mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-02 18:40:58 +08:00
online serving V7
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
from pprint import pprint
|
||||
import time
|
||||
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.model.trainer import task_train
|
||||
from qlib.model.trainer import TrainerR, task_train
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.manage import TaskManager, run_task
|
||||
@@ -102,7 +103,7 @@ def task_training(tasks, task_pool, exp_name):
|
||||
|
||||
|
||||
# This part corresponds to "Task Collecting" in the document
|
||||
def task_collecting(task_pool, exp_name):
|
||||
def task_collecting(exp_name):
|
||||
print("========== task_collecting ==========")
|
||||
|
||||
def rec_key(recorder):
|
||||
@@ -141,7 +142,7 @@ def main(
|
||||
reset(task_pool, experiment_name)
|
||||
tasks = task_generating()
|
||||
task_training(tasks, task_pool, experiment_name)
|
||||
task_collecting(task_pool, experiment_name)
|
||||
task_collecting(experiment_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
198
examples/online_srv/online_management_simulate.py
Normal file
198
examples/online_srv/online_management_simulate.py
Normal file
@@ -0,0 +1,198 @@
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.model.ens.ensemble import ens_workflow
|
||||
from qlib.model.ens.group import RollingGroup
|
||||
from qlib.model.trainer import TrainerRM
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.online.manager import RollingOnlineManager
|
||||
from qlib.workflow.online.simulator import OnlineSimulator
|
||||
from qlib.workflow.task.collect import RecorderCollector
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
|
||||
"""
|
||||
This examples is about the OnlineManager and OnlineSimulator based on Rolling tasks.
|
||||
The OnlineManager will focus on the updating of your online models.
|
||||
The OnlineSimulator will focus on the simulating real updating routine of your online models.
|
||||
"""
|
||||
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2018-01-01",
|
||||
"end_time": None, # "2018-10-31",
|
||||
"fit_start_time": "2018-01-01",
|
||||
"fit_end_time": "2018-03-31",
|
||||
"instruments": "csi100",
|
||||
}
|
||||
|
||||
dataset_config = {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2018-01-01", "2018-03-31"),
|
||||
"valid": ("2018-04-01", "2018-05-31"),
|
||||
"test": ("2018-06-01", "2018-09-10"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
record_config = [
|
||||
{
|
||||
"class": "SignalRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
{
|
||||
"class": "SigAnaRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
]
|
||||
|
||||
# use lgb model
|
||||
task_lgb_config = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
},
|
||||
"dataset": dataset_config,
|
||||
"record": record_config,
|
||||
}
|
||||
|
||||
# use xgboost model
|
||||
task_xgboost_config = {
|
||||
"model": {
|
||||
"class": "XGBModel",
|
||||
"module_path": "qlib.contrib.model.xgboost",
|
||||
},
|
||||
"dataset": dataset_config,
|
||||
"record": record_config,
|
||||
}
|
||||
|
||||
|
||||
class OnlineManagerExample:
|
||||
def __init__(
|
||||
self,
|
||||
provider_uri="~/.qlib/qlib_data/cn_data",
|
||||
region="cn",
|
||||
exp_name="rolling_exp",
|
||||
task_url="mongodb://10.0.0.4:27017/",
|
||||
task_db_name="rolling_db",
|
||||
task_pool="rolling_task",
|
||||
rolling_step=80,
|
||||
start_time="2018-09-10",
|
||||
end_time="2018-10-31",
|
||||
):
|
||||
"""
|
||||
init OnlineManagerExample.
|
||||
|
||||
Args:
|
||||
provider_uri (str, optional): the provider uri. Defaults to "~/.qlib/qlib_data/cn_data".
|
||||
region (str, optional): the stock region. Defaults to "cn".
|
||||
exp_name (str, optional): the experiment name. Defaults to "rolling_exp".
|
||||
task_url (str, optional): your MongoDB url. Defaults to "mongodb://10.0.0.4:27017/".
|
||||
task_db_name (str, optional): database name. Defaults to "rolling_db".
|
||||
task_pool (str, optional): the task pool name (a task pool is a collection in MongoDB). Defaults to "rolling_task".
|
||||
rolling_step (int, optional): the step for rolling. Defaults to 80.
|
||||
start_time (str, optional): the start time of simulating. Defaults to "2018-09-10".
|
||||
end_time (str, optional): the end time of simulating. Defaults to "2018-10-31".
|
||||
"""
|
||||
self.exp_name = exp_name
|
||||
self.task_pool = task_pool
|
||||
mongo_conf = {
|
||||
"task_url": task_url,
|
||||
"task_db_name": task_db_name,
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
|
||||
|
||||
self.rolling_gen = RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD) # The rolling tasks generator
|
||||
self.trainer = TrainerRM(self.exp_name, self.task_pool) # The trainer based on (R)ecorder and Task(M)anager
|
||||
self.task_manager = TaskManager(self.task_pool) # A good way to manage all your tasks
|
||||
self.collector = RecorderCollector(exp_name=self.exp_name, rec_key_func=self.rec_key) # The result collector
|
||||
self.grouper = RollingGroup() # Divide your results into different rolling group
|
||||
self.rolling_online_manager = RollingOnlineManager(
|
||||
experiment_name=exp_name,
|
||||
rolling_gen=self.rolling_gen,
|
||||
trainer=self.trainer,
|
||||
collector=self.collector,
|
||||
need_log=False,
|
||||
) # The OnlineManager based on Rolling
|
||||
self.onlinesimulator = OnlineSimulator(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
onlinemanager=self.rolling_online_manager,
|
||||
)
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(self):
|
||||
print("========== reset ==========")
|
||||
self.task_manager.remove()
|
||||
exp = R.get_exp(experiment_name=self.exp_name)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
@staticmethod
|
||||
def rec_key(recorder):
|
||||
"""
|
||||
given a Recorder and return its key to identify it
|
||||
|
||||
Args:
|
||||
recorder (Recorder): a instance of the Recorder
|
||||
|
||||
Returns:
|
||||
tuple: (model_key, rolling_key)
|
||||
"""
|
||||
task_config = recorder.load_object("task")
|
||||
model_key = task_config["model"]["class"]
|
||||
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
|
||||
return model_key, rolling_key
|
||||
|
||||
def result_collecting(self):
|
||||
print("========== result collecting ==========")
|
||||
|
||||
# ens_workflow can help collect, group and ensemble results in a easy way
|
||||
artifact = ens_workflow(self.rolling_online_manager.get_collector(), self.grouper)
|
||||
print(artifact)
|
||||
|
||||
# Run this firstly to see the workflow in OnlineManager
|
||||
def first_train(self):
|
||||
print("========== first train ==========")
|
||||
self.reset()
|
||||
|
||||
tasks = task_generator(
|
||||
tasks=[task_xgboost_config, task_lgb_config],
|
||||
generators=[self.rolling_gen], # generate different date segment
|
||||
)
|
||||
|
||||
self.rolling_online_manager.prepare_new_models(tasks=tasks, tag=RollingOnlineManager.ONLINE_TAG)
|
||||
self.result_collecting()
|
||||
|
||||
# Run this secondly to see the simulating in OnlineSimulator
|
||||
def simulate(self):
|
||||
|
||||
print("========== simulate ==========")
|
||||
self.onlinesimulator.simulate()
|
||||
|
||||
self.result_collecting()
|
||||
|
||||
print("========== online models ==========")
|
||||
recs_dict = self.onlinesimulator.online_models()
|
||||
for time, recs in recs_dict.items():
|
||||
print(f"{str(time[0])} to {str(time[1])}:")
|
||||
for rec in recs:
|
||||
print(rec.info["id"])
|
||||
|
||||
# Run this to run all workflow automaticly
|
||||
def main(self):
|
||||
self.first_train()
|
||||
self.simulate()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
## to run all workflow automaticly with your own parameters, use the command below
|
||||
# python online_management_simulate.py main --experiment_name="your_exp_name" --rolling_step=60
|
||||
fire.Fire(OnlineManagerExample)
|
||||
@@ -1,163 +0,0 @@
|
||||
from abc import abstractmethod
|
||||
import copy
|
||||
from pprint import pprint
|
||||
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.model.trainer import task_train
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.task.gen import TaskGen
|
||||
from qlib.workflow.online.simulator import OnlineSimulator
|
||||
from qlib.workflow.task.collect import RecorderCollector
|
||||
from qlib.model.ens.ensemble import RollingEnsemble, ens_workflow
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.manage import TaskManager, run_task
|
||||
from qlib.workflow.online.manager import RollingOnlineManager
|
||||
from qlib.workflow.task.utils import TimeAdjuster, list_recorders
|
||||
from qlib.model.trainer import TrainerRM
|
||||
from qlib.model.ens.group import RollingGroup
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2018-01-01",
|
||||
"end_time": "2018-10-31",
|
||||
"fit_start_time": "2018-01-01",
|
||||
"fit_end_time": "2018-03-31",
|
||||
"instruments": "csi100",
|
||||
}
|
||||
|
||||
dataset_config = {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2018-01-01", "2018-03-31"),
|
||||
"valid": ("2018-04-01", "2018-05-31"),
|
||||
"test": ("2018-06-01", "2018-09-10"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
record_config = [
|
||||
{
|
||||
"class": "SignalRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
{
|
||||
"class": "SigAnaRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
]
|
||||
|
||||
# use lgb model
|
||||
task_lgb_config = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
},
|
||||
"dataset": dataset_config,
|
||||
"record": record_config,
|
||||
}
|
||||
|
||||
# use xgboost model
|
||||
task_xgboost_config = {
|
||||
"model": {
|
||||
"class": "XGBModel",
|
||||
"module_path": "qlib.contrib.model.xgboost",
|
||||
},
|
||||
"dataset": dataset_config,
|
||||
"record": record_config,
|
||||
}
|
||||
|
||||
|
||||
class OnlineSimulatorExample:
|
||||
def __init__(
|
||||
self,
|
||||
exp_name="rolling_exp",
|
||||
task_pool="rolling_task",
|
||||
provider_uri="~/.qlib/qlib_data/cn_data",
|
||||
region="cn",
|
||||
task_url="mongodb://10.0.0.4:27017/",
|
||||
task_db_name="rolling_db",
|
||||
rolling_step=80,
|
||||
):
|
||||
self.exp_name = exp_name
|
||||
self.task_pool = task_pool
|
||||
mongo_conf = {
|
||||
"task_url": task_url, # your MongoDB url
|
||||
"task_db_name": task_db_name, # database name
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
|
||||
|
||||
self.rolling_gen = RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD)
|
||||
self.trainer = TrainerRM(self.exp_name, self.task_pool)
|
||||
self.task_manager = TaskManager(self.task_pool)
|
||||
self.rolling_online_manager = RollingOnlineManager(
|
||||
experiment_name=exp_name, rolling_gen=self.rolling_gen, trainer=self.trainer, need_log=False
|
||||
)
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(self):
|
||||
print("========== reset ==========")
|
||||
self.task_manager.remove()
|
||||
exp = R.get_exp(experiment_name=self.exp_name)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
@staticmethod
|
||||
def rec_key(recorder):
|
||||
task_config = recorder.load_object("task")
|
||||
model_key = task_config["model"]["class"]
|
||||
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
|
||||
return model_key, rolling_key
|
||||
|
||||
# Run this firstly to see the workflow in Task Management
|
||||
def first_run(self):
|
||||
print("========== first_run ==========")
|
||||
self.reset()
|
||||
|
||||
tasks = task_generator(
|
||||
tasks=task_xgboost_config,
|
||||
generators=[self.rolling_gen], # generate different date segment
|
||||
)
|
||||
|
||||
pprint(tasks)
|
||||
|
||||
self.trainer.train(tasks)
|
||||
|
||||
print("========== task collecting ==========")
|
||||
|
||||
artifact = ens_workflow(RecorderCollector(exp_name=self.exp_name, rec_key_func=self.rec_key), RollingGroup())
|
||||
print(artifact)
|
||||
|
||||
latest_rec, _ = self.rolling_online_manager.list_latest_recorders()
|
||||
self.rolling_online_manager.set_online_tag(RollingOnlineManager.ONLINE_TAG, list(latest_rec.values()))
|
||||
|
||||
def simulate(self):
|
||||
|
||||
print("========== simulate ==========")
|
||||
onlinesimulator = OnlineSimulator(
|
||||
start_time="2018-09-10",
|
||||
end_time="2018-10-31",
|
||||
onlinemanager=self.rolling_online_manager,
|
||||
collector=RecorderCollector(exp_name=self.exp_name, rec_key_func=self.rec_key),
|
||||
process_list=[RollingGroup()],
|
||||
)
|
||||
results = onlinesimulator.simulate()
|
||||
print(results)
|
||||
recs_dict = onlinesimulator.online_models()
|
||||
for time, recs in recs_dict.items():
|
||||
print(f"{str(time[0])} to {str(time[1])}:")
|
||||
for rec in recs:
|
||||
print(rec.info["id"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ose = OnlineSimulatorExample()
|
||||
ose.first_run()
|
||||
ose.simulate()
|
||||
@@ -123,7 +123,8 @@ class RollingOnlineExample:
|
||||
return tasks
|
||||
|
||||
def task_training(self, tasks):
|
||||
self.trainer.train(tasks)
|
||||
# self.trainer.train(tasks)
|
||||
self.rolling_online_manager.prepare_new_models(tasks, tag=RollingOnlineManager.ONLINE_TAG)
|
||||
|
||||
# This part corresponds to "Task Collecting" in the document
|
||||
def task_collecting(self):
|
||||
@@ -165,10 +166,8 @@ class RollingOnlineExample:
|
||||
self.task_training(tasks)
|
||||
self.task_collecting()
|
||||
|
||||
latest_rec, _ = self.rolling_online_manager.list_latest_recorders()
|
||||
self.rolling_online_manager.reset_online_tag(list(latest_rec.values()))
|
||||
|
||||
self.routine()
|
||||
# latest_rec, _ = self.rolling_online_manager.list_latest_recorders()
|
||||
# self.rolling_online_manager.reset_online_tag(list(latest_rec.values()))
|
||||
|
||||
def routine(self):
|
||||
print("========== routine ==========")
|
||||
@@ -177,6 +176,10 @@ class RollingOnlineExample:
|
||||
self.print_online_model()
|
||||
self.task_collecting()
|
||||
|
||||
def main(self):
|
||||
self.first_run()
|
||||
self.routine()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
####### to train the first version's models, use the command below
|
||||
|
||||
@@ -488,7 +488,7 @@ class TSDatasetH(DatasetH):
|
||||
"""
|
||||
split the _prepare_raw_seg is to leave a hook for data preprocessing before creating processing data
|
||||
"""
|
||||
dtype = kwargs.pop("dtype")
|
||||
dtype = kwargs.pop("dtype", None)
|
||||
start, end = slc.start, slc.stop
|
||||
data = self._prepare_raw_seg(slc=slc, **kwargs)
|
||||
tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len, dtype=dtype)
|
||||
|
||||
@@ -26,7 +26,6 @@ def task_train(task_config: dict, experiment_name: str) -> Recorder:
|
||||
----------
|
||||
Recorder : The instance of the recorder
|
||||
"""
|
||||
|
||||
# model initiaiton
|
||||
model: Model = init_instance_by_config(task_config["model"])
|
||||
dataset: Dataset = init_instance_by_config(task_config["dataset"])
|
||||
@@ -46,7 +45,7 @@ def task_train(task_config: dict, experiment_name: str) -> Recorder:
|
||||
|
||||
# generate records: prediction, backtest, and analysis
|
||||
records = task_config.get("record", [])
|
||||
recorder = R.get_recorder()
|
||||
recorder: Recorder = R.get_recorder()
|
||||
if isinstance(records, dict): # prevent only one dict
|
||||
records = [records]
|
||||
for record in records:
|
||||
|
||||
@@ -4,6 +4,7 @@ from qlib.workflow import R
|
||||
from qlib.model.trainer import task_train
|
||||
from qlib.workflow.recorder import MLflowRecorder, Recorder
|
||||
from qlib.workflow.online.update import PredUpdater, RecordUpdater
|
||||
from qlib.workflow.task.collect import Collector
|
||||
from qlib.workflow.task.utils import TimeAdjuster
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
@@ -14,78 +15,127 @@ from qlib.model.trainer import Trainer, TrainerR
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
class OnlineManager(Serializable):
|
||||
class OnlineManager:
|
||||
|
||||
ONLINE_KEY = "online_status" # the online status key in recorder
|
||||
ONLINE_TAG = "online" # the 'online' model
|
||||
# NOTE: The meaning of this tag is that we can not assume the training models can be trained before we need its predition. Whenever finished training, it can be guaranteed that there are some online models.
|
||||
NEXT_ONLINE_TAG = "next_online" # the 'next online' model, which can be 'online' model when call reset_online_model
|
||||
OFFLINE_TAG = "offline" # the 'offline' model, not for online serving
|
||||
|
||||
def __init__(self, trainer: Trainer = None, need_log=True):
|
||||
self._trainer = trainer
|
||||
def __init__(self, trainer: Trainer = None, collector: Collector = None, need_log=True):
|
||||
"""
|
||||
init OnlineManager.
|
||||
|
||||
Args:
|
||||
trainer (Trainer, optional): a instance of Trainer. Defaults to None.
|
||||
collector (Collector, optional): a instance of Collector. Defaults to None.
|
||||
need_log (bool, optional): print log or not. Defaults to True.
|
||||
"""
|
||||
self.trainer = trainer
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
self.need_log = need_log
|
||||
self.delay_signals = {}
|
||||
self.collector = collector
|
||||
self.cur_time = None
|
||||
|
||||
def prepare_signals(self, *args, **kwargs):
|
||||
"""
|
||||
After perparing the data of last routine (a box in box-plot) which means the end of the routine, we can prepare trading signals for next routine.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `prepare_signals` method.")
|
||||
|
||||
def prepare_tasks(self, *args, **kwargs):
|
||||
"""return the new tasks waiting for training."""
|
||||
"""
|
||||
After the end of a routine, check whether we need to prepare and train some new tasks.
|
||||
return the new tasks waiting for training.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `prepare_tasks` method.")
|
||||
|
||||
def prepare_new_models(self, tasks):
|
||||
"""Use trainer to train a list of tasks and set the trained model to next_online.
|
||||
def prepare_new_models(self, tasks, tag=NEXT_ONLINE_TAG):
|
||||
"""
|
||||
Use trainer to train a list of tasks and set the trained model to `tag`.
|
||||
|
||||
Args:
|
||||
tasks (list): a list of tasks.
|
||||
tag (str):
|
||||
`ONLINE_TAG` for first train or additional train
|
||||
`NEXT_ONLINE_TAG` for reset online model when calling `reset_online_tag`
|
||||
`OFFLINE_TAG` for train but offline those models
|
||||
"""
|
||||
if not (tasks is None or len(tasks) == 0):
|
||||
if self._trainer is not None:
|
||||
new_models = self._trainer.train(tasks)
|
||||
self.set_online_tag(self.NEXT_ONLINE_TAG, new_models)
|
||||
self.logger.info(
|
||||
f"Finished prepare {len(new_models)} new models and set them to `{self.NEXT_ONLINE_TAG}`."
|
||||
)
|
||||
if self.trainer is not None:
|
||||
new_models = self.trainer.train(tasks)
|
||||
self.set_online_tag(tag, new_models)
|
||||
if self.need_log:
|
||||
self.logger.info(f"Finished prepare {len(new_models)} new models and set them to {tag}.")
|
||||
else:
|
||||
self.logger.warn("No trainer to train new tasks.")
|
||||
|
||||
def update_online_pred(self, *args, **kwargs):
|
||||
"""
|
||||
After the end of a routine, update the predictions of online models to latest.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `update_online_pred` method.")
|
||||
|
||||
def set_online_tag(self, tag, *args, **kwargs):
|
||||
"""set `tag` to the model to sign whether online
|
||||
"""
|
||||
Set `tag` to the model to sign whether online.
|
||||
|
||||
Args:
|
||||
tag (str): the tags in ONLINE_TAG, NEXT_ONLINE_TAG, OFFLINE_TAG
|
||||
tag (str): the tags in `ONLINE_TAG`, `NEXT_ONLINE_TAG`, `OFFLINE_TAG`
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `set_online_tag` method.")
|
||||
|
||||
def get_online_tag(self, *args, **kwargs):
|
||||
"""given a model and return its online tag"""
|
||||
"""
|
||||
Given a model and return its online tag.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_online_tag` method.")
|
||||
|
||||
def reset_online_tag(self, *args, **kwargs):
|
||||
"""offline all models and set the recorders to 'online'. If no parameter and no 'next online' model, then do nothing."""
|
||||
"""
|
||||
Offline all models and set the models to 'online'.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `reset_online_tag` method.")
|
||||
|
||||
def online_models(self):
|
||||
"""return online models"""
|
||||
"""
|
||||
Return online models.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `online_models` method.")
|
||||
|
||||
def get_collector(self):
|
||||
"""
|
||||
Return the collector.
|
||||
|
||||
Returns:
|
||||
Collector
|
||||
"""
|
||||
return self.collector
|
||||
|
||||
def run_delay_signals(self):
|
||||
"""
|
||||
Prepare all signals if there are some dates waiting for prepare.
|
||||
"""
|
||||
for cur_time, params in self.delay_signals.items():
|
||||
self.cur_time = cur_time
|
||||
self.prepare_signals(*params[0], **params[1])
|
||||
self.delay_signals = {}
|
||||
|
||||
def routine(self, cur_time=None, delay_prepare=False, *args, **kwargs):
|
||||
"""The typical update process in a routine such as day by day or month by month"""
|
||||
"""
|
||||
The typical update process after a routine, such as day by day or month by month.
|
||||
Prepare signals -> prepare tasks -> prepare new models -> update online prediction -> reset online models
|
||||
"""
|
||||
self.cur_time = cur_time # None for latest date
|
||||
if not delay_prepare:
|
||||
self.prepare_signals(*args, **kwargs)
|
||||
else:
|
||||
self.delay_signals[cur_time] = (args, kwargs)
|
||||
if cur_time is not None:
|
||||
self.delay_signals[cur_time] = (args, kwargs)
|
||||
else:
|
||||
raise ValueError("Can not delay prepare when cur_time is None")
|
||||
tasks = self.prepare_tasks(*args, **kwargs)
|
||||
self.prepare_new_models(tasks)
|
||||
self.update_online_pred()
|
||||
@@ -98,9 +148,18 @@ class OnlineManagerR(OnlineManager):
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, experiment_name: str, trainer: Trainer = None, need_log=True):
|
||||
def __init__(self, experiment_name: str, trainer: Trainer = None, collector: Collector = None, need_log=True):
|
||||
"""
|
||||
init OnlineManagerR.
|
||||
|
||||
Args:
|
||||
experiment_name (str): the experiment name.
|
||||
trainer (Trainer, optional): a instance of Trainer. Defaults to None.
|
||||
collector (Collector, optional): a instance of Collector. Defaults to None.
|
||||
need_log (bool, optional): print log or not. Defaults to True.
|
||||
"""
|
||||
trainer = TrainerR(experiment_name)
|
||||
super().__init__(trainer, need_log)
|
||||
super().__init__(trainer=trainer, collector=collector, need_log=need_log)
|
||||
self.exp_name = experiment_name
|
||||
|
||||
def set_online_tag(self, tag, recorder: Union[Recorder, List]):
|
||||
@@ -148,7 +207,9 @@ class OnlineManagerR(OnlineManager):
|
||||
)
|
||||
|
||||
def update_online_pred(self):
|
||||
"""update all online model predictions to the latest day in Calendar"""
|
||||
"""
|
||||
Update all online model predictions to the latest day in Calendar
|
||||
"""
|
||||
online_models = self.online_models()
|
||||
for rec in online_models:
|
||||
PredUpdater(rec, to_date=self.cur_time, need_log=self.need_log).update()
|
||||
@@ -160,18 +221,39 @@ class OnlineManagerR(OnlineManager):
|
||||
class RollingOnlineManager(OnlineManagerR):
|
||||
"""An implementation of OnlineManager based on Rolling."""
|
||||
|
||||
def __init__(self, experiment_name: str, rolling_gen: RollingGen, trainer: Trainer = None, need_log=True):
|
||||
def __init__(
|
||||
self,
|
||||
experiment_name: str,
|
||||
rolling_gen: RollingGen,
|
||||
trainer: Trainer = None,
|
||||
collector: Collector = None,
|
||||
need_log=True,
|
||||
):
|
||||
"""
|
||||
init RollingOnlineManager.
|
||||
|
||||
Args:
|
||||
experiment_name (str): the experiment name.
|
||||
rolling_gen (RollingGen): a instance of RollingGen
|
||||
trainer (Trainer, optional): a instance of Trainer. Defaults to None.
|
||||
collector (Collector, optional): a instance of Collector. Defaults to None.
|
||||
need_log (bool, optional): print log or not. Defaults to True.
|
||||
"""
|
||||
trainer = TrainerR(experiment_name)
|
||||
super().__init__(experiment_name, trainer, need_log=need_log)
|
||||
super().__init__(experiment_name=experiment_name, trainer=trainer, collector=collector, need_log=need_log)
|
||||
self.ta = TimeAdjuster()
|
||||
self.rg = rolling_gen
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
|
||||
def prepare_signals(self, *args, **kwargs):
|
||||
"""
|
||||
Must use `pass` even though there is nothing to do.
|
||||
"""
|
||||
pass
|
||||
|
||||
def prepare_tasks(self, *args, **kwargs):
|
||||
"""prepare new tasks based on new date.
|
||||
"""
|
||||
Prepare new tasks based on new date.
|
||||
|
||||
Returns:
|
||||
list: a list of new tasks.
|
||||
@@ -184,7 +266,11 @@ class RollingOnlineManager(OnlineManagerR):
|
||||
self.logger.warn(f"No latest online recorders, no new tasks.")
|
||||
return []
|
||||
calendar_latest = self.ta.last_date() if self.cur_time is None else self.cur_time
|
||||
if self.ta.cal_interval(calendar_latest, max_test[0]) > self.rg.step:
|
||||
if self.need_log:
|
||||
self.logger.info(
|
||||
f"The interval between current time and last rolling test begin time is {self.ta.cal_interval(calendar_latest, max_test[0])}, the rolling step is {self.rg.step}"
|
||||
)
|
||||
if self.ta.cal_interval(calendar_latest, max_test[0]) >= self.rg.step:
|
||||
old_tasks = []
|
||||
tasks_tmp = []
|
||||
for rid, rec in latest_records.items():
|
||||
|
||||
@@ -1,12 +1,6 @@
|
||||
from typing import Callable
|
||||
import pandas as pd
|
||||
from qlib.config import C
|
||||
from qlib.data import D
|
||||
from qlib import get_module_logger
|
||||
from qlib.log import set_log_with_config
|
||||
from qlib.model.ens.ensemble import ens_workflow
|
||||
from qlib.workflow.online.manager import OnlineManager
|
||||
from qlib.workflow.task.collect import Collector
|
||||
|
||||
|
||||
class OnlineSimulator:
|
||||
@@ -20,21 +14,24 @@ class OnlineSimulator:
|
||||
end_time,
|
||||
onlinemanager: OnlineManager,
|
||||
frequency="day",
|
||||
time_delta="20 hours",
|
||||
collector: Collector = None,
|
||||
process_list: list = None,
|
||||
):
|
||||
"""
|
||||
init OnlineSimulator.
|
||||
|
||||
Args:
|
||||
start_time (str or pd.Timestamp): the start time of simulating.
|
||||
end_time (str or pd.Timestamp): the end time of simulating. If None, then end_time is latest.
|
||||
onlinemanager (OnlineManager): the instance of OnlineManager
|
||||
frequency (str, optional): the data frequency. Defaults to "day".
|
||||
"""
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
self.cal = D.calendar(start_time=start_time, end_time=end_time, freq=frequency)
|
||||
self.start_time = self.cal[0]
|
||||
self.end_time = self.cal[-1]
|
||||
self.olm = onlinemanager
|
||||
self.time_delta = time_delta
|
||||
|
||||
if len(self.cal) == 0:
|
||||
self.logger.warn(f"There is no need to simulate bacause start_time is larger than end_time.")
|
||||
self.collector = collector
|
||||
self.process_list = process_list
|
||||
|
||||
def simulate(self, *args, **kwargs):
|
||||
"""
|
||||
@@ -42,14 +39,13 @@ class OnlineSimulator:
|
||||
NOTE: Considering the parallel training, the signals will be perpared after all routine simulating.
|
||||
|
||||
Returns:
|
||||
dict: the simulated results collected by collector
|
||||
Collector: the OnlineManager's collector
|
||||
"""
|
||||
self.rec_dict = {}
|
||||
tmp_begin = self.start_time
|
||||
tmp_end = None
|
||||
prev_recorders = self.olm.online_models()
|
||||
for cur_time in self.cal:
|
||||
cur_time = cur_time + pd.Timedelta(self.time_delta)
|
||||
self.logger.info(f"Simulating at {str(cur_time)}......")
|
||||
recorders = self.olm.routine(cur_time, True, *args, **kwargs)
|
||||
if len(recorders) == 0:
|
||||
@@ -64,8 +60,7 @@ class OnlineSimulator:
|
||||
self.olm.run_delay_signals()
|
||||
self.logger.info(f"Finished preparing signals")
|
||||
|
||||
if self.collector is not None:
|
||||
return ens_workflow(self.collector, self.process_list)
|
||||
return self.olm.get_collector()
|
||||
|
||||
def online_models(self):
|
||||
"""
|
||||
|
||||
@@ -121,6 +121,7 @@ class PredUpdater(RecordUpdater):
|
||||
# FIXME: the problme below is not solved
|
||||
# The model dumped on GPU instances can not be loaded on CPU instance. Follow exception will raised
|
||||
# RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
|
||||
# https://github.com/pytorch/pytorch/issues/16797
|
||||
|
||||
start_time = get_date_by_shift(self.last_end, 1, freq=self.freq)
|
||||
if start_time >= self.to_date:
|
||||
@@ -136,7 +137,7 @@ class PredUpdater(RecordUpdater):
|
||||
# Load model
|
||||
model = self.rmdl.get_model()
|
||||
|
||||
new_pred = model.predict(dataset)
|
||||
new_pred: pd.Series = model.predict(dataset)
|
||||
|
||||
cb_pred = pd.concat([self.old_pred, new_pred.to_frame("score")], axis=0)
|
||||
cb_pred = cb_pred.sort_index()
|
||||
|
||||
@@ -168,7 +168,7 @@ class RollingGen(TaskGen):
|
||||
if prev_seg is None:
|
||||
# First rolling
|
||||
# 1) prepare the end point
|
||||
segments = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"]))
|
||||
segments: dict = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"]))
|
||||
test_end = self.ta.last_date() if segments[self.test_key][1] is None else segments[self.test_key][1]
|
||||
# 2) and init test segments
|
||||
test_start_idx = self.ta.align_idx(segments[self.test_key][0])
|
||||
|
||||
@@ -12,6 +12,7 @@ import pickle
|
||||
from pymongo.errors import InvalidDocument
|
||||
from bson.objectid import ObjectId
|
||||
from contextlib import contextmanager
|
||||
import qlib
|
||||
from tqdm.cli import tqdm
|
||||
import time
|
||||
import concurrent
|
||||
@@ -65,6 +66,12 @@ class TaskManager:
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
|
||||
def list(self):
|
||||
"""
|
||||
list the all collection(task_pool) of the db
|
||||
|
||||
Returns:
|
||||
list
|
||||
"""
|
||||
return self.mdb.list_collection_names()
|
||||
|
||||
def _encode_task(self, task):
|
||||
@@ -257,9 +264,6 @@ class TaskManager:
|
||||
query: dict
|
||||
the dict of query
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
"""
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
|
||||
@@ -15,10 +15,21 @@ def get_mongodb():
|
||||
|
||||
get database in MongoDB, which means you need to declare the address and the name of database.
|
||||
for example:
|
||||
C["mongo"] = {
|
||||
"task_url" : "mongodb://localhost:27017/",
|
||||
"task_db_name" : "rolling_db"
|
||||
}
|
||||
|
||||
Using qlib.init():
|
||||
|
||||
mongo_conf = {
|
||||
"task_url": task_url, # your MongoDB url
|
||||
"task_db_name": task_db_name, # database name
|
||||
}
|
||||
qlib.init(..., mongo=mongo_conf)
|
||||
|
||||
After qlib.init():
|
||||
|
||||
C["mongo"] = {
|
||||
"task_url" : "mongodb://localhost:27017/",
|
||||
"task_db_name" : "rolling_db"
|
||||
}
|
||||
|
||||
"""
|
||||
try:
|
||||
@@ -113,6 +124,16 @@ class TimeAdjuster:
|
||||
return idx
|
||||
|
||||
def cal_interval(self, time_point_A, time_point_B):
|
||||
"""
|
||||
calculate the trading day interval
|
||||
|
||||
Args:
|
||||
time_point_A : time_point_A
|
||||
time_point_B : time_point_B (is the past of time_point_A)
|
||||
|
||||
Returns:
|
||||
int: the interval between A and B
|
||||
"""
|
||||
return self.align_idx(time_point_A) - self.align_idx(time_point_B)
|
||||
|
||||
def align_time(self, time_point, tp_type="start"):
|
||||
|
||||
Reference in New Issue
Block a user