mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-01 18:11:18 +08:00
refactor online serving rolling api
This commit is contained in:
@@ -78,4 +78,4 @@ task:
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
config: *port_analysis_config
|
||||
|
||||
1
examples/model_rolling/requirements.txt
Normal file
1
examples/model_rolling/requirements.txt
Normal file
@@ -0,0 +1 @@
|
||||
xgboost
|
||||
@@ -570,9 +570,11 @@ def get_pre_trading_date(trading_date, future=False):
|
||||
|
||||
|
||||
def transform_end_date(end_date=None, freq="day"):
|
||||
"""get previous trading date
|
||||
"""handle the end date with various format
|
||||
|
||||
If end_date is -1, None, or end_date is greater than the maximum trading day, the last trading date is returned.
|
||||
Otherwise, returns the end_date
|
||||
|
||||
----------
|
||||
end_date: str
|
||||
end trading date
|
||||
|
||||
@@ -10,6 +10,7 @@ from typing import List, Tuple, Union
|
||||
from qlib.data.data import D
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.model.ens.group import RollingGroup
|
||||
from qlib.utils import transform_end_date
|
||||
from qlib.workflow.online.utils import OnlineTool, OnlineToolR
|
||||
from qlib.workflow.recorder import Recorder
|
||||
from qlib.workflow.task.collect import Collector, RecorderCollector
|
||||
@@ -118,6 +119,7 @@ class RollingStrategy(OnlineStrategy):
|
||||
task_template = [task_template]
|
||||
self.task_template = task_template
|
||||
self.rg = rolling_gen
|
||||
assert issubclass(self.rg.__class__, RollingGen), "The rolling strategy relies on the feature if RollingGen"
|
||||
self.tool = OnlineToolR(self.exp_name)
|
||||
self.ta = TimeAdjuster()
|
||||
|
||||
@@ -174,28 +176,20 @@ class RollingStrategy(OnlineStrategy):
|
||||
Returns:
|
||||
List[dict]: a list of new tasks.
|
||||
"""
|
||||
# TODO: filter recorders by latest test segments is not a necessary
|
||||
latest_records, max_test = self._list_latest(self.tool.online_models())
|
||||
if max_test is None:
|
||||
self.logger.warn(f"No latest online recorders, no new tasks.")
|
||||
return []
|
||||
calendar_latest = D.calendar(end_time=cur_time)[-1] if cur_time is None else cur_time
|
||||
calendar_latest = transform_end_date(cur_time)
|
||||
self.logger.info(
|
||||
f"The interval between current time {calendar_latest} and last rolling test begin time {max_test[0]} is {self.ta.cal_interval(calendar_latest, max_test[0])}, the rolling step is {self.rg.step}"
|
||||
)
|
||||
if self.ta.cal_interval(calendar_latest, max_test[0]) >= self.rg.step:
|
||||
old_tasks = []
|
||||
tasks_tmp = []
|
||||
for rec in latest_records:
|
||||
task = rec.load_object("task")
|
||||
old_tasks.append(deepcopy(task))
|
||||
test_begin = task["dataset"]["kwargs"]["segments"]["test"][0]
|
||||
# modify the test segment to generate new tasks
|
||||
task["dataset"]["kwargs"]["segments"]["test"] = (test_begin, calendar_latest)
|
||||
tasks_tmp.append(task)
|
||||
new_tasks_tmp = task_generator(tasks_tmp, self.rg)
|
||||
new_tasks = [task for task in new_tasks_tmp if task not in old_tasks]
|
||||
return new_tasks
|
||||
return []
|
||||
res = []
|
||||
for rec in latest_records:
|
||||
task = rec.load_object("task")
|
||||
res.extend(self.rg.gen_following_tasks(task, calendar_latest))
|
||||
return res
|
||||
|
||||
def _list_latest(self, rec_list: List[Recorder]):
|
||||
"""
|
||||
|
||||
@@ -105,6 +105,8 @@ class PredUpdater(RecordUpdater):
|
||||
if to_date == None:
|
||||
to_date = D.calendar(freq=freq)[-1]
|
||||
self.to_date = pd.Timestamp(to_date)
|
||||
# FIXME: it will raise error when running routine with delay trainer
|
||||
# should we use another predicition updater for delay trainer?
|
||||
self.old_pred = record.load_object("pred.pkl")
|
||||
self.last_end = self.old_pred.index.get_level_values("datetime").max()
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ TaskGenerator module can generate many tasks based on TaskGen and some task temp
|
||||
"""
|
||||
import abc
|
||||
import copy
|
||||
import pandas as pd
|
||||
from typing import List, Union, Callable
|
||||
|
||||
from qlib.utils import transform_end_date
|
||||
@@ -139,6 +140,53 @@ class RollingGen(TaskGen):
|
||||
self.test_key = "test"
|
||||
self.train_key = "train"
|
||||
|
||||
def _update_task_segs(self, task, segs):
|
||||
# update segments of this task
|
||||
task["dataset"]["kwargs"]["segments"] = copy.deepcopy(segs)
|
||||
if self.ds_extra_mod_func is not None:
|
||||
self.ds_extra_mod_func(task, self)
|
||||
|
||||
def gen_following_tasks(self, task: dict, test_end: pd.Timestamp) -> List[dict]:
|
||||
"""
|
||||
generating following rolling tasks for `task` until test_end
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task : dict
|
||||
Qlib task format
|
||||
test_end : pd.Timestamp
|
||||
the latest rolling task includes `test_end`
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[dict]:
|
||||
the following tasks of `task`(`task` itself is excluded)
|
||||
"""
|
||||
t = copy.deepcopy(task)
|
||||
prev_seg = t["dataset"]["kwargs"]["segments"]
|
||||
while True:
|
||||
segments = {}
|
||||
try:
|
||||
for k, seg in prev_seg.items():
|
||||
# decide how to shift
|
||||
# expanding only for train data, the segments size of test data and valid data won't change
|
||||
if k == self.train_key and self.rtype == self.ROLL_EX:
|
||||
rtype = self.ta.SHIFT_EX
|
||||
else:
|
||||
rtype = self.ta.SHIFT_SD
|
||||
# shift the segments data
|
||||
segments[k] = self.ta.shift(seg, step=self.step, rtype=rtype)
|
||||
if segments[self.test_key][0] > test_end:
|
||||
break
|
||||
except KeyError:
|
||||
# We reach the end of tasks
|
||||
# No more rolling
|
||||
break
|
||||
|
||||
prev_seg = segments
|
||||
self._update_task_segs(t, segments)
|
||||
yield t
|
||||
|
||||
def generate(self, task: dict) -> List[dict]:
|
||||
"""
|
||||
Converting the task into a rolling task.
|
||||
@@ -191,43 +239,23 @@ class RollingGen(TaskGen):
|
||||
"""
|
||||
res = []
|
||||
|
||||
prev_seg = None
|
||||
test_end = None
|
||||
while True:
|
||||
t = copy.deepcopy(task)
|
||||
t = copy.deepcopy(task)
|
||||
|
||||
# calculate segments
|
||||
if prev_seg is None:
|
||||
# First rolling
|
||||
# 1) prepare the end point
|
||||
segments: dict = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"]))
|
||||
test_end = transform_end_date(segments[self.test_key][1])
|
||||
# 2) and init test segments
|
||||
test_start_idx = self.ta.align_idx(segments[self.test_key][0])
|
||||
segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1))
|
||||
else:
|
||||
segments = {}
|
||||
try:
|
||||
for k, seg in prev_seg.items():
|
||||
# decide how to shift
|
||||
# expanding only for train data, the segments size of test data and valid data won't change
|
||||
if k == self.train_key and self.rtype == self.ROLL_EX:
|
||||
rtype = self.ta.SHIFT_EX
|
||||
else:
|
||||
rtype = self.ta.SHIFT_SD
|
||||
# shift the segments data
|
||||
segments[k] = self.ta.shift(seg, step=self.step, rtype=rtype)
|
||||
if segments[self.test_key][0] > test_end:
|
||||
break
|
||||
except KeyError:
|
||||
# We reach the end of tasks
|
||||
# No more rolling
|
||||
break
|
||||
# calculate segments
|
||||
|
||||
# update segments of this task
|
||||
t["dataset"]["kwargs"]["segments"] = copy.deepcopy(segments)
|
||||
prev_seg = segments
|
||||
if self.ds_extra_mod_func is not None:
|
||||
self.ds_extra_mod_func(t, self)
|
||||
res.append(t)
|
||||
# First rolling
|
||||
# 1) prepare the end point
|
||||
segments: dict = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"]))
|
||||
test_end = transform_end_date(segments[self.test_key][1])
|
||||
# 2) and init test segments
|
||||
test_start_idx = self.ta.align_idx(segments[self.test_key][0])
|
||||
segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1))
|
||||
|
||||
# update segments of this task
|
||||
self._update_task_segs(t, segments)
|
||||
|
||||
res.append(t)
|
||||
|
||||
# Update the following rolling
|
||||
res.extend(self.gen_following_tasks(t, test_end))
|
||||
return res
|
||||
|
||||
Reference in New Issue
Block a user