From c8d7d3ea2aed896b352cc07dd7d747424af0efa0 Mon Sep 17 00:00:00 2001 From: Young Date: Sun, 22 Nov 2020 03:17:50 +0000 Subject: [PATCH] fix record_tmp bug --- .github/workflows/test.yml | 2 +- qlib/workflow/record_temp.py | 38 +++++++++++++++++++++++++----------- setup.py | 2 +- tests/test_all_pipeline.py | 2 +- 4 files changed, 30 insertions(+), 14 deletions(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d1e01e46b..935d03116 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -56,4 +56,4 @@ jobs: - name: Test workflow by config run: | - workflow_by_config examples/benchmarks/GBDT/workflow_config_gbdt.yaml \ No newline at end of file + qrun examples/benchmarks/GBDT/workflow_config_gbdt.yaml diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index af4e99acb..b1fd9cc83 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -25,6 +25,19 @@ class RecordTemp: backtest in a certain format. """ + artifact_path = None + + @classmethod + def get_path(cls, path=None): + names = [] + if cls.artifact_path is not None: + names.append(cls.artifact_path) + + if path is not None: + names.append(path) + + return "/".join(names) + def __init__(self, recorder): self.recorder = recorder @@ -79,7 +92,7 @@ class RecordTemp: artifacts = set(self.recorder.list_artifacts()) if parent: # Downcasting have to be done here instead of using `super` - flist = self.__class__.__base__.list(self) + flist = self.__class__.__base__.list(self) # pylint: disable=E1101 else: flist = self.list() for item in flist: @@ -132,10 +145,12 @@ class SignalRecord(RecordTemp): class SigAnaRecord(SignalRecord): + + artifact_path = "sig_analysis" + def __init__(self, recorder, **kwargs): super().__init__(recorder=recorder, **kwargs) # The name must be unique. Otherwise it will be overridden - self.artifact_path_sig = "sig_analysis" def generate(self): self.check(parent=True) @@ -150,11 +165,11 @@ class SigAnaRecord(SignalRecord): "Rank ICIR": ric.mean() / ric.std(), } self.recorder.log_metrics(**metrics) - self.recorder.save_objects(**{"ic.pkl": ic, "ric.pkl": ric}, artifact_path=self.artifact_path_sig) + self.recorder.save_objects(**{"ic.pkl": ic, "ric.pkl": ric}, artifact_path=self.get_path()) pprint(metrics) def list(self): - return ["{self.artifact_path_sig}/ic.pkl", "{self.artifact_path_sig}/ric.pkl"] + return [self.get_path("ic.pkl"), self.get_path("ric.pkl")] class PortAnaRecord(SignalRecord): @@ -162,6 +177,8 @@ class PortAnaRecord(SignalRecord): This is the Portfolio Analysis Record class that generates the results such as those of backtest. """ + artifact_path = "portfolio_analysis" + def __init__(self, recorder, config, **kwargs): """ config["strategy"] : dict @@ -174,7 +191,6 @@ class PortAnaRecord(SignalRecord): self.strategy_config = config["strategy"] self.backtest_config = config["backtest"] self.strategy = init_instance_by_config(self.strategy_config) - self.artifact_path_port = "portfolio_analysis" def generate(self, **kwargs): # check previously stored prediction results @@ -183,8 +199,8 @@ class PortAnaRecord(SignalRecord): # custom strategy and get backtest pred_score = super().load() report_normal, positions_normal = normal_backtest(pred_score, strategy=self.strategy, **self.backtest_config) - self.recorder.save_objects(**{"report_normal.pkl": report_normal}, artifact_path=self.artifact_path_port) - self.recorder.save_objects(**{"positions_normal.pkl": positions_normal}, artifact_path=self.artifact_path_port) + self.recorder.save_objects(**{"report_normal.pkl": report_normal}, artifact_path=PortAnaRecord.get_path()) + self.recorder.save_objects(**{"positions_normal.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path()) # analysis analysis = dict() @@ -197,7 +213,7 @@ class PortAnaRecord(SignalRecord): # log metrics self.recorder.log_metrics(**flatten_dict(analysis_df["risk"].unstack().T.to_dict())) # save results - self.recorder.save_objects(**{"port_analysis.pkl": analysis_df}, artifact_path=self.artifact_path_port) + self.recorder.save_objects(**{"port_analysis.pkl": analysis_df}, artifact_path=PortAnaRecord.get_path()) logger.info( f"Portfolio analysis record 'port_analysis.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}" ) @@ -209,7 +225,7 @@ class PortAnaRecord(SignalRecord): def list(self): return [ - f"{self.artifact_path_port}/report_normal.pkl", - f"{self.artifact_path_port}/positions_normal.pkl", - f"{self.artifact_path_port}/port_analysis.pkl", + PortAnaRecord.get_path("report_normal.pkl"), + PortAnaRecord.get_path("positions_normal.pkl"), + PortAnaRecord.get_path("port_analysis.pkl"), ] diff --git a/setup.py b/setup.py index 7c2688666..7d7ea7fdb 100644 --- a/setup.py +++ b/setup.py @@ -101,7 +101,7 @@ setup( entry_points={ # 'console_scripts': ['mycli=mymodule:cli'], "console_scripts": [ - "workflow_by_config=qlib.workflow.cli:run", + "qrun=qlib.workflow.cli:run", ], }, ext_modules=extensions, diff --git a/tests/test_all_pipeline.py b/tests/test_all_pipeline.py index 16242189a..d2fb506ee 100644 --- a/tests/test_all_pipeline.py +++ b/tests/test_all_pipeline.py @@ -147,7 +147,7 @@ def backtest_analysis(pred, rid): # backtest par = PortAnaRecord(recorder, port_analysis_config) par.generate() - analysis_df = par.load("port_analysis.pkl") + analysis_df = par.load(par.get_path("port_analysis.pkl")) print(analysis_df) return analysis_df