diff --git a/qlib/__init__.py b/qlib/__init__.py index 8d0b322b1..f2b2c28ac 100644 --- a/qlib/__init__.py +++ b/qlib/__init__.py @@ -13,7 +13,7 @@ import platform import yaml from pathlib import Path -from .utils import can_use_cache +from .utils import can_use_cache, init_instance_by_config, get_module_by_module_path # init qlib @@ -22,6 +22,7 @@ def init(default_conf="client", **kwargs): from .data.data import register_all_wrappers from .log import get_module_logger, set_log_with_config from .data.cache import H + from .workflow import R, QlibRecorder C.reset() H.clear() @@ -79,6 +80,15 @@ def init(default_conf="client", **kwargs): if "flask_server" in C: LOG.info(f"flask_server={C['flask_server']}, flask_port={C['flask_port']}") + + # set up QlibRecorder + default_uri = str(Path(os.getcwd()).resolve() / "mlruns") + current_uri = C['exp_uri'] if C['exp_uri'] is not None else default_uri + # exp manager module + module = get_module_by_module_path('qlib.workflow') + exp_manager = init_instance_by_config(C['exp_manager'], module) + qr = QlibRecorder(exp_manager, default_uri, current_uri) + R.register(qr) def _mount_nfs_uri(C): diff --git a/qlib/config.py b/qlib/config.py index ff01fe5e8..db5fab69c 100644 --- a/qlib/config.py +++ b/qlib/config.py @@ -124,6 +124,12 @@ _default_config = { }, "loggers": {"qlib": {"level": "DEBUG", "handlers": ["console"]}}, }, + # Defatult config for experiment manager + "exp_manager": { + "class": "MLflowExpManager", + "kwargs": {} + }, + "exp_uri": None, } MODE_CONF = { diff --git a/qlib/data/data.py b/qlib/data/data.py index c41d32f6e..476cc9682 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -24,6 +24,7 @@ from ..log import get_module_logger from ..utils import parse_field, read_bin, hash_args, normalize_cache_fields from .base import Feature from .cache import DiskDatasetCache, DiskExpressionCache +from ..utils import Wrapper, get_provider_obj, register_wrapper class CalendarProvider(abc.ABC): @@ -1019,44 +1020,6 @@ class ClientProvider(BaseProvider): DatasetD.set_conn(self.client) -class Wrapper(object): - """Data Provider Wrapper""" - - def __init__(self): - self._provider = None - - def register(self, provider): - self._provider = provider - - def __getattr__(self, key): - if self._provider is None: - raise AttributeError("Please run qlib.init() first using qlib") - return getattr(self._provider, key) - - -def get_cls_from_name(cls_name): - return getattr(importlib.import_module(".data", package="qlib"), cls_name) - - -def get_provider_obj(config, **params): - if isinstance(config, dict): - params.update(config["kwargs"]) - config = config["class"] - return get_cls_from_name(config)(**params) - - -def register_wrapper(wrapper, cls_or_obj): - """register_wrapper - - :param wrapper: A wrapper of all kinds of providers - :param cls_or_obj: A class or class name or object instance in data/data.py - """ - if isinstance(cls_or_obj, str): - cls_or_obj = get_cls_from_name(cls_or_obj) - obj = cls_or_obj() if isinstance(cls_or_obj, type) else cls_or_obj - wrapper.register(obj) - - Cal = Wrapper() Inst = Wrapper() FeatureD = Wrapper() diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index b10735868..8467db600 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -611,3 +611,39 @@ def exists_qlib_data(qlib_dir): return False return True + + +#################### Wrapper ##################### +class Wrapper(object): + """Data Provider Wrapper""" + + def __init__(self): + self._provider = None + + def register(self, provider): + self._provider = provider + + def __getattr__(self, key): + if self._provider is None: + raise AttributeError("Please run qlib.init() first using qlib") + return getattr(self._provider, key) + + +def get_provider_obj(config, **params): + module = get_module_by_module_path("qlib.data") + klass, kwargs = get_cls_kwargs(config, module) + kwargs.update(params) + return klass(**kwargs) + + +def register_wrapper(wrapper, cls_or_obj): + """register_wrapper + + :param wrapper: A wrapper of all kinds of providers + :param cls_or_obj: A class or class name or object instance in data/data.py + """ + if isinstance(cls_or_obj, str): + module = get_module_by_module_path("qlib.data") + cls_or_obj = getattr(module, cls_or_obj) + obj = cls_or_obj() if isinstance(cls_or_obj, type) else cls_or_obj + wrapper.register(obj) \ No newline at end of file diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py index 76d5e7d4c..db5112470 100644 --- a/qlib/workflow/__init__.py +++ b/qlib/workflow/__init__.py @@ -2,156 +2,62 @@ # Licensed under the MIT License. from contextlib import contextmanager -from .record import MLflowRecorder -from .exp import MLflowExpManager +from .expm import * +from ..utils import Wrapper -class Record: - def __init__(self): - pass +class QlibRecorder: + def __init__(self, exp_manager, default_uri, current_uri): + self.exp_manager = exp_manager + self.default_uri = default_uri + self.current_uri = current_uri @contextmanager - def start_exp(self, experiment_name=None, uri=None, project_path=None, artifact_location=None, nested=False): - raise NotImplementedError(f"Please implement the `start_exp` method.") + def start(self, experiment_name): + run = self.start_exp(experiment_name, self.current_uri) + yield run + self.end_exp() - def search_runs(self, experiment_ids=None, filter_string='', run_view_type=1, max_results=100000, order_by=None): - raise NotImplementedError(f"Please implement the `search_runs` method.") + def start_exp(self, experiment_name=None): + return self.exp_manager.start_exp(experiment_name, self.current_uri) + + def end_exp(self): + self.exp_manager.end_exp() - def get_exp(self, experiment_id): - raise NotImplementedError(f"Please implement the `get_exp` method.") - - def get_exp_by_name(self, experiment_name): - raise NotImplementedError(f"Please implement the `get_exp_by_name` method.") + def search_records(self, experiment_ids, **kwargs): + return self.exp_manager.search_records(experiment_ids, **kwargs) - def create_exp(self, experiment_name, artifact_location=None): - raise NotImplementedError(f"Please implement the `create_exp` method.") - - def set_exp(self, experiment_name): - raise NotImplementedError(f"Please implement the `set_exp` method.") - - def delete_exp(self, experiment_id): - raise NotImplementedError(f"Please implement the `create_exp` method.") - - def set_tracking_uri(self, uri): - raise NotImplementedError(f"Please implement the `set_tracking_uri` method.") - - def get_tracking_uri(self): - raise NotImplementedError(f"Please implement the `get_tracking_uri` method.") - - def get_recorder(self): - raise NotImplementedError(f"Please implement the `get_recorder` method.") - - def save_object(self, name, data): - raise NotImplementedError(f"Please implement the `save_object` method.") - - def save_objects(self, name_data_list): - raise NotImplementedError(f"Please implement the `save_objects` method.") - - def load_object(self, name): - raise NotImplementedError(f"Please implement the `load_object` method.") - - def log_param(self, key, value): - raise NotImplementedError(f"Please implement the `log_param` method.") - - def log_params(self, params): - raise NotImplementedError(f"Please implement the `log_params` method.") - - def log_metric(self, key, value, step=None): - raise NotImplementedError(f"Please implement the `log_metric` method.") - - def log_metrics(self, metrics, step=None): - raise NotImplementedError(f"Please implement the `log_metrics` method.") - - def set_tag(self, key, value): - raise NotImplementedError(f"Please implement the `set_tag` method.") - - def set_tags(self, tags): - raise NotImplementedError(f"Please implement the `log_tags` method.") - - def delete_tag(self, key): - raise NotImplementedError(f"Please implement the `delete_tag` method.") - - def log_artifact(self, local_path, artifact_path=None): - raise NotImplementedError(f"Please implement the `log_artifact` method.") - - def log_artifacts(self, local_dir, artifact_path=None): - raise NotImplementedError(f"Please implement the `log_artifacts` method.") - - def get_artifact_uri(self, artifact_path=None): - raise NotImplementedError(f"Please implement the `get_artifact_uri` method.") - -class MLflowRecord(Record): - def __init__(self): - self.exp_manager = MLflowExpManager() - - @contextmanager - def start_exp(self, experiment_name=None, uri=None, project_path=None, artifact_location=None, nested=False): - yield self.exp_manager.start_exp(experiment_name, uri, project_path, artifact_location, nested) - - def search_runs(self, experiment_ids=None, filter_string='', run_view_type=1, max_results=100000, order_by=None): - return self.exp_manager.search_runs(experiment_ids, filter_string, run_view_type, max_results, order_by) - - def get_exp(self, experiment_id): - return self.exp_manager.get_exp(experiment_id) - - def get_exp_by_name(self, experiment_name): - return self.exp_manager.get_exp_by_name(experiment_name) - - def create_exp(self, experiment_name, artifact_location=None): - self.exp_manager.create_exp(experiment_name, artifact_location) - - def set_exp(self, experiment_name): - self.exp_manager.set_exp(experiment_name) + def get_exp(self, experiment_id=None, experiment_name=None): + return self.exp_manager.get_exp(experiment_id, experiment_name) def delete_exp(self, experiment_id): self.exp_manager.delete_exp(experiment_id) - def set_tracking_uri(self, uri): - self.exp_manager.set_tracking_uri(uri) - - def get_tracking_uri(self): - return self.exp_manager.get_tracking_uri() - + def get_uri(self, type): + return self.exp_manager.get_uri(type) + def get_recorder(self): - return self.exp_manager.get_recorder() + return self.exp_manager.active_recorder - def save_object(self, name, data): - self.exp_manager.active_recorder.save_object(name, data) + def save_object(self, data=None, name=None, local_path=None): + self.exp_manager.active_recorder.save_object(data, name, local_path) - def save_objects(self, name_data_list): - self.exp_manager.active_recorder.save_objects(name_data_list) + def save_objects(self, data_name_list=None, local_path=None): + self.exp_manager.active_recorder.save_objects(data_name_list, local_path) def load_object(self, name): return self.exp_manager.active_recorder.load_object(name) + + def log_params(self, **kwargs): + self.exp_manager.active_recorder.log_params(**kwargs) + + def log_metrics(self, step=None, **kwargs): + self.exp_manager.active_recorder.log_metrics(step, **kwargs) - def log_param(self, key, value): - self.exp_manager.active_recorder.log_param(key, value) - - def log_params(self, params): - self.exp_manager.active_recorder.log_params(params) - - def log_metric(self, key, value, step=None): - self.exp_manager.active_recorder.log_metric(key, value, step) - - def log_metrics(self, metrics, step=None): - self.exp_manager.active_recorder.log_metrics(metrics, step) - - def set_tag(self, key, value): - self.exp_manager.active_recorder.set_tag(key, value) - - def set_tags(self, tags): - self.exp_manager.active_recorder.set_tags(tags) + def set_tags(self, **kwargs): + self.exp_manager.active_recorder.set_tags(**kwargs) def delete_tag(self, key): self.exp_manager.active_recorder.delete_tag(key) - - def log_artifact(self, local_path, artifact_path=None): - self.exp_manager.active_recorder.log_artifact(local_path, artifact_path) - - def log_artifacts(self, local_dir, artifact_path=None): - self.exp_manager.active_recorder.log_artifacts(local_dir, artifact_path) - - def get_artifact_uri(self, artifact_path=None): - return self.exp_manager.active_recorder.get_artifact_uri(artifact_path) # global record -R = MLflowRecord() \ No newline at end of file +R = Wrapper() \ No newline at end of file diff --git a/qlib/workflow/exp.py b/qlib/workflow/exp.py index f3cedea90..9e076aced 100644 --- a/qlib/workflow/exp.py +++ b/qlib/workflow/exp.py @@ -2,67 +2,23 @@ # Licensed under the MIT License. import mlflow -from contextlib import contextmanager -from .record import MLflowRecorder +from pathlib import Path -class ExpManager: +class Experiment: """ - This is the `ExpManager` class for managing the experiments. The API is designed similar to mlflow. - (The link: https://mlflow.org/docs/latest/python_api/mlflow.html) + Thie is the `Experiment` class for each experiment being run. The API is designed """ def __init__(self): - self.active_recorder = None - self.experiments = dict() # store the experiment names -> list of recorders. - self.exp_ids = list() - - def _store_exp(self, id, name): - """ - Store the experiments in the experiments holder. - """ - if id in self.exp_ids: - raise Exception('Something went wrong when creating the experiment. Please check if the experiment is already created.') - if name in self.experiments: - assert int(id) == int(self.experiments[name][0]), 'Experiment id and name are not consistent when storing the experiment.' - else: - self.exp_ids.append(id) - self.experiments[name] = [id] + self.name = None + self.id = None + self.recorders = list() - def start_exp(self, project_path, experiment_name=None, uri=None, artifact_location=None, nested=False): + def search_records(self, **kwargs): """ - Start running an experiment. This method can only work in the `with` statement. + Get a pandas DataFrame of records that fit the search criteria of the experiment. Parameters ---------- - project_path : str - path for the project. - experiment_name : str - name of the active experiment. - uri : str - the current tracking URI. - artifact_location : str - the location to store all the artifacts. - nested : boolean - controls whether run is nested in parent run. - - Returns - None - """ - raise NotImplementedError(f"Please implement the `start_exp` method.") - - def end_exp(self): - """ - End an active experiment. - """ - raise NotImplementedError(f"Please implement the `end_exp` method.") - - def search_runs(self, experiment_ids=None, filter_string='', run_view_type=1, max_results=100000, order_by=None): - """ - Get a pandas DataFrame of runs that fit the search criteria. - - Parameters - ---------- - experiment_ids : list - list of experiment IDs. filter_string : str filter query string, defaults to searching all runs. run_view_type : int @@ -74,192 +30,18 @@ class ExpManager: Returns ------- - A pandas.DataFrame of runs. + A pandas.DataFrame of records. """ - raise NotImplementedError(f"Please implement the `search_runs` method.") - - def get_exp(self, experiment_id): - """ - Retrieve an experiment by experiment_id from the backend store. - - Parameters - ---------- - experiment_id : str - the experiment id to return. - - Returns - ------- - An experiment object (e.g. mlflow.entities.Experiment). - """ - raise NotImplementedError(f"Please implement the `get_exp` method.") - - def get_exp_by_name(self, experiment_name): - """ - Retrieve an experiment by experiment name from the backend store. - - Parameters - ---------- - experiment_name : str - the experiment name to return. - - Returns - ------- - An experiment object (e.g. mlflow.entities.Experiment). - """ - raise NotImplementedError(f"Please implement the `get_exp_by_name` method.") - - def create_exp(self, experiment_name, artifact_location=None): - """ - Create an experiment. - - Parameters - ---------- - experiment_name : str - the experiment name, which must be unique. - artifact_location : str - the location to store run artifacts. - - Returns - ------- - String id of created experiment. - """ - raise NotImplementedError(f"Please implement the `create_exp` method.") - - def set_exp(self, experiment_name): - """ - Set the experiment to be active. - - Parameters - ---------- - experiment_name : str - the experiment name, which must be unique. - - Returns - ------- - String id of created experiment. - """ - raise NotImplementedError(f"Please implement the `set_exp` method.") - - def delete_exp(self, experiment_id): - """ - Delete an experiment. - - Parameters - ---------- - experiment_id : str - the experiment id. - - Returns - ------- - None - """ - raise NotImplementedError(f"Please implement the `create_exp` method.") - - def set_tracking_uri(self, uri): - """ - Set the tracking server URI. - - Parameters - ---------- - uri : str - the uri of the tracking server, can be An empty string, or a local file path, prefixed with file:/. - or An HTTP URI or A Databricks workspace. - Returns - ------- - None - """ - raise NotImplementedError(f"Please implement the `set_tracking_uri` method.") - - def get_tracking_uri(self): - """ - Get the tracking server URI. - - Parameters - ---------- - - Returns - ------- - The tracking URI. - """ - raise NotImplementedError(f"Please implement the `get_tracking_uri` method.") - - def get_recorder(self): - """ - Get the current active Recorder. - - Parameters - ---------- - - Returns - ------- - An Recorder object. - """ - raise NotImplementedError(f"Please implement the `get_recorder` method.") + raise NotImplementedError(f"Please implement the `search_records` method.") -class MLflowExpManager(ExpManager): - ''' - Use mlflow to implement ExpManager. - ''' - def start_exp(self, experiment_name=None, uri=None, project_path=None, artifact_location=None, nested=False): - # set the tracking uri - if uri is None: - assert project_path is not None, "Please provide the project_path if no uri is provided in order to set a proper tracking uri." - print('No tracking URI is provided. The default tracking URI is set as `mlruns` under the project path.') - mlflow.set_tracking_uri(str(project_path / "mlruns")) - else: - mlflow.set_tracking_uri(uri) - # start the experiment - if experiment_name is None: - print('No experiment name provided. The default experiment name is set as `experiment`.') - experiment_id = self.create_exp('experiment', artifact_location) - # set the active experiment - self.set_exp('experiment') - experiment_name = 'experiment' - else: - if experiment_name not in self.experiments: - if self.get_exp_by_name(experiment_name) is not None: - raise Exception('The experiment has already been created before. Please pick another name or delete the files under tracking uri.') - experiment_id = self.create_exp(experiment_name, artifact_location) - else: - experiment_id = self.experiments(experiment_name)[0] - # set the active experiment - self.set_exp(experiment_name) - - # store the id and name - self._store_exp(experiment_id, experiment_name) - # set up recorder - recorder = MLflowRecorder(experiment_id) - self.active_recorder = recorder - # store the recorder - self.experiments[experiment_name].append(self.active_recorder) - - return self.active_recorder.start_run(experiment_id=experiment_id, nested=nested) - - def search_runs(self, experiment_ids=None, filter_string='', run_view_type=1, max_results=100000, order_by=None): - return mlflow.search_runs(experiment_ids, filter_string, run_view_type, max_results, order_by) - - def get_exp(self, experiment_id): - return mlflow.get_experiment(experiment_id) - - def get_exp_by_name(self, experiment_name): - return mlflow.get_experiment_by_name(experiment_name) - - def create_exp(self, experiment_name, artifact_location=None): - return mlflow.create_experiment(experiment_name, artifact_location) - - def set_exp(self, experiment_name): - mlflow.set_experiment(experiment_name) - - def delete_exp(self, experiment_id): - mlflow.delete_experiment(experiment_id) - self.experiments = {key:val for key, val in self.experiments.items() if val[0] != experiment_id} - - def set_tracking_uri(self, uri): - mlflow.set_tracking_uri(uri) - - def get_tracking_uri(self): - return mlflow.get_tracking_uri() - - def get_recorder(self): - return self.active_recorder \ No newline at end of file +class MLflowExperiment(Experiment): + """ + Use mlflow to implement Experiment. + """ + def search_records(self, **kwargs): + filter_string = '' if kwargs.get('filter_string') is None else kwargs.get('filter_string') + run_view_type = 1 if kwargs.get('run_view_type') is None else kwargs.get('run_view_type') + max_results = 100000 if kwargs.get('max_results') is None else kwargs.get('max_results') + order_by = kwargs.get('order_by') + return mlflow.search_runs([self.experiment_id], filter_string, run_view_type, max_results, order_by) \ No newline at end of file diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py new file mode 100644 index 000000000..36a945f42 --- /dev/null +++ b/qlib/workflow/expm.py @@ -0,0 +1,236 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import mlflow +import os +from pathlib import Path +from contextlib import contextmanager +from .exp import MLflowExperiment +from .record import MLflowRecorder + +class ExpManager: + """ + This is the `ExpManager` class for managing the experiments. The API is designed similar to mlflow. + (The link: https://mlflow.org/docs/latest/python_api/mlflow.html) + """ + def __init__(self): + self.default_uri = None + self.active_recorder = None # only one recorder can running each time + self.experiments = dict() # store the experiment name --> Experiment object + + def start_exp(self, experiment_name=None, uri=None, **kwargs): + """ + Start running an experiment. + + Parameters + ---------- + experiment_name : str + name of the active experiment. + uri : str + the current tracking URI. + artifact_location : str + the location to store all the artifacts. + nested : boolean + controls whether run is nested in parent run. + + Returns + An object wrapped by context manager. + """ + raise NotImplementedError(f"Please implement the `start_exp` method.") + + def end_exp(self, **kwargs): + """ + End an running experiment. + + Parameters + ---------- + experiment_name : str + name of the active experiment. + """ + raise NotImplementedError(f"Please implement the `end_exp` method.") + + def search_records(self, experiment_ids=None, **kwargs): + """ + Get a pandas DataFrame of records that fit the search criteria. + + Parameters + ---------- + experiment_ids : list + list of experiment IDs. + filter_string : str + filter query string, defaults to searching all runs. + run_view_type : int + one of enum values ACTIVE_ONLY, DELETED_ONLY, or ALL (e.g. in mlflow.entities.ViewType). + max_results : int + the maximum number of runs to put in the dataframe. + order_by : list + list of columns to order by (e.g., “metrics.rmse”). + + Returns + ------- + A pandas.DataFrame of runs. + """ + raise NotImplementedError(f"Please implement the `search_records` method.") + + def __create_exp(self, experiment_name, artifact_location=None): + """ + Create an experiment. + + Parameters + ---------- + experiment_name : str + the experiment name, which must be unique. + artifact_location : str + the location to store run artifacts. + + Returns + ------- + An experiment object. + """ + raise NotImplementedError(f"Please implement the `create_exp` method.") + + def get_exp(self, experiment_id=None, experiment_name=None): + """ + Retrieve an experiment by experiment_id from the backend store. + + Parameters + ---------- + experiment_id : str + the experiment id to return. + + Returns + ------- + An experiment object. + """ + raise NotImplementedError(f"Please implement the `get_exp` method.") + + def delete_exp(self, experiment_id): + """ + Delete an experiment. + + Parameters + ---------- + experiment_id : str + the experiment id. + """ + raise NotImplementedError(f"Please implement the `create_exp` method.") + + def get_uri(self, type): + """ + Get the default tracking URI or current URI. + + Parameters + ---------- + type : str + the type of the tracking URI one wants to retrieve. + + Returns + ------- + The tracking URI string. + """ + raise NotImplementedError(f"Please implement the `create_exp` method.") + + def get_recorder(self): + """ + Get the current active Recorder. + + Parameters + ---------- + + Returns + ------- + An Recorder object. + """ + raise NotImplementedError(f"Please implement the `get_recorder` method.") + + +class MLflowExpManager(ExpManager): + ''' + Use mlflow to implement ExpManager. + ''' + def __init__(self): + super(MLflowExpManager, self).__init__() + self.default_uri = None + self.current_uri = None + + def start_exp(self, experiment_name=None, uri=None): + # create experiment + experiment = self.__create_exp(experiment_name, uri) + # set up recorder + recorder = MLflowRecorder(experiment.id) + self.active_recorder = recorder + # store the recorder + experiment.recorders.append(self.active_recorder) + # store the experiment + self.experiments[experiment_name] = experiment + + return self.active_recorder.start_run(experiment_id=experiment.id) + + def end_exp(self): + self.active_recorder.end_run() + self.active_recorder = None + + def __create_exp(self, experiment_name=None, uri=None): + # init experiment + experiment = MLflowExperiment() + # set the tracking uri + if uri is None: + print('No tracking URI is provided. The default tracking URI is set as `mlruns` under the working directory.') + else: + self.current_uri = uri + mlflow.set_tracking_uri(self.current_uri) + # start the experiment + if experiment_name is None: + print('No experiment name provided. The default experiment name is set as `experiment`.') + experiment_id = mlflow.create_experiment('experiment') + # set the active experiment + mlflow.set_experiment('experiment') + experiment_name = 'experiment' + else: + if experiment_name not in self.experiments: + if mlflow.get_experiment_by_name(experiment_name) is not None: + raise Exception('The experiment has already been created before. Please pick another name or delete the files under uri.') + experiment_id = mlflow.create_experiment(experiment_name) + else: + experiment_id = self.experiments[experiment_name].id + experiment = self.experiments[experiment_name] + # set the active experiment + mlflow.set_experiment(experiment_name) + # set up experiment + experiment.id = experiment_id + experiment.name = experiment_name + + return experiment + + def search_records(self, experiment_ids, **kwargs): + filter_string = '' if kwargs.get('filter_string') is None else kwargs.get('filter_string') + run_view_type = 1 if kwargs.get('run_view_type') is None else kwargs.get('run_view_type') + max_results = 100000 if kwargs.get('max_results') is None else kwargs.get('max_results') + order_by = kwargs.get('order_by') + return mlflow.search_runs(experiment_ids, filter_string, run_view_type, max_results, order_by) + + def get_exp(self, experiment_id=None, experiment_name=None): + assert experiment_id is not None or experiment_name is not None, 'Please provide at least one of the experiment id or name to retrieve an experiment.' + if experiment_name is not None: + return self.experiments[experiment_name] + elif: + for name in self.experiments: + if self.experiments[name].id == experiment_id: + return self.experiments[name] + else: + print('No valid experiment is found. Please make sure the id and name are correctly given.') + + def delete_exp(self, experiment_id): + mlflow.delete_experiment(experiment_id) + self.experiments = {key:val for key, val in self.experiments.items() if val.id != experiment_id} + + def get_uri(self, type): + if uri == 'default': + return self.default_uri + elif uri == 'current': + return self.current_uri + else: + raise ValueError('Input type is not supported. Please choose type default or current to get the uri.') + + def get_recorder(self): + return self.active_recorder \ No newline at end of file diff --git a/qlib/workflow/record.py b/qlib/workflow/record.py index 7895cf0fb..071c92691 100644 --- a/qlib/workflow/record.py +++ b/qlib/workflow/record.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. import mlflow -import shutil +import shutil, os, pickle, tempfile, codecs from pathlib import Path from ..utils.objm import FileManager @@ -12,45 +12,39 @@ class Recorder: (The link: https://mlflow.org/docs/latest/python_api/mlflow.html) """ - def __init__(self, experiment_id, project_path=None): + def __init__(self, experiment_id): self.experiment_id = experiment_id self.recorder_id = None self.recorder_name = None - self.fm = None - self.artifact_uri = None def set_recorder_name(self, rname): self.recorder_name = rname - def save_object(self, name, data): + def save_object(self, data, name, local_path=None): """ - Save object such as prediction file or model checkpoints. + Save object such as prediction file or model checkpoints to the artifact URI. Parameters ---------- - name : str - name of the file to be saved. data : any type the data to be saved. - - Returns - ------- - None. + name : str + name of the file to be saved. + local_path : str + if provided, them save the file or directory to the artifact URI. """ raise NotImplementedError(f"Please implement the `save_object` method.") - def save_objects(self, name_data_list): + def save_objects(self, data_name_list, local_path=None): """ - Save objects such as prediction file or model checkpoints. + Save objects such as prediction file or model checkpoints to the artifact URI. Parameters ---------- - name_data_list : list - list of (name, data) pairs - - Returns - ------- - None. + data_name_list : list + list of (data, name) pairs + local_path : str + if provided, them save the file or directory to the artifact URI. """ raise NotImplementedError(f"Please implement the `save_objects` method.") @@ -98,99 +92,36 @@ class Recorder: """ raise NotImplementedError(f"Please implement the `end_run` method.") - def log_param(self, key, value): - """ - Log a parameter under the current run. - - Parameters - ---------- - key : str - the name of the parameter - value : str - the value of the parameter - - Returns - ------- - None - """ - raise NotImplementedError(f"Please implement the `log_param` method.") - - def log_params(self, params): + def log_params(self, **kwargs): """ Log a batch of params for the current run. Parameters ---------- - params : dict - dictionary of param_name: String -> value: String. - - Returns - ------- - None + keyword arguments + key, value pair to be logged as parameters. """ raise NotImplementedError(f"Please implement the `log_params` method.") - def log_metric(self, key, value, step=None): - """ - Log a metric under the current run. - - Parameters - ---------- - key : str - the name of the metric - value : float - the value of the metric - - Returns - ------- - None - """ - raise NotImplementedError(f"Please implement the `log_metric` method.") - - def log_metrics(self, metrics, step=None): + def log_metrics(self, step=None, **kwargs): """ Log multiple metrics for the current run. Parameters ---------- - metrics : dict - dictionary of metric_name: String -> value: Float. - - Returns - ------- - None + keyword arguments + key, value pair to be logged as metrics. """ raise NotImplementedError(f"Please implement the `log_metrics` method.") - - def set_tag(self, key, value): - """ - Set a tag under the current run. - Parameters - ---------- - key : str - the name of the tag - value : str - the value of the tag - - Returns - ------- - None - """ - raise NotImplementedError(f"Please implement the `set_tag` method.") - - def set_tags(self, tags): + def set_tags(self, **kwargs): """ Log a batch of tags for the current run. Parameters ---------- - tags : dict - dictionary of tag_name: String -> value: String. - - Returns - ------- - None + keyword arguments + key, value pair to be logged as tags. """ raise NotImplementedError(f"Please implement the `log_tags` method.") @@ -202,67 +133,22 @@ class Recorder: ---------- key : str the name of the tag to be deleted. - - Returns - ------- - None """ raise NotImplementedError(f"Please implement the `delete_tag` method.") - - def log_artifact(self, local_path, artifact_path=None): - """ - Log a local file or directory as an artifact of the currently active run. - - Parameters - ---------- - local_path : str - path to the file to write. - artifact_path : str - the directory in `artifact_uri` to write to. - - Returns - ------- - None - """ - raise NotImplementedError(f"Please implement the `log_artifact` method.") - - def log_artifacts(self, local_dir, artifact_path=None): - """ - Log all the contents of a local directory as artifacts of the run. - - Parameters - ---------- - local_dir : str - path to the directory of files to write. - artifact_path : str - the directory in `artifact_uri` to write to. - - Returns - ------- - None - """ - raise NotImplementedError(f"Please implement the `log_artifacts` method.") - - def get_artifact_uri(self, artifact_path=None): - """ - Get the absolute URI of the specified artifact in the currently active run. - - Parameters - ---------- - artifact_path : str - the directory in `artifact_uri` to write to. - - Returns - ------- - An absolute URI referring to the specified artifact or currently active Recorder. - """ - raise NotImplementedError(f"Please implement the `get_artifact_uri` method.") class MLflowRecorder(Recorder): ''' Use mlflow to implement a Recorder. + + Due to the fact that mlflow will only log artifact from a file or directory, we decide to + use file manager to help maintain the objects in the project. ''' + def __init__(self, experiment_id): + super(MLflowRecorder, self).__init__(experiment_id) + self.fm = None + self.temp_dir = None + def start_run(self, run_id=None, experiment_id=None, run_name=None, nested=False): if run_id is None: @@ -277,65 +163,67 @@ class MLflowRecorder(Recorder): self.recorder_id = run.info.run_id self.artifact_uri = run.info.artifact_uri # set up file manager for saving objects - if self.artifact_uri.startswith('file:/'): - self.fm = FileManager(Path(urllib.parse.urlparse(self.artifact_uri).path)) - else: - self.fm = FileManager(Path(self.artifact_uri)) - print(self.artifact_uri) + self.temp_dir = tempfile.mkdtemp() + self.fm = FileManager(Path(self.temp_dir).absolute()) return run def end_run(self): mlflow.end_run() + shutil.rmtree(self.temp_dir) - def save_object(self, name, data): - self.fm.save_obj(data, name) - import urllib - print(urllib.parse.urlparse(self.artifact_uri).scheme) - try: - self.log_artifact(self.fm.path / name) - except shutil.SameFileError: - pass - except Exception as e: - print(e) + def save_object(self, data, name, local_path=None): + if local_path is None: + assert data is not None and name is not None, "Please provide data and name input." + self.fm.save_obj(data, name) + mlflow.log_artifact(self.fm.path / name) + self.fm.remove(name) + else: + mlflow.log_artifact(local_path) - def save_objects(self, name_data_list): - self.fm.save_objs(name_data_list) - try: - self.log_artifacts(self.fm.path) - except shutil.SameFileError: - pass - except Exception as e: - print(e) + def save_objects(self, data_name_list, local_path=None): + if local_path is None: + assert data_name_list is not None, "Please provide data_name_list input." + self.fm.save_objs(data_name_list) + mlflow.log_artifacts(self.fm.path) + for obj, name in data_name_list: + self.fm.remove(name) + else: + mlflow.log_artifacts(local_path) def load_object(self, name): - return self.fm.load_obj(name) + client = mlflow.tracking.MlflowClient() + path = client.download_artifacts(self.recorder_id, name) + try: + with Path(path).open('rb') as f: + f.seek(0) + return pickle.load(f) + except: + with codecs.open(path, mode="r", encoding='utf-8') as f: + return f.read() + + def log_params(self, **kwargs): + keys = list(kwargs.keys()) + if len(keys) == 0: + mlflow.log_param(keys[0], kwargs.get(keys[0])) + else: + mlflow.log_params(dict(kwargs)) - def log_param(self, key, value): - mlflow.log_param(key, value) - - def log_params(self, params): - mlflow.log_params(params) - - def log_metric(self, key, value, step=None): - mlflow.log_metric(key, value, step) - - def log_metrics(self, metrics, step=None): - mlflow.log_metrics(metrics, step) + def log_metrics(self, step=None, **kwargs): + keys = list(kwargs.keys()) + if len(keys) == 0: + mlflow.log_metric(keys[0], kwargs.get(keys[0])) + else: + mlflow.log_metrics(dict(kwargs)) - def set_tag(self, key, value): - mlflow.set_tag(key, value) - - def set_tags(self, tags): - mlflow.set_tags(tags) + def set_tags(self, **kwargs): + keys = list(kwargs.keys()) + if len(keys) == 0: + mlflow.set_tag(keys[0], kwargs.get(keys[0])) + else: + mlflow.set_tags(dict(kwargs)) def delete_tag(self, key): mlflow.delete_tag(key) - - def log_artifact(self, local_path, artifact_path=None): - mlflow.log_artifact(local_path, artifact_path) - - def log_artifacts(self, local_dir, artifact_path=None): - mlflow.log_artifacts(local_dir, artifact_path) def get_artifact_uri(self, artifact_path=None): if self.artifact_uri is not None: