1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-01 18:11:18 +08:00

refactor online serving rolling api

This commit is contained in:
Young
2021-07-29 04:40:27 +00:00
committed by you-n-g
parent 05d28469ad
commit 9303415666
6 changed files with 81 additions and 54 deletions

View File

@@ -78,4 +78,4 @@ task:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
config: *port_analysis_config

View File

@@ -0,0 +1 @@
xgboost

View File

@@ -570,9 +570,11 @@ def get_pre_trading_date(trading_date, future=False):
def transform_end_date(end_date=None, freq="day"):
"""get previous trading date
"""handle the end date with various format
If end_date is -1, None, or end_date is greater than the maximum trading day, the last trading date is returned.
Otherwise, returns the end_date
----------
end_date: str
end trading date

View File

@@ -10,6 +10,7 @@ from typing import List, Tuple, Union
from qlib.data.data import D
from qlib.log import get_module_logger
from qlib.model.ens.group import RollingGroup
from qlib.utils import transform_end_date
from qlib.workflow.online.utils import OnlineTool, OnlineToolR
from qlib.workflow.recorder import Recorder
from qlib.workflow.task.collect import Collector, RecorderCollector
@@ -118,6 +119,7 @@ class RollingStrategy(OnlineStrategy):
task_template = [task_template]
self.task_template = task_template
self.rg = rolling_gen
assert issubclass(self.rg.__class__, RollingGen), "The rolling strategy relies on the feature if RollingGen"
self.tool = OnlineToolR(self.exp_name)
self.ta = TimeAdjuster()
@@ -174,28 +176,20 @@ class RollingStrategy(OnlineStrategy):
Returns:
List[dict]: a list of new tasks.
"""
# TODO: filter recorders by latest test segments is not a necessary
latest_records, max_test = self._list_latest(self.tool.online_models())
if max_test is None:
self.logger.warn(f"No latest online recorders, no new tasks.")
return []
calendar_latest = D.calendar(end_time=cur_time)[-1] if cur_time is None else cur_time
calendar_latest = transform_end_date(cur_time)
self.logger.info(
f"The interval between current time {calendar_latest} and last rolling test begin time {max_test[0]} is {self.ta.cal_interval(calendar_latest, max_test[0])}, the rolling step is {self.rg.step}"
)
if self.ta.cal_interval(calendar_latest, max_test[0]) >= self.rg.step:
old_tasks = []
tasks_tmp = []
for rec in latest_records:
task = rec.load_object("task")
old_tasks.append(deepcopy(task))
test_begin = task["dataset"]["kwargs"]["segments"]["test"][0]
# modify the test segment to generate new tasks
task["dataset"]["kwargs"]["segments"]["test"] = (test_begin, calendar_latest)
tasks_tmp.append(task)
new_tasks_tmp = task_generator(tasks_tmp, self.rg)
new_tasks = [task for task in new_tasks_tmp if task not in old_tasks]
return new_tasks
return []
res = []
for rec in latest_records:
task = rec.load_object("task")
res.extend(self.rg.gen_following_tasks(task, calendar_latest))
return res
def _list_latest(self, rec_list: List[Recorder]):
"""

View File

@@ -105,6 +105,8 @@ class PredUpdater(RecordUpdater):
if to_date == None:
to_date = D.calendar(freq=freq)[-1]
self.to_date = pd.Timestamp(to_date)
# FIXME: it will raise error when running routine with delay trainer
# should we use another predicition updater for delay trainer?
self.old_pred = record.load_object("pred.pkl")
self.last_end = self.old_pred.index.get_level_values("datetime").max()

View File

@@ -5,6 +5,7 @@ TaskGenerator module can generate many tasks based on TaskGen and some task temp
"""
import abc
import copy
import pandas as pd
from typing import List, Union, Callable
from qlib.utils import transform_end_date
@@ -139,6 +140,53 @@ class RollingGen(TaskGen):
self.test_key = "test"
self.train_key = "train"
def _update_task_segs(self, task, segs):
# update segments of this task
task["dataset"]["kwargs"]["segments"] = copy.deepcopy(segs)
if self.ds_extra_mod_func is not None:
self.ds_extra_mod_func(task, self)
def gen_following_tasks(self, task: dict, test_end: pd.Timestamp) -> List[dict]:
"""
generating following rolling tasks for `task` until test_end
Parameters
----------
task : dict
Qlib task format
test_end : pd.Timestamp
the latest rolling task includes `test_end`
Returns
-------
List[dict]:
the following tasks of `task`(`task` itself is excluded)
"""
t = copy.deepcopy(task)
prev_seg = t["dataset"]["kwargs"]["segments"]
while True:
segments = {}
try:
for k, seg in prev_seg.items():
# decide how to shift
# expanding only for train data, the segments size of test data and valid data won't change
if k == self.train_key and self.rtype == self.ROLL_EX:
rtype = self.ta.SHIFT_EX
else:
rtype = self.ta.SHIFT_SD
# shift the segments data
segments[k] = self.ta.shift(seg, step=self.step, rtype=rtype)
if segments[self.test_key][0] > test_end:
break
except KeyError:
# We reach the end of tasks
# No more rolling
break
prev_seg = segments
self._update_task_segs(t, segments)
yield t
def generate(self, task: dict) -> List[dict]:
"""
Converting the task into a rolling task.
@@ -191,43 +239,23 @@ class RollingGen(TaskGen):
"""
res = []
prev_seg = None
test_end = None
while True:
t = copy.deepcopy(task)
t = copy.deepcopy(task)
# calculate segments
if prev_seg is None:
# First rolling
# 1) prepare the end point
segments: dict = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"]))
test_end = transform_end_date(segments[self.test_key][1])
# 2) and init test segments
test_start_idx = self.ta.align_idx(segments[self.test_key][0])
segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1))
else:
segments = {}
try:
for k, seg in prev_seg.items():
# decide how to shift
# expanding only for train data, the segments size of test data and valid data won't change
if k == self.train_key and self.rtype == self.ROLL_EX:
rtype = self.ta.SHIFT_EX
else:
rtype = self.ta.SHIFT_SD
# shift the segments data
segments[k] = self.ta.shift(seg, step=self.step, rtype=rtype)
if segments[self.test_key][0] > test_end:
break
except KeyError:
# We reach the end of tasks
# No more rolling
break
# calculate segments
# update segments of this task
t["dataset"]["kwargs"]["segments"] = copy.deepcopy(segments)
prev_seg = segments
if self.ds_extra_mod_func is not None:
self.ds_extra_mod_func(t, self)
res.append(t)
# First rolling
# 1) prepare the end point
segments: dict = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"]))
test_end = transform_end_date(segments[self.test_key][1])
# 2) and init test segments
test_start_idx = self.ta.align_idx(segments[self.test_key][0])
segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1))
# update segments of this task
self._update_task_segs(t, segments)
res.append(t)
# Update the following rolling
res.extend(self.gen_following_tasks(t, test_end))
return res