1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

bug fixed

This commit is contained in:
lzh222333
2021-05-14 11:31:50 +00:00
parent aef3f186c1
commit a986379deb
3 changed files with 71 additions and 23 deletions

View File

@@ -3,20 +3,19 @@
"""
This example shows how OnlineManager works with rolling tasks.
There are two parts including first train and routine.
There are four parts including first train, routine 1, add strategy and routine 2.
Firstly, the OnlineManager will finish the first training and set trained models to `online` models.
Next, the OnlineManager will finish a routine process, including update online prediction -> prepare signals -> prepare tasks -> prepare new models -> reset online models
Next, the OnlineManager will finish a routine process, including update online prediction -> prepare tasks -> prepare new models -> prepare signals
Then, we will add some new strategies to the OnlineManager. This will finish first training of new strategies.
Finally, the OnlineManager will finish second routine and update all strategies.
"""
import os
from pathlib import Path
import pickle
import fire
import qlib
from qlib.workflow import R
from qlib.workflow.online.strategy import RollingStrategy
from qlib.workflow.task.gen import RollingGen
from qlib.workflow.task.manage import TaskManager
from qlib.workflow.online.manager import OnlineManager
data_handler_config = {
@@ -84,7 +83,8 @@ class RollingOnlineExample:
task_url="mongodb://10.0.0.4:27017/",
task_db_name="rolling_db",
rolling_step=550,
tasks=[task_xgboost_config, task_lgb_config],
tasks=[task_xgboost_config],
add_tasks=[task_lgb_config],
):
mongo_conf = {
"task_url": task_url, # your MongoDB url
@@ -92,11 +92,12 @@ class RollingOnlineExample:
}
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
self.tasks = tasks
self.add_tasks = add_tasks
self.rolling_step = rolling_step
strategy = []
strategies = []
for task in tasks:
name_id = task["model"]["class"] # NOTE: Assumption: The model class can specify only one strategy
strategy.append(
strategies.append(
RollingStrategy(
name_id,
task,
@@ -104,8 +105,7 @@ class RollingOnlineExample:
)
)
self.rolling_online_manager = OnlineManager(strategy)
self.collector = self.rolling_online_manager.get_collector()
self.rolling_online_manager = OnlineManager(strategies)
_ROLLING_MANAGER_PATH = (
".RollingOnlineExample" # the OnlineManager will dump to this file, for it can be loaded when calling routine.
@@ -113,40 +113,60 @@ class RollingOnlineExample:
# Reset all things to the first status, be careful to save important data
def reset(self):
for task in self.tasks:
for task in self.tasks + self.add_tasks:
name_id = task["model"]["class"]
TaskManager(name_id).remove()
exp = R.get_exp(experiment_name=name_id)
for rid in exp.list_recorders():
exp.delete_recorder(rid)
if os.path.exists(self._ROLLING_MANAGER_PATH):
os.remove(self._ROLLING_MANAGER_PATH)
if os.path.exists(self._ROLLING_MANAGER_PATH):
os.remove(self._ROLLING_MANAGER_PATH)
def first_run(self):
print("========== reset ==========")
self.reset()
print("========== first_run ==========")
self.rolling_online_manager.first_train()
print("========== collect results ==========")
print(self.rolling_online_manager.get_collector()())
print("========== dump ==========")
self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)
print("========== collect results ==========")
print(self.collector())
def routine(self):
print("========== load ==========")
with Path(self._ROLLING_MANAGER_PATH).open("rb") as f:
self.rolling_online_manager = pickle.load(f)
self.rolling_online_manager = OnlineManager.load(self._ROLLING_MANAGER_PATH)
print("========== routine ==========")
self.rolling_online_manager.routine()
print("========== collect results ==========")
print(self.collector())
print(self.rolling_online_manager.get_collector()())
print("========== signals ==========")
print(self.rolling_online_manager.get_signals())
print("========== dump ==========")
self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)
def add_strategy(self):
print("========== load ==========")
self.rolling_online_manager = OnlineManager.load(self._ROLLING_MANAGER_PATH)
print("========== add strategy ==========")
strategies = []
for task in self.add_tasks:
name_id = task["model"]["class"] # NOTE: Assumption: The model class can specify only one strategy
strategies.append(
RollingStrategy(
name_id,
task,
RollingGen(step=self.rolling_step, rtype=RollingGen.ROLL_SD),
)
)
self.rolling_online_manager.add_strategy(strategies=strategies)
print("========== dump ==========")
self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)
def main(self):
self.first_run()
self.routine()
self.add_strategy()
self.routine()
if __name__ == "__main__":

View File

@@ -262,12 +262,29 @@ class OnlineManager(Serializable):
Prepare all models and signals if something is waiting for preparation.
Args:
model_kwargs: the params for `prepare_online_models`
model_kwargs: the params for `end_train`
signal_kwargs: the params for `prepare_signals`
"""
last_models = {}
signals_time = D.calendar()[0]
need_prepare = False
for cur_time, strategy_models in self.history.items():
self.cur_time = cur_time
for strategy, models in strategy_models.items():
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
# NOTE: Assumption: the predictions of online models need less than next cur_time, or this method will work in a wrong way.
self.prepare_signals(**signal_kwargs)
# only new online models need to prepare
if last_models.setdefault(strategy, set()) != set(models):
models = self.trainer.end_train(models, experiment_name=strategy.name_id, **model_kwargs)
strategy.tool.reset_online_tag(models)
need_prepare = True
last_models[strategy] = set(models)
if need_prepare:
# NOTE: Assumption: the predictions of online models need less than next cur_time, or this method will work in a wrong way.
self.prepare_signals(**signal_kwargs)
if signals_time > cur_time:
self.logger.warn(
f"The signals have already parpred to {signals_time} by last preparation, but current time is only {cur_time}. This may be because the online models predict more than they should, which can cause signals to be contaminated by the offline models."
)
need_prepare = False
signals_time = self.signals.index.get_level_values("datetime").max()

View File

@@ -39,6 +39,9 @@ class Recorder:
def __str__(self):
return str(self.info)
def __hash__(self) -> int:
return hash(self.info["id"])
@property
def info(self):
output = dict()
@@ -232,6 +235,14 @@ class MLflowRecorder(Recorder):
client=self.client,
)
def __hash__(self) -> int:
return hash(self.info["id"])
def __eq__(self, o: object) -> bool:
if isinstance(o, MLflowRecorder):
return self.info["id"] == o.info["id"]
return False
@property
def uri(self):
return self._uri