diff --git a/examples/multi_level_trading/workflow.py b/examples/multi_level_trading/workflow.py index 8096fc76f..2b70d4411 100644 --- a/examples/multi_level_trading/workflow.py +++ b/examples/multi_level_trading/workflow.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from typing import Optional import qlib import fire @@ -124,11 +125,17 @@ class MultiLevelTradingWorkflow: sr = SignalRecord(model, dataset, recorder) sr.generate() - def backtest(self): + def _load_model(self, load): + return R.get_recorder(load, experiment_name="train").load_object("params.pkl") + + def backtest(self, load_model: Optional[str] = None): self._init_qlib() model = init_instance_by_config(self.task["model"]) dataset = init_instance_by_config(self.task["dataset"]) - self._train_model(model, dataset) + if load_model is None: + self._train_model(model, dataset) + else: + model = self._load_model(load_model) strategy_config = { "class": "TopkDropoutStrategy", "module_path": "qlib.contrib.strategy.model_strategy",