1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-02 18:40:58 +08:00

online serving V7

This commit is contained in:
lzh222333
2021-04-16 05:37:13 +00:00
parent 5095b2a470
commit cec318fbfe
12 changed files with 370 additions and 225 deletions

View File

@@ -1,9 +1,10 @@
from pprint import pprint
import time
import fire
import qlib
from qlib.config import REG_CN
from qlib.model.trainer import task_train
from qlib.model.trainer import TrainerR, task_train
from qlib.workflow import R
from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.manage import TaskManager, run_task
@@ -102,7 +103,7 @@ def task_training(tasks, task_pool, exp_name):
# This part corresponds to "Task Collecting" in the document
def task_collecting(task_pool, exp_name):
def task_collecting(exp_name):
print("========== task_collecting ==========")
def rec_key(recorder):
@@ -141,7 +142,7 @@ def main(
reset(task_pool, experiment_name)
tasks = task_generating()
task_training(tasks, task_pool, experiment_name)
task_collecting(task_pool, experiment_name)
task_collecting(experiment_name)
if __name__ == "__main__":

View File

@@ -0,0 +1,198 @@
import fire
import qlib
from qlib.model.ens.ensemble import ens_workflow
from qlib.model.ens.group import RollingGroup
from qlib.model.trainer import TrainerRM
from qlib.workflow import R
from qlib.workflow.online.manager import RollingOnlineManager
from qlib.workflow.online.simulator import OnlineSimulator
from qlib.workflow.task.collect import RecorderCollector
from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.manage import TaskManager
"""
This examples is about the OnlineManager and OnlineSimulator based on Rolling tasks.
The OnlineManager will focus on the updating of your online models.
The OnlineSimulator will focus on the simulating real updating routine of your online models.
"""
data_handler_config = {
"start_time": "2018-01-01",
"end_time": None, # "2018-10-31",
"fit_start_time": "2018-01-01",
"fit_end_time": "2018-03-31",
"instruments": "csi100",
}
dataset_config = {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": data_handler_config,
},
"segments": {
"train": ("2018-01-01", "2018-03-31"),
"valid": ("2018-04-01", "2018-05-31"),
"test": ("2018-06-01", "2018-09-10"),
},
},
}
record_config = [
{
"class": "SignalRecord",
"module_path": "qlib.workflow.record_temp",
},
{
"class": "SigAnaRecord",
"module_path": "qlib.workflow.record_temp",
},
]
# use lgb model
task_lgb_config = {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
},
"dataset": dataset_config,
"record": record_config,
}
# use xgboost model
task_xgboost_config = {
"model": {
"class": "XGBModel",
"module_path": "qlib.contrib.model.xgboost",
},
"dataset": dataset_config,
"record": record_config,
}
class OnlineManagerExample:
def __init__(
self,
provider_uri="~/.qlib/qlib_data/cn_data",
region="cn",
exp_name="rolling_exp",
task_url="mongodb://10.0.0.4:27017/",
task_db_name="rolling_db",
task_pool="rolling_task",
rolling_step=80,
start_time="2018-09-10",
end_time="2018-10-31",
):
"""
init OnlineManagerExample.
Args:
provider_uri (str, optional): the provider uri. Defaults to "~/.qlib/qlib_data/cn_data".
region (str, optional): the stock region. Defaults to "cn".
exp_name (str, optional): the experiment name. Defaults to "rolling_exp".
task_url (str, optional): your MongoDB url. Defaults to "mongodb://10.0.0.4:27017/".
task_db_name (str, optional): database name. Defaults to "rolling_db".
task_pool (str, optional): the task pool name (a task pool is a collection in MongoDB). Defaults to "rolling_task".
rolling_step (int, optional): the step for rolling. Defaults to 80.
start_time (str, optional): the start time of simulating. Defaults to "2018-09-10".
end_time (str, optional): the end time of simulating. Defaults to "2018-10-31".
"""
self.exp_name = exp_name
self.task_pool = task_pool
mongo_conf = {
"task_url": task_url,
"task_db_name": task_db_name,
}
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
self.rolling_gen = RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD) # The rolling tasks generator
self.trainer = TrainerRM(self.exp_name, self.task_pool) # The trainer based on (R)ecorder and Task(M)anager
self.task_manager = TaskManager(self.task_pool) # A good way to manage all your tasks
self.collector = RecorderCollector(exp_name=self.exp_name, rec_key_func=self.rec_key) # The result collector
self.grouper = RollingGroup() # Divide your results into different rolling group
self.rolling_online_manager = RollingOnlineManager(
experiment_name=exp_name,
rolling_gen=self.rolling_gen,
trainer=self.trainer,
collector=self.collector,
need_log=False,
) # The OnlineManager based on Rolling
self.onlinesimulator = OnlineSimulator(
start_time=start_time,
end_time=end_time,
onlinemanager=self.rolling_online_manager,
)
# Reset all things to the first status, be careful to save important data
def reset(self):
print("========== reset ==========")
self.task_manager.remove()
exp = R.get_exp(experiment_name=self.exp_name)
for rid in exp.list_recorders():
exp.delete_recorder(rid)
@staticmethod
def rec_key(recorder):
"""
given a Recorder and return its key to identify it
Args:
recorder (Recorder): a instance of the Recorder
Returns:
tuple: (model_key, rolling_key)
"""
task_config = recorder.load_object("task")
model_key = task_config["model"]["class"]
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
return model_key, rolling_key
def result_collecting(self):
print("========== result collecting ==========")
# ens_workflow can help collect, group and ensemble results in a easy way
artifact = ens_workflow(self.rolling_online_manager.get_collector(), self.grouper)
print(artifact)
# Run this firstly to see the workflow in OnlineManager
def first_train(self):
print("========== first train ==========")
self.reset()
tasks = task_generator(
tasks=[task_xgboost_config, task_lgb_config],
generators=[self.rolling_gen], # generate different date segment
)
self.rolling_online_manager.prepare_new_models(tasks=tasks, tag=RollingOnlineManager.ONLINE_TAG)
self.result_collecting()
# Run this secondly to see the simulating in OnlineSimulator
def simulate(self):
print("========== simulate ==========")
self.onlinesimulator.simulate()
self.result_collecting()
print("========== online models ==========")
recs_dict = self.onlinesimulator.online_models()
for time, recs in recs_dict.items():
print(f"{str(time[0])} to {str(time[1])}:")
for rec in recs:
print(rec.info["id"])
# Run this to run all workflow automaticly
def main(self):
self.first_train()
self.simulate()
if __name__ == "__main__":
## to run all workflow automaticly with your own parameters, use the command below
# python online_management_simulate.py main --experiment_name="your_exp_name" --rolling_step=60
fire.Fire(OnlineManagerExample)

