diff --git a/qlib/__init__.py b/qlib/__init__.py index 83c36dbd3..b26ac986d 100644 --- a/qlib/__init__.py +++ b/qlib/__init__.py @@ -83,11 +83,9 @@ def init(default_conf="client", **kwargs): LOG.info(f"flask_server={C['flask_server']}, flask_port={C['flask_port']}") # set up QlibRecorder - uri = C["exp_uri"] - # exp manager module module = get_module_by_module_path("qlib.workflow.expm") exp_manager = init_instance_by_config(C["exp_manager"], module) - qr = QlibRecorder(exp_manager, uri) + qr = QlibRecorder(exp_manager) R.register(qr) diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py index a941ed7cf..5ac673a30 100644 --- a/qlib/workflow/__init__.py +++ b/qlib/workflow/__init__.py @@ -43,26 +43,26 @@ class QlibRecorder: def get_uri(self): return self.exp_manager.get_uri() - def get_recorder(self): - return self.exp_manager.active_recorder + def get_recorder(self, recorder_id=None, recorder_name=None): + return self.exp_manager.active_experiment.get_recorder(recorder_id, recorder_name) def save_objects(self, local_path=None, artifact_path=None, **kwargs): - self.exp_manager.active_recorder.save_objects(local_path, artifact_path, **kwargs) + self.exp_manager.active_experiment.active_recorder.save_objects(local_path, artifact_path, **kwargs) def load_object(self, name): - return self.exp_manager.active_recorder.load_object(name) + return self.exp_manager.active_experiment.active_recorder.load_object(name) def log_params(self, **kwargs): - self.exp_manager.active_recorder.log_params(**kwargs) + self.exp_manager.active_experiment.active_recorder.log_params(**kwargs) def log_metrics(self, step=None, **kwargs): - self.exp_manager.active_recorder.log_metrics(step, **kwargs) + self.exp_manager.active_experiment.active_recorder.log_metrics(step, **kwargs) def set_tags(self, **kwargs): - self.exp_manager.active_recorder.set_tags(**kwargs) + self.exp_manager.active_experiment.active_recorder.set_tags(**kwargs) def delete_tag(self, *key): - self.exp_manager.active_recorder.delete_tag(*key) + self.exp_manager.active_experiment.active_recorder.delete_tag(*key) # global record diff --git a/qlib/workflow/exp.py b/qlib/workflow/exp.py index 86163c0ea..e4ef6d8a6 100644 --- a/qlib/workflow/exp.py +++ b/qlib/workflow/exp.py @@ -4,7 +4,9 @@ import mlflow from pathlib import Path from .recorder import MLflowRecorder +from ..log import get_module_logger +logger = get_module_logger("workflow", "INFO") class Experiment: """ diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py index e81da0fcb..2afdee279 100644 --- a/qlib/workflow/expm.py +++ b/qlib/workflow/expm.py @@ -184,10 +184,12 @@ class MLflowExpManager(ExpManager): else: if experiment_name not in self.experiments: if mlflow.get_experiment_by_name(experiment_name) is not None: - raise Exception( - "The experiment has already been created before. Please pick another name or delete the files under uri." + logger.info( + "The experiment has already been created before. Try to resume the experiment..." ) - experiment_id = mlflow.create_experiment(experiment_name) + experiment_id = mlflow.get_experiment_by_name(experiment_name).experiment_id + else: + experiment_id = mlflow.create_experiment(experiment_name) else: experiment_id = self.experiments[experiment_name].id experiment = self.experiments[experiment_name] diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index cf3a86f7f..d92f836a8 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -76,7 +76,7 @@ class SignalRecord(RecordTemp): def generate(self, **kwargs): # generate prediciton pred = self.model.predict(self.dataset) - self.recorder.save_object(pred, "pred.pkl") + self.recorder.save_objects(data=pred, name="pred.pkl") def load(self): # try to load the saved object @@ -132,9 +132,9 @@ class PortAnaRecord(SignalRecord): assert super().check(), "Make sure the parent process is completed and store the data properly." # custom strategy and get backtest pred_score = super().load() - report_normal, positions_normal = normal_backtest(pred_score, strategy=self.strategy, **self.BACKTEST_CONFIG) - self.recorder.save_object(report_normal, "report_normal.pkl", self.artifact_path) - self.recorder.save_object(positions_normal, "positions_normal.pkl", self.artifact_path) + report_normal, positions_normal = normal_backtest(pred_score, strategy=self.strategy, **self.backtest_config) + self.recorder.save_objects(data=report_normal, name="report_normal.pkl", artifact_path=self.artifact_path) + self.recorder.save_objects(data=positions_normal, name="positions_normal.pkl", artifact_path=self.artifact_path) # analysis analysis = dict() @@ -143,7 +143,7 @@ class PortAnaRecord(SignalRecord): report_normal["return"] - report_normal["bench"] - report_normal["cost"] ) analysis_df = pd.concat(analysis) # type: pd.DataFrame - self.recorder.save_object(pred, "port_analysis.pkl", self.artifact_path) + self.recorder.save_objects(data=analysis_df, name="port_analysis.pkl", artifact_path=self.artifact_path) def load(self): # try to load the saved object diff --git a/qlib/workflow/recorder.py b/qlib/workflow/recorder.py index 157e29347..89b16e9f1 100644 --- a/qlib/workflow/recorder.py +++ b/qlib/workflow/recorder.py @@ -201,7 +201,7 @@ class MLflowRecorder(Recorder): def load_object(self, name): client = mlflow.tracking.MlflowClient(tracking_uri=self._uri) - path = client.download_artifacts(self.recorder_id, name) + path = client.download_artifacts(self.id, name) try: with Path(path).open("rb") as f: f.seek(0) @@ -242,5 +242,5 @@ class MLflowRecorder(Recorder): def list_artifacts(self, artifact_path=None): client = mlflow.tracking.MlflowClient(tracking_uri=self._uri) - artifacts = client.list_artifacts(self.id, path) + artifacts = client.list_artifacts(self.id, artifact_path) return artifacts