diff --git a/qlib/__init__.py b/qlib/__init__.py index 154d4ea08..8620acdb7 100644 --- a/qlib/__init__.py +++ b/qlib/__init__.py @@ -82,12 +82,11 @@ def init(default_conf="client", **kwargs): LOG.info(f"flask_server={C['flask_server']}, flask_port={C['flask_port']}") # set up QlibRecorder - default_uri = str(Path(os.getcwd()).resolve() / "mlruns") - current_uri = C["exp_uri"] if C["exp_uri"] is not None else default_uri + uri = C["exp_uri"] # exp manager module - module = get_module_by_module_path("qlib.workflow") + module = get_module_by_module_path("qlib.workflow.expm") exp_manager = init_instance_by_config(C["exp_manager"], module) - qr = QlibRecorder(exp_manager, default_uri, current_uri) + qr = QlibRecorder(exp_manager, uri) R.register(qr) diff --git a/qlib/config.py b/qlib/config.py index 0e2a264af..2bd77feb8 100644 --- a/qlib/config.py +++ b/qlib/config.py @@ -126,7 +126,7 @@ _default_config = { }, # Defatult config for experiment manager "exp_manager": {"class": "MLflowExpManager", "kwargs": {}}, - "exp_uri": None, + "exp_uri": str(Path(os.getcwd()).resolve() / "mlruns"), } MODE_CONF = { diff --git a/qlib/data/data.py b/qlib/data/data.py index 476cc9682..8eae9f01c 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -24,7 +24,7 @@ from ..log import get_module_logger from ..utils import parse_field, read_bin, hash_args, normalize_cache_fields from .base import Feature from .cache import DiskDatasetCache, DiskExpressionCache -from ..utils import Wrapper, get_provider_obj, register_wrapper +from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path class CalendarProvider(abc.ABC): @@ -1031,34 +1031,44 @@ D = Wrapper() def register_all_wrappers(): """register_all_wrappers""" logger = get_module_logger("data") - - _calendar_provider = get_provider_obj(C.calendar_provider) + module = get_module_by_module_path("qlib.data") + + _calendar_provider = init_instance_by_config(C.calendar_provider, module) if getattr(C, "calendar_cache", None) is not None: - _calendar_provider = get_provider_obj(C.calendar_cache, provider=_calendar_provider) - register_wrapper(Cal, _calendar_provider) + _calendar_cache_config = {} + _calendar_cache_config.update(C.calendar_cache) + _calendar_cache_config['kwargs'].update(provider=_calendar_provider) + _calendar_provider = init_instance_by_config(_calendar_cache_config, module) + register_wrapper(Cal, _calendar_provider, "qlib.data") logger.debug(f"registering Cal {C.calendar_provider}-{C.calenar_cache}") - register_wrapper(Inst, C.instrument_provider) + register_wrapper(Inst, C.instrument_provider, "qlib.data") logger.debug(f"registering Inst {C.instrument_provider}") if getattr(C, "feature_provider", None) is not None: - feature_provider = get_provider_obj(C.feature_provider) - register_wrapper(FeatureD, feature_provider) + feature_provider = init_instance_by_config(C.feature_provider, module) + register_wrapper(FeatureD, feature_provider, "qlib.data") logger.debug(f"registering FeatureD {C.feature_provider}") if getattr(C, "expression_provider", None) is not None: # This provider is unnecessary in client provider - _eprovider = get_provider_obj(C.expression_provider) + _eprovider = init_instance_by_config(C.expression_provider, module) if getattr(C, "expression_cache", None) is not None: - _eprovider = get_provider_obj(C.expression_cache, provider=_eprovider) - register_wrapper(ExpressionD, _eprovider) + _expression_cache_config = {} + _expression_cache_config.update(C.expression_cache) + _expression_cache_config['kwargs'].update(provider=_eprovider) + _eprovider = init_instance_by_config(C.expression_cache, module) + register_wrapper(ExpressionD, _eprovider, "qlib.data") logger.debug(f"registering ExpressioneD {C.expression_provider}-{C.expression_cache}") - _dprovider = get_provider_obj(C.dataset_provider) + _dprovider = init_instance_by_config(C.dataset_provider, module) if getattr(C, "dataset_cache", None) is not None: - _dprovider = get_provider_obj(C.dataset_cache, provider=_dprovider) - register_wrapper(DatasetD, _dprovider) + _dataset_cache_config = {} + _dataset_cache_config.update(C.dataset_cache) + _dataset_cache_config['kwargs'].update(provider=_dprovider) + _dprovider = init_instance_by_config(_dataset_cache_config, module) + register_wrapper(DatasetD, _dprovider, "qlib.data") logger.debug(f"registering DataseteD {C.dataset_provider}-{C.dataset_cache}") - register_wrapper(D, C.provider) + register_wrapper(D, C.provider, "qlib.data") logger.debug(f"registering D {C.provider}") diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index 87b43f456..ca0ff4c28 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -632,21 +632,14 @@ class Wrapper(object): return getattr(self._provider, key) -def get_provider_obj(config, **params): - module = get_module_by_module_path("qlib.data") - klass, kwargs = get_cls_kwargs(config, module) - kwargs.update(params) - return klass(**kwargs) - - -def register_wrapper(wrapper, cls_or_obj): +def register_wrapper(wrapper, cls_or_obj, module_path=None): """register_wrapper - :param wrapper: A wrapper of all kinds of providers - :param cls_or_obj: A class or class name or object instance in data/data.py + :param wrapper: A wrapper. + :param cls_or_obj: A class or class name or object instance. """ if isinstance(cls_or_obj, str): - module = get_module_by_module_path("qlib.data") + module = get_module_by_module_path(module_path) cls_or_obj = getattr(module, cls_or_obj) obj = cls_or_obj() if isinstance(cls_or_obj, type) else cls_or_obj wrapper.register(obj) diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py index 7c9c1928f..31b9ae2d7 100644 --- a/qlib/workflow/__init__.py +++ b/qlib/workflow/__init__.py @@ -2,24 +2,29 @@ # Licensed under the MIT License. from contextlib import contextmanager -from .expm import * +from .expm import MLflowExpManager from ..utils import Wrapper class QlibRecorder: - def __init__(self, exp_manager, default_uri, current_uri): + """ + A global system that helps to manage the experiments. + """ + def __init__(self, exp_manager, uri): self.exp_manager = exp_manager - self.default_uri = default_uri - self.current_uri = current_uri + self.uri = uri @contextmanager def start(self, experiment_name): - run = self.start_exp(experiment_name, self.current_uri) - yield run + run = self.start_exp(experiment_name, self.uri) + try: + yield run + except: + self.end_exp() # end the experiment if something went wrong self.end_exp() def start_exp(self, experiment_name=None): - return self.exp_manager.start_exp(experiment_name, self.current_uri) + return self.exp_manager.start_exp(experiment_name, self.uri) def end_exp(self): self.exp_manager.end_exp() @@ -33,8 +38,8 @@ class QlibRecorder: def delete_exp(self, experiment_id): self.exp_manager.delete_exp(experiment_id) - def get_uri(self, type): - return self.exp_manager.get_uri(type) + def get_uri(self): + return self.exp_manager.get_uri() def get_recorder(self): return self.exp_manager.active_recorder diff --git a/qlib/workflow/exp.py b/qlib/workflow/exp.py index a63187e28..335dd338b 100644 --- a/qlib/workflow/exp.py +++ b/qlib/workflow/exp.py @@ -3,7 +3,7 @@ import mlflow from pathlib import Path - +from .recorder import MLflowRecorder class Experiment: """ @@ -15,6 +15,19 @@ class Experiment: self.id = None self.recorders = list() + def create_recorder(self): + """ + Create a recorder for each experiment. + + Parameters + ---------- + + Returns + ------- + A recorder instance. + """ + raise NotImplementedError(f"Please implement the `create_recorder` method.") + def search_records(self, **kwargs): """ Get a pandas DataFrame of records that fit the search criteria of the experiment. @@ -36,15 +49,39 @@ class Experiment: """ raise NotImplementedError(f"Please implement the `search_records` method.") + def delete_recorder(self, rid): + """ + Create a recorder for each experiment. + + Parameters + ---------- + rid : str + the id of the recorder to be deleted. + + Returns + ------- + A recorder instance. + """ + raise NotImplementedError(f"Please implement the `delete_recorder` method.") + class MLflowExperiment(Experiment): """ Use mlflow to implement Experiment. """ + def create_recorder(self): + recorder = MLflowRecorder(self.id) + self.recorders.append(recorder) + return recorders + def search_records(self, **kwargs): filter_string = "" if kwargs.get("filter_string") is None else kwargs.get("filter_string") run_view_type = 1 if kwargs.get("run_view_type") is None else kwargs.get("run_view_type") max_results = 100000 if kwargs.get("max_results") is None else kwargs.get("max_results") order_by = kwargs.get("order_by") return mlflow.search_runs([self.experiment_id], filter_string, run_view_type, max_results, order_by) + + def delete_recorder(self, rid): + mlflow.delete_run(rid) + self.recorders = [r for r in self.recorders if r.recorder_id == rid] \ No newline at end of file diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py index 00d25da48..3c633e3bb 100644 --- a/qlib/workflow/expm.py +++ b/qlib/workflow/expm.py @@ -6,8 +6,10 @@ import os from pathlib import Path from contextlib import contextmanager from .exp import MLflowExperiment -from .record import MLflowRecorder +from .recorder import MLflowRecorder +from ..log import get_module_logger +logger = get_module_logger('workflow', 'Warning') class ExpManager: """ @@ -16,7 +18,7 @@ class ExpManager: """ def __init__(self): - self.default_uri = None + self.uri = None self.active_recorder = None # only one recorder can running each time self.experiments = dict() # store the experiment name --> Experiment object @@ -117,20 +119,18 @@ class ExpManager: """ raise NotImplementedError(f"Please implement the `create_exp` method.") - def get_uri(self, type): + def get_uri(self): """ Get the default tracking URI or current URI. Parameters ---------- - type : str - the type of the tracking URI one wants to retrieve. Returns ------- The tracking URI string. """ - raise NotImplementedError(f"Please implement the `create_exp` method.") + return self.uri def get_recorder(self): """ @@ -143,7 +143,7 @@ class ExpManager: ------- An Recorder object. """ - raise NotImplementedError(f"Please implement the `get_recorder` method.") + return self.active_recorder class MLflowExpManager(ExpManager): @@ -153,17 +153,14 @@ class MLflowExpManager(ExpManager): def __init__(self): super(MLflowExpManager, self).__init__() - self.default_uri = None - self.current_uri = None + self.uri = None def start_exp(self, experiment_name=None, uri=None): # create experiment experiment = self.__create_exp(experiment_name, uri) # set up recorder - recorder = MLflowRecorder(experiment.id) + recorder = experiment.create_recorder() self.active_recorder = recorder - # store the recorder - experiment.recorders.append(self.active_recorder) # store the experiment self.experiments[experiment_name] = experiment @@ -178,15 +175,15 @@ class MLflowExpManager(ExpManager): experiment = MLflowExperiment() # set the tracking uri if uri is None: - print( + logger.warning( "No tracking URI is provided. The default tracking URI is set as `mlruns` under the working directory." ) else: - self.current_uri = uri - mlflow.set_tracking_uri(self.current_uri) + self.uri = uri + mlflow.set_tracking_uri(self.uri) # start the experiment if experiment_name is None: - print("No experiment name provided. The default experiment name is set as `experiment`.") + logger.warning("No experiment name provided. The default experiment name is set as `experiment`.") experiment_id = mlflow.create_experiment("experiment") # set the active experiment mlflow.set_experiment("experiment") @@ -227,19 +224,8 @@ class MLflowExpManager(ExpManager): if self.experiments[name].id == experiment_id: return self.experiments[name] else: - print("No valid experiment is found. Please make sure the id and name are correctly given.") + raise Exception("No valid experiment is found. Please make sure the id and name are correctly given.") def delete_exp(self, experiment_id): mlflow.delete_experiment(experiment_id) self.experiments = {key: val for key, val in self.experiments.items() if val.id != experiment_id} - - def get_uri(self, type): - if uri == "default": - return self.default_uri - elif uri == "current": - return self.current_uri - else: - raise ValueError("Input type is not supported. Please choose type default or current to get the uri.") - - def get_recorder(self): - return self.active_recorder diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py new file mode 100644 index 000000000..62ee14405 --- /dev/null +++ b/qlib/workflow/record_temp.py @@ -0,0 +1,137 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import pandas as pd +from pathlib import Path +from ..contrib.evaluate import ( + backtest as normal_backtest, + risk_analysis, +) +from ..utils import init_instance_by_config, get_module_by_module_path + + +class RecordTemp: + def __init__(self, *args, **kwargs): + pass + + def generate(self, **kwargs): + """ + Generate certain records such as IC, backtest etc., and save them. + + Parameters + ---------- + kwargs + + Return + ------ + The generated records. + """ + raise NotImplementedError(f"Please implement the `generate` method.") + + def check(self, **kwargs): + """ + Check if the records is properly generated and saved. + + Parameters + ---------- + kwargs + """ + raise NotImplementedError(f"Please implement the `check` method.") + + +# TODO: this can only be run under R's running experiment. +class SignalRecord(RecordTemp): + def __init__(self, model, dataset, recorder, **kwargs): + super(SignalRecord, self).__init__() + self.model = model + self.dataset = dataset + self.recorder = recorder + + def generate(self, **kwargs): + # generate prediciton + pred = self.model.predict(self.dataset) + self.recorder.save_object(pred, 'pred.pkl') + + def load(self): + # try to load the saved object + try: + pred = self.recorder.load_object('pred.pkl') + return pred + except: + raise Exception('Something went wrong when loading the saved object.') + + def check(self, **kwargs): + return self.recorder.check('pred.pkl') + + +# TODO +class SigAnaRecord(SignalRecord): + def __init__(self, recorder, **kwargs): + + def generate(self): + pass + + def load(self): + pass + + def check(self): + pass + + +class PortAnaRecord(SignalRecord): + def __init__(self, recorder, STRATEGY_CONFIG, BACKTEST_CONFIG, **kwargs): + self.recorder = recorder + self.STRATEGY_CONFIG = STRATEGY_CONFIG + self.BACKTEST_CONFIG = BACKTEST_CONFIG + module = get_module_by_module_path("qlib.contrib.strategy") + self.strategy = init_instance_by_config(STRATEGY_CONFIG, module) + self.artifact_path = Path('portfolio_analysis').resolve() + + def generate(self, **kwargs): + """ + STRATEGY_CONFIG : dict + define the strategy class as well as the kwargs. + BACKTEST_CONFIG : dict + define the backtest kwargs. + """ + # check previously stored prediction results + assert super().check(), "Make sure the parent process is completed and store the data properly." + # custom strategy and get backtest + pred_score = super().load() + report_normal, positions_normal = normal_backtest(pred_score, strategy=self.strategy, **self.BACKTEST_CONFIG) + self.recorder.save_object(report_normal, 'report_normal.pkl', self.artifact_path) + self.recorder.save_object(positions_normal, 'positions_normal.pkl', self.artifact_path) + + # analysis + analysis = dict() + analysis["excess_return_without_cost"] = risk_analysis(report_normal["return"] - report_normal["bench"]) + analysis["excess_return_with_cost"] = risk_analysis( + report_normal["return"] - report_normal["bench"] - report_normal["cost"] + ) + analysis_df = pd.concat(analysis) # type: pd.DataFrame + self.recorder.save_object(pred, 'port_analysis.pkl', self.artifact_path) + + def load(self): + # try to load the saved object + try: + pred = self.recorder.load_object(self.artifact_path / 'port_analysis.pkl'') + return pred + except: + raise Exception('Something went wrong when loading the saved object.') + + def check(self): + return self.recorder.check('port_analysis.pkl', self.artifact_path) + + + + + + + + + + + + + + diff --git a/qlib/workflow/record.py b/qlib/workflow/recorder.py similarity index 79% rename from qlib/workflow/record.py rename to qlib/workflow/recorder.py index e132710ca..042b052e0 100644 --- a/qlib/workflow/record.py +++ b/qlib/workflow/recorder.py @@ -21,7 +21,7 @@ class Recorder: def set_recorder_name(self, rname): self.recorder_name = rname - def save_object(self, data, name, local_path=None): + def save_object(self, data=None, name=None, local_path=None, artifact_path=None): """ Save object such as prediction file or model checkpoints to the artifact URI. @@ -33,10 +33,12 @@ class Recorder: name of the file to be saved. local_path : str if provided, them save the file or directory to the artifact URI. + artifact_path=None : str + the relative path for the artifact to be stored in the URI. """ raise NotImplementedError(f"Please implement the `save_object` method.") - def save_objects(self, data_name_list, local_path=None): + def save_objects(self, data_name_list=None, local_path=None, artifact_path=None): """ Save objects such as prediction file or model checkpoints to the artifact URI. @@ -46,6 +48,8 @@ class Recorder: list of (data, name) pairs local_path : str if provided, them save the file or directory to the artifact URI. + artifact_path=None : str + the relative path for the artifact to be stored in the URI. """ raise NotImplementedError(f"Please implement the `save_objects` method.") @@ -162,6 +166,7 @@ class MLflowRecorder(Recorder): # save the run id and artifact_uri self.recorder_id = run.info.run_id self.artifact_uri = run.info.artifact_uri + self._uri = mlflow.get_tracking_uri() # Fix!!! : this is not proper to have uri in recorder # set up file manager for saving objects self.temp_dir = tempfile.mkdtemp() self.fm = FileManager(Path(self.temp_dir).absolute()) @@ -171,27 +176,27 @@ class MLflowRecorder(Recorder): mlflow.end_run() shutil.rmtree(self.temp_dir) - def save_object(self, data, name, local_path=None): + def save_object(self, data=None, name=None, local_path=None, artifact_path=None): + client = mlflow.tracking.MlflowClient(tracking_uri=self._uri) if local_path is None: assert data is not None and name is not None, "Please provide data and name input." self.fm.save_obj(data, name) - mlflow.log_artifact(self.fm.path / name) - self.fm.remove(name) + client.log_artifact(self.recorder_id, self.fm.path / name, artifact_path) else: - mlflow.log_artifact(local_path) + assert local_path is not None, "Please provide a valid local path for the " + client.log_artifact(self.recorder_id, local_path, artifact_path) - def save_objects(self, data_name_list, local_path=None): + def save_objects(self, data_name_list=None, local_path=None, artifact_path=None): + client = mlflow.tracking.MlflowClient(tracking_uri=self._uri) if local_path is None: assert data_name_list is not None, "Please provide data_name_list input." self.fm.save_objs(data_name_list) - mlflow.log_artifacts(self.fm.path) - for obj, name in data_name_list: - self.fm.remove(name) + client.log_artifacts(self.recorder_id, self.fm.path, artifact_path) else: - mlflow.log_artifacts(local_path) + client.log_artifacts(self.recorder_id, local_path, artifact_path) def load_object(self, name): - client = mlflow.tracking.MlflowClient() + client = mlflow.tracking.MlflowClient(tracking_uri=self._uri) path = client.download_artifacts(self.recorder_id, name) try: with Path(path).open("rb") as f: @@ -229,3 +234,11 @@ class MLflowRecorder(Recorder): if self.artifact_uri is not None: return self.artifact_uri return mlflow.get_artifact_uri(artifact_path) + + def check(self, name, path=None): + client = mlflow.tracking.MlflowClient(tracking_uri=self._uri) + artifacts = client.list_artifacts(self.recorder_id, path) + for artifact in artifacts + if name in artifact.path: + return True + return False \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 165619920..f927ce5a2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -22,3 +22,4 @@ scikit_learn==0.23.2 torch==1.6.0 tqdm==4.49.0 yahooquery==2.2.7 +mlflow==1.11.0 \ No newline at end of file diff --git a/setup.py b/setup.py index 3a6237e5a..47ddceaf8 100644 --- a/setup.py +++ b/setup.py @@ -50,6 +50,7 @@ REQUIRED = [ "matplotlib==3.1.3", "tables>=3.6.1", "pyyaml>=5.3.1", + "mlflow>=1.10.0", "tqdm", "loguru", "lightgbm",