View File

@@ -1,163 +0,0 @@
from abc import abstractmethod
import copy
from pprint import pprint
import fire
import qlib
from qlib.config import REG_CN
from qlib.model.trainer import task_train
from qlib.workflow import R
from qlib.workflow.task.gen import TaskGen
from qlib.workflow.online.simulator import OnlineSimulator
from qlib.workflow.task.collect import RecorderCollector
from qlib.model.ens.ensemble import RollingEnsemble, ens_workflow
from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.manage import TaskManager, run_task
from qlib.workflow.online.manager import RollingOnlineManager
from qlib.workflow.task.utils import TimeAdjuster, list_recorders
from qlib.model.trainer import TrainerRM
from qlib.model.ens.group import RollingGroup
data_handler_config = {
"start_time": "2018-01-01",
"end_time": "2018-10-31",
"fit_start_time": "2018-01-01",
"fit_end_time": "2018-03-31",
"instruments": "csi100",
}
dataset_config = {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": data_handler_config,
},
"segments": {
"train": ("2018-01-01", "2018-03-31"),
"valid": ("2018-04-01", "2018-05-31"),
"test": ("2018-06-01", "2018-09-10"),
},
},
}
record_config = [
{
"class": "SignalRecord",
"module_path": "qlib.workflow.record_temp",
},
{
"class": "SigAnaRecord",
"module_path": "qlib.workflow.record_temp",
},
]
# use lgb model
task_lgb_config = {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
},
"dataset": dataset_config,
"record": record_config,
}
# use xgboost model
task_xgboost_config = {
"model": {
"class": "XGBModel",
"module_path": "qlib.contrib.model.xgboost",
},
"dataset": dataset_config,
"record": record_config,
}
class OnlineSimulatorExample:
def __init__(
self,
exp_name="rolling_exp",
task_pool="rolling_task",
provider_uri="~/.qlib/qlib_data/cn_data",
region="cn",
task_url="mongodb://10.0.0.4:27017/",
task_db_name="rolling_db",
rolling_step=80,
):
self.exp_name = exp_name
self.task_pool = task_pool
mongo_conf = {
"task_url": task_url, # your MongoDB url
"task_db_name": task_db_name, # database name
}
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
self.rolling_gen = RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD)
self.trainer = TrainerRM(self.exp_name, self.task_pool)
self.task_manager = TaskManager(self.task_pool)
self.rolling_online_manager = RollingOnlineManager(
experiment_name=exp_name, rolling_gen=self.rolling_gen, trainer=self.trainer, need_log=False
)
# Reset all things to the first status, be careful to save important data
def reset(self):
print("========== reset ==========")
self.task_manager.remove()
exp = R.get_exp(experiment_name=self.exp_name)
for rid in exp.list_recorders():
exp.delete_recorder(rid)
@staticmethod
def rec_key(recorder):
task_config = recorder.load_object("task")
model_key = task_config["model"]["class"]
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
return model_key, rolling_key
# Run this firstly to see the workflow in Task Management
def first_run(self):
print("========== first_run ==========")
self.reset()
tasks = task_generator(
tasks=task_xgboost_config,
generators=[self.rolling_gen], # generate different date segment
)
pprint(tasks)
self.trainer.train(tasks)
print("========== task collecting ==========")
artifact = ens_workflow(RecorderCollector(exp_name=self.exp_name, rec_key_func=self.rec_key), RollingGroup())
print(artifact)
latest_rec, _ = self.rolling_online_manager.list_latest_recorders()
self.rolling_online_manager.set_online_tag(RollingOnlineManager.ONLINE_TAG, list(latest_rec.values()))
def simulate(self):
print("========== simulate ==========")
onlinesimulator = OnlineSimulator(
start_time="2018-09-10",
end_time="2018-10-31",
onlinemanager=self.rolling_online_manager,
collector=RecorderCollector(exp_name=self.exp_name, rec_key_func=self.rec_key),
process_list=[RollingGroup()],
)
results = onlinesimulator.simulate()
print(results)
recs_dict = onlinesimulator.online_models()
for time, recs in recs_dict.items():
print(f"{str(time[0])} to {str(time[1])}:")
for rec in recs:
print(rec.info["id"])
if __name__ == "__main__":
ose = OnlineSimulatorExample()
ose.first_run()
ose.simulate()

