mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
bug fixed
This commit is contained in:
@@ -3,20 +3,19 @@
|
||||
|
||||
"""
|
||||
This example shows how OnlineManager works with rolling tasks.
|
||||
There are two parts including first train and routine.
|
||||
There are four parts including first train, routine 1, add strategy and routine 2.
|
||||
Firstly, the OnlineManager will finish the first training and set trained models to `online` models.
|
||||
Next, the OnlineManager will finish a routine process, including update online prediction -> prepare signals -> prepare tasks -> prepare new models -> reset online models
|
||||
Next, the OnlineManager will finish a routine process, including update online prediction -> prepare tasks -> prepare new models -> prepare signals
|
||||
Then, we will add some new strategies to the OnlineManager. This will finish first training of new strategies.
|
||||
Finally, the OnlineManager will finish second routine and update all strategies.
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.online.strategy import RollingStrategy
|
||||
from qlib.workflow.task.gen import RollingGen
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.workflow.online.manager import OnlineManager
|
||||
|
||||
data_handler_config = {
|
||||
@@ -84,7 +83,8 @@ class RollingOnlineExample:
|
||||
task_url="mongodb://10.0.0.4:27017/",
|
||||
task_db_name="rolling_db",
|
||||
rolling_step=550,
|
||||
tasks=[task_xgboost_config, task_lgb_config],
|
||||
tasks=[task_xgboost_config],
|
||||
add_tasks=[task_lgb_config],
|
||||
):
|
||||
mongo_conf = {
|
||||
"task_url": task_url, # your MongoDB url
|
||||
@@ -92,11 +92,12 @@ class RollingOnlineExample:
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
|
||||
self.tasks = tasks
|
||||
self.add_tasks = add_tasks
|
||||
self.rolling_step = rolling_step
|
||||
strategy = []
|
||||
strategies = []
|
||||
for task in tasks:
|
||||
name_id = task["model"]["class"] # NOTE: Assumption: The model class can specify only one strategy
|
||||
strategy.append(
|
||||
strategies.append(
|
||||
RollingStrategy(
|
||||
name_id,
|
||||
task,
|
||||
@@ -104,8 +105,7 @@ class RollingOnlineExample:
|
||||
)
|
||||
)
|
||||
|
||||
self.rolling_online_manager = OnlineManager(strategy)
|
||||
self.collector = self.rolling_online_manager.get_collector()
|
||||
self.rolling_online_manager = OnlineManager(strategies)
|
||||
|
||||
_ROLLING_MANAGER_PATH = (
|
||||
".RollingOnlineExample" # the OnlineManager will dump to this file, for it can be loaded when calling routine.
|
||||
@@ -113,40 +113,60 @@ class RollingOnlineExample:
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(self):
|
||||
for task in self.tasks:
|
||||
for task in self.tasks + self.add_tasks:
|
||||
name_id = task["model"]["class"]
|
||||
TaskManager(name_id).remove()
|
||||
exp = R.get_exp(experiment_name=name_id)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
if os.path.exists(self._ROLLING_MANAGER_PATH):
|
||||
os.remove(self._ROLLING_MANAGER_PATH)
|
||||
if os.path.exists(self._ROLLING_MANAGER_PATH):
|
||||
os.remove(self._ROLLING_MANAGER_PATH)
|
||||
|
||||
def first_run(self):
|
||||
print("========== reset ==========")
|
||||
self.reset()
|
||||
print("========== first_run ==========")
|
||||
self.rolling_online_manager.first_train()
|
||||
print("========== collect results ==========")
|
||||
print(self.rolling_online_manager.get_collector()())
|
||||
print("========== dump ==========")
|
||||
self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)
|
||||
print("========== collect results ==========")
|
||||
print(self.collector())
|
||||
|
||||
def routine(self):
|
||||
print("========== load ==========")
|
||||
with Path(self._ROLLING_MANAGER_PATH).open("rb") as f:
|
||||
self.rolling_online_manager = pickle.load(f)
|
||||
self.rolling_online_manager = OnlineManager.load(self._ROLLING_MANAGER_PATH)
|
||||
print("========== routine ==========")
|
||||
self.rolling_online_manager.routine()
|
||||
print("========== collect results ==========")
|
||||
print(self.collector())
|
||||
print(self.rolling_online_manager.get_collector()())
|
||||
print("========== signals ==========")
|
||||
print(self.rolling_online_manager.get_signals())
|
||||
print("========== dump ==========")
|
||||
self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)
|
||||
|
||||
def add_strategy(self):
|
||||
print("========== load ==========")
|
||||
self.rolling_online_manager = OnlineManager.load(self._ROLLING_MANAGER_PATH)
|
||||
print("========== add strategy ==========")
|
||||
strategies = []
|
||||
for task in self.add_tasks:
|
||||
name_id = task["model"]["class"] # NOTE: Assumption: The model class can specify only one strategy
|
||||
strategies.append(
|
||||
RollingStrategy(
|
||||
name_id,
|
||||
task,
|
||||
RollingGen(step=self.rolling_step, rtype=RollingGen.ROLL_SD),
|
||||
)
|
||||
)
|
||||
self.rolling_online_manager.add_strategy(strategies=strategies)
|
||||
print("========== dump ==========")
|
||||
self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)
|
||||
|
||||
def main(self):
|
||||
self.first_run()
|
||||
self.routine()
|
||||
self.add_strategy()
|
||||
self.routine()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -262,12 +262,29 @@ class OnlineManager(Serializable):
|
||||
Prepare all models and signals if something is waiting for preparation.
|
||||
|
||||
Args:
|
||||
model_kwargs: the params for `prepare_online_models`
|
||||
model_kwargs: the params for `end_train`
|
||||
signal_kwargs: the params for `prepare_signals`
|
||||
"""
|
||||
last_models = {}
|
||||
signals_time = D.calendar()[0]
|
||||
need_prepare = False
|
||||
for cur_time, strategy_models in self.history.items():
|
||||
self.cur_time = cur_time
|
||||
|
||||
for strategy, models in strategy_models.items():
|
||||
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
|
||||
# NOTE: Assumption: the predictions of online models need less than next cur_time, or this method will work in a wrong way.
|
||||
self.prepare_signals(**signal_kwargs)
|
||||
# only new online models need to prepare
|
||||
if last_models.setdefault(strategy, set()) != set(models):
|
||||
models = self.trainer.end_train(models, experiment_name=strategy.name_id, **model_kwargs)
|
||||
strategy.tool.reset_online_tag(models)
|
||||
need_prepare = True
|
||||
last_models[strategy] = set(models)
|
||||
|
||||
if need_prepare:
|
||||
# NOTE: Assumption: the predictions of online models need less than next cur_time, or this method will work in a wrong way.
|
||||
self.prepare_signals(**signal_kwargs)
|
||||
if signals_time > cur_time:
|
||||
self.logger.warn(
|
||||
f"The signals have already parpred to {signals_time} by last preparation, but current time is only {cur_time}. This may be because the online models predict more than they should, which can cause signals to be contaminated by the offline models."
|
||||
)
|
||||
need_prepare = False
|
||||
signals_time = self.signals.index.get_level_values("datetime").max()
|
||||
|
||||
@@ -39,6 +39,9 @@ class Recorder:
|
||||
def __str__(self):
|
||||
return str(self.info)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.info["id"])
|
||||
|
||||
@property
|
||||
def info(self):
|
||||
output = dict()
|
||||
@@ -232,6 +235,14 @@ class MLflowRecorder(Recorder):
|
||||
client=self.client,
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.info["id"])
|
||||
|
||||
def __eq__(self, o: object) -> bool:
|
||||
if isinstance(o, MLflowRecorder):
|
||||
return self.info["id"] == o.info["id"]
|
||||
return False
|
||||
|
||||
@property
|
||||
def uri(self):
|
||||
return self._uri
|
||||
|
||||
Reference in New Issue
Block a user