View File

@@ -123,7 +123,8 @@ class RollingOnlineExample:
return tasks
def task_training(self, tasks):
self.trainer.train(tasks)
# self.trainer.train(tasks)
self.rolling_online_manager.prepare_new_models(tasks, tag=RollingOnlineManager.ONLINE_TAG)
# This part corresponds to "Task Collecting" in the document
def task_collecting(self):
@@ -165,10 +166,8 @@ class RollingOnlineExample:
self.task_training(tasks)
self.task_collecting()
latest_rec, _ = self.rolling_online_manager.list_latest_recorders()
self.rolling_online_manager.reset_online_tag(list(latest_rec.values()))
self.routine()
# latest_rec, _ = self.rolling_online_manager.list_latest_recorders()
# self.rolling_online_manager.reset_online_tag(list(latest_rec.values()))
def routine(self):
print("========== routine ==========")
@@ -177,6 +176,10 @@ class RollingOnlineExample:
self.print_online_model()
self.task_collecting()
def main(self):
self.first_run()
self.routine()
if __name__ == "__main__":
####### to train the first version's models, use the command below

View File

@@ -488,7 +488,7 @@ class TSDatasetH(DatasetH):
"""
split the _prepare_raw_seg is to leave a hook for data preprocessing before creating processing data
"""
dtype = kwargs.pop("dtype")
dtype = kwargs.pop("dtype", None)
start, end = slc.start, slc.stop
data = self._prepare_raw_seg(slc=slc, **kwargs)
tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len, dtype=dtype)

View File

@@ -26,7 +26,6 @@ def task_train(task_config: dict, experiment_name: str) -> Recorder:
----------
Recorder : The instance of the recorder
"""
# model initiaiton
model: Model = init_instance_by_config(task_config["model"])
dataset: Dataset = init_instance_by_config(task_config["dataset"])
@@ -46,7 +45,7 @@ def task_train(task_config: dict, experiment_name: str) -> Recorder:
# generate records: prediction, backtest, and analysis
records = task_config.get("record", [])
recorder = R.get_recorder()
recorder: Recorder = R.get_recorder()
if isinstance(records, dict): # prevent only one dict
records = [records]
for record in records:

View File

@@ -4,6 +4,7 @@ from qlib.workflow import R
from qlib.model.trainer import task_train
from qlib.workflow.recorder import MLflowRecorder, Recorder
from qlib.workflow.online.update import PredUpdater, RecordUpdater
from qlib.workflow.task.collect import Collector
from qlib.workflow.task.utils import TimeAdjuster
from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.manage import TaskManager
@@ -14,78 +15,127 @@ from qlib.model.trainer import Trainer, TrainerR
from copy import deepcopy
class OnlineManager(Serializable):
class OnlineManager:
ONLINE_KEY = "online_status" # the online status key in recorder
ONLINE_TAG = "online" # the 'online' model
# NOTE: The meaning of this tag is that we can not assume the training models can be trained before we need its predition. Whenever finished training, it can be guaranteed that there are some online models.
NEXT_ONLINE_TAG = "next_online" # the 'next online' model, which can be 'online' model when call reset_online_model
OFFLINE_TAG = "offline" # the 'offline' model, not for online serving
def __init__(self, trainer: Trainer = None, need_log=True):
self._trainer = trainer
def __init__(self, trainer: Trainer = None, collector: Collector = None, need_log=True):
"""
init OnlineManager.
Args:
trainer (Trainer, optional): a instance of Trainer. Defaults to None.
collector (Collector, optional): a instance of Collector. Defaults to None.
need_log (bool, optional): print log or not. Defaults to True.
"""
self.trainer = trainer
self.logger = get_module_logger(self.__class__.__name__)
self.need_log = need_log
self.delay_signals = {}
self.collector = collector
self.cur_time = None
def prepare_signals(self, *args, **kwargs):
"""
After perparing the data of last routine (a box in box-plot) which means the end of the routine, we can prepare trading signals for next routine.
"""
raise NotImplementedError(f"Please implement the `prepare_signals` method.")
def prepare_tasks(self, *args, **kwargs):
"""return the new tasks waiting for training."""
"""
After the end of a routine, check whether we need to prepare and train some new tasks.
return the new tasks waiting for training.
"""
raise NotImplementedError(f"Please implement the `prepare_tasks` method.")
def prepare_new_models(self, tasks):
"""Use trainer to train a list of tasks and set the trained model to next_online.
def prepare_new_models(self, tasks, tag=NEXT_ONLINE_TAG):
"""
Use trainer to train a list of tasks and set the trained model to `tag`.
Args:
tasks (list): a list of tasks.
tag (str):
`ONLINE_TAG` for first train or additional train
`NEXT_ONLINE_TAG` for reset online model when calling `reset_online_tag`
`OFFLINE_TAG` for train but offline those models
"""
if not (tasks is None or len(tasks) == 0):
if self._trainer is not None:
new_models = self._trainer.train(tasks)
self.set_online_tag(self.NEXT_ONLINE_TAG, new_models)
self.logger.info(
f"Finished prepare {len(new_models)} new models and set them to `{self.NEXT_ONLINE_TAG}`."
)
if self.trainer is not None:
new_models = self.trainer.train(tasks)
self.set_online_tag(tag, new_models)
if self.need_log:
self.logger.info(f"Finished prepare {len(new_models)} new models and set them to {tag}.")
else:
self.logger.warn("No trainer to train new tasks.")
def update_online_pred(self, *args, **kwargs):
"""
After the end of a routine, update the predictions of online models to latest.
"""
raise NotImplementedError(f"Please implement the `update_online_pred` method.")
def set_online_tag(self, tag, *args, **kwargs):
"""set `tag` to the model to sign whether online
"""
Set `tag` to the model to sign whether online.
Args:
tag (str): the tags in ONLINE_TAG, NEXT_ONLINE_TAG, OFFLINE_TAG
tag (str): the tags in `ONLINE_TAG`, `NEXT_ONLINE_TAG`, `OFFLINE_TAG`
"""
raise NotImplementedError(f"Please implement the `set_online_tag` method.")
def get_online_tag(self, *args, **kwargs):
"""given a model and return its online tag"""
"""
Given a model and return its online tag.
"""
raise NotImplementedError(f"Please implement the `get_online_tag` method.")
def reset_online_tag(self, *args, **kwargs):
"""offline all models and set the recorders to 'online'. If no parameter and no 'next online' model, then do nothing."""
"""
Offline all models and set the models to 'online'.
"""
raise NotImplementedError(f"Please implement the `reset_online_tag` method.")
def online_models(self):
"""return online models"""
"""
Return online models.
"""
raise NotImplementedError(f"Please implement the `online_models` method.")
def get_collector(self):
"""
Return the collector.
Returns:
Collector
"""
return self.collector
def run_delay_signals(self):
"""
Prepare all signals if there are some dates waiting for prepare.
"""
for cur_time, params in self.delay_signals.items():
self.cur_time = cur_time
self.prepare_signals(*params[0], **params[1])
self.delay_signals = {}
def routine(self, cur_time=None, delay_prepare=False, *args, **kwargs):
"""The typical update process in a routine such as day by day or month by month"""
"""
The typical update process after a routine, such as day by day or month by month.
Prepare signals -> prepare tasks -> prepare new models -> update online prediction -> reset online models
"""
self.cur_time = cur_time # None for latest date
if not delay_prepare:
self.prepare_signals(*args, **kwargs)
else:
self.delay_signals[cur_time] = (args, kwargs)
if cur_time is not None:
self.delay_signals[cur_time] = (args, kwargs)
else:
raise ValueError("Can not delay prepare when cur_time is None")
tasks = self.prepare_tasks(*args, **kwargs)
self.prepare_new_models(tasks)
self.update_online_pred()
@@ -98,9 +148,18 @@ class OnlineManagerR(OnlineManager):
"""
def __init__(self, experiment_name: str, trainer: Trainer = None, need_log=True):
def __init__(self, experiment_name: str, trainer: Trainer = None, collector: Collector = None, need_log=True):
"""
init OnlineManagerR.
Args:
experiment_name (str): the experiment name.
trainer (Trainer, optional): a instance of Trainer. Defaults to None.
collector (Collector, optional): a instance of Collector. Defaults to None.
need_log (bool, optional): print log or not. Defaults to True.
"""
trainer = TrainerR(experiment_name)
super().__init__(trainer, need_log)
super().__init__(trainer=trainer, collector=collector, need_log=need_log)
self.exp_name = experiment_name
def set_online_tag(self, tag, recorder: Union[Recorder, List]):
@@ -148,7 +207,9 @@ class OnlineManagerR(OnlineManager):
)
def update_online_pred(self):
"""update all online model predictions to the latest day in Calendar"""
"""
Update all online model predictions to the latest day in Calendar
"""
online_models = self.online_models()
for rec in online_models:
PredUpdater(rec, to_date=self.cur_time, need_log=self.need_log).update()
@@ -160,18 +221,39 @@ class OnlineManagerR(OnlineManager):
class RollingOnlineManager(OnlineManagerR):
"""An implementation of OnlineManager based on Rolling."""
def __init__(self, experiment_name: str, rolling_gen: RollingGen, trainer: Trainer = None, need_log=True):
def __init__(
self,
experiment_name: str,
rolling_gen: RollingGen,
trainer: Trainer = None,
collector: Collector = None,
need_log=True,
):
"""
init RollingOnlineManager.
Args:
experiment_name (str): the experiment name.
rolling_gen (RollingGen): a instance of RollingGen
trainer (Trainer, optional): a instance of Trainer. Defaults to None.
collector (Collector, optional): a instance of Collector. Defaults to None.
need_log (bool, optional): print log or not. Defaults to True.
"""
trainer = TrainerR(experiment_name)
super().__init__(experiment_name, trainer, need_log=need_log)
super().__init__(experiment_name=experiment_name, trainer=trainer, collector=collector, need_log=need_log)
self.ta = TimeAdjuster()
self.rg = rolling_gen
self.logger = get_module_logger(self.__class__.__name__)
def prepare_signals(self, *args, **kwargs):
"""
Must use `pass` even though there is nothing to do.
"""
pass
def prepare_tasks(self, *args, **kwargs):
"""prepare new tasks based on new date.
"""
Prepare new tasks based on new date.
Returns:
list: a list of new tasks.
@@ -184,7 +266,11 @@ class RollingOnlineManager(OnlineManagerR):
self.logger.warn(f"No latest online recorders, no new tasks.")
return []
calendar_latest = self.ta.last_date() if self.cur_time is None else self.cur_time
if self.ta.cal_interval(calendar_latest, max_test[0]) > self.rg.step:
if self.need_log:
self.logger.info(
f"The interval between current time and last rolling test begin time is {self.ta.cal_interval(calendar_latest, max_test[0])}, the rolling step is {self.rg.step}"
)
if self.ta.cal_interval(calendar_latest, max_test[0]) >= self.rg.step:
old_tasks = []
tasks_tmp = []
for rid, rec in latest_records.items():

View File

@@ -1,12 +1,6 @@
from typing import Callable
import pandas as pd
from qlib.config import C
from qlib.data import D
from qlib import get_module_logger
from qlib.log import set_log_with_config
from qlib.model.ens.ensemble import ens_workflow
from qlib.workflow.online.manager import OnlineManager
from qlib.workflow.task.collect import Collector
class OnlineSimulator:
@@ -20,21 +14,24 @@ class OnlineSimulator:
end_time,
onlinemanager: OnlineManager,
frequency="day",
time_delta="20 hours",
collector: Collector = None,
process_list: list = None,
):
"""
init OnlineSimulator.
Args:
start_time (str or pd.Timestamp): the start time of simulating.
end_time (str or pd.Timestamp): the end time of simulating. If None, then end_time is latest.
onlinemanager (OnlineManager): the instance of OnlineManager
frequency (str, optional): the data frequency. Defaults to "day".
"""
self.logger = get_module_logger(self.__class__.__name__)
self.cal = D.calendar(start_time=start_time, end_time=end_time, freq=frequency)
self.start_time = self.cal[0]
self.end_time = self.cal[-1]
self.olm = onlinemanager
self.time_delta = time_delta
if len(self.cal) == 0:
self.logger.warn(f"There is no need to simulate bacause start_time is larger than end_time.")
self.collector = collector
self.process_list = process_list
def simulate(self, *args, **kwargs):
"""
@@ -42,14 +39,13 @@ class OnlineSimulator:
NOTE: Considering the parallel training, the signals will be perpared after all routine simulating.
Returns:
dict: the simulated results collected by collector
Collector: the OnlineManager's collector
"""
self.rec_dict = {}
tmp_begin = self.start_time
tmp_end = None
prev_recorders = self.olm.online_models()
for cur_time in self.cal:
cur_time = cur_time + pd.Timedelta(self.time_delta)
self.logger.info(f"Simulating at {str(cur_time)}......")
recorders = self.olm.routine(cur_time, True, *args, **kwargs)
if len(recorders) == 0:
@@ -64,8 +60,7 @@ class OnlineSimulator:
self.olm.run_delay_signals()
self.logger.info(f"Finished preparing signals")
if self.collector is not None:
return ens_workflow(self.collector, self.process_list)
return self.olm.get_collector()
def online_models(self):
"""

View File

@@ -121,6 +121,7 @@ class PredUpdater(RecordUpdater):
# FIXME: the problme below is not solved
# The model dumped on GPU instances can not be loaded on CPU instance. Follow exception will raised
# RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
# https://github.com/pytorch/pytorch/issues/16797
start_time = get_date_by_shift(self.last_end, 1, freq=self.freq)
if start_time >= self.to_date:
@@ -136,7 +137,7 @@ class PredUpdater(RecordUpdater):
# Load model
model = self.rmdl.get_model()
new_pred = model.predict(dataset)
new_pred: pd.Series = model.predict(dataset)
cb_pred = pd.concat([self.old_pred, new_pred.to_frame("score")], axis=0)
cb_pred = cb_pred.sort_index()

View File

@@ -168,7 +168,7 @@ class RollingGen(TaskGen):
if prev_seg is None:
# First rolling
# 1) prepare the end point
segments = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"]))
segments: dict = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"]))
test_end = self.ta.last_date() if segments[self.test_key][1] is None else segments[self.test_key][1]
# 2) and init test segments
test_start_idx = self.ta.align_idx(segments[self.test_key][0])

View File

@@ -12,6 +12,7 @@ import pickle
from pymongo.errors import InvalidDocument
from bson.objectid import ObjectId
from contextlib import contextmanager
import qlib
from tqdm.cli import tqdm
import time
import concurrent
@@ -65,6 +66,12 @@ class TaskManager:
self.logger = get_module_logger(self.__class__.__name__)
def list(self):
"""
list the all collection(task_pool) of the db
Returns:
list
"""
return self.mdb.list_collection_names()
def _encode_task(self, task):
@@ -257,9 +264,6 @@ class TaskManager:
query: dict
the dict of query
Returns
-------
"""
query = query.copy()
if "_id" in query:

View File

@@ -15,10 +15,21 @@ def get_mongodb():
get database in MongoDB, which means you need to declare the address and the name of database.
for example:
C["mongo"] = {
"task_url" : "mongodb://localhost:27017/",
"task_db_name" : "rolling_db"
}
Using qlib.init():
mongo_conf = {
"task_url": task_url, # your MongoDB url
"task_db_name": task_db_name, # database name
}
qlib.init(..., mongo=mongo_conf)
After qlib.init():
C["mongo"] = {
"task_url" : "mongodb://localhost:27017/",
"task_db_name" : "rolling_db"
}
"""
try:
@@ -113,6 +124,16 @@ class TimeAdjuster:
return idx
def cal_interval(self, time_point_A, time_point_B):
"""
calculate the trading day interval
Args:
time_point_A : time_point_A
time_point_B : time_point_B (is the past of time_point_A)
Returns:
int: the interval between A and B
"""
return self.align_idx(time_point_A) - self.align_idx(time_point_B)
def align_time(self, time_point, tp_type="start"):