1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-30 17:41:18 +08:00

bug fixed

This commit is contained in:
lzh222333
2021-04-22 08:09:15 +00:00
parent cec318fbfe
commit de0a0c083d
4 changed files with 29 additions and 33 deletions

View File

@@ -1,18 +1,13 @@
from typing import Dict, Union, List
from qlib import get_module_logger
from qlib.workflow import R
from qlib.model.trainer import task_train
from qlib.workflow.recorder import MLflowRecorder, Recorder
from qlib.workflow.online.update import PredUpdater, RecordUpdater
from qlib.workflow.task.collect import Collector
from qlib.workflow.task.utils import TimeAdjuster
from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.manage import TaskManager
from qlib.workflow.task.manage import run_task
from qlib.workflow.task.utils import list_recorders
from qlib.utils.serial import Serializable
from qlib.model.trainer import Trainer, TrainerR
from copy import deepcopy
from typing import Dict, List, Union
from qlib import get_module_logger
from qlib.data.data import D
from qlib.model.trainer import Trainer, TrainerR, task_train
from qlib.workflow.online.update import PredUpdater
from qlib.workflow.recorder import Recorder
from qlib.workflow.task.collect import Collector
from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.utils import TimeAdjuster, list_recorders
class OnlineManager:
@@ -63,6 +58,7 @@ class OnlineManager:
`NEXT_ONLINE_TAG` for reset online model when calling `reset_online_tag`
`OFFLINE_TAG` for train but offline those models
"""
# TODO: 回调
if not (tasks is None or len(tasks) == 0):
if self.trainer is not None:
new_models = self.trainer.train(tasks)
@@ -158,7 +154,8 @@ class OnlineManagerR(OnlineManager):
collector (Collector, optional): a instance of Collector. Defaults to None.
need_log (bool, optional): print log or not. Defaults to True.
"""
trainer = TrainerR(experiment_name)
if trainer is None:
trainer = TrainerR(experiment_name)
super().__init__(trainer=trainer, collector=collector, need_log=need_log)
self.exp_name = experiment_name
@@ -239,7 +236,8 @@ class RollingOnlineManager(OnlineManagerR):
collector (Collector, optional): a instance of Collector. Defaults to None.
need_log (bool, optional): print log or not. Defaults to True.
"""
trainer = TrainerR(experiment_name)
if trainer is None:
trainer = TrainerR(experiment_name)
super().__init__(experiment_name=experiment_name, trainer=trainer, collector=collector, need_log=need_log)
self.ta = TimeAdjuster()
self.rg = rolling_gen
@@ -247,9 +245,17 @@ class RollingOnlineManager(OnlineManagerR):
def prepare_signals(self, *args, **kwargs):
"""
Average the online models prediction and save them into a recorder
Must use `pass` even though there is nothing to do.
"""
pass
# 检查recorder是否存在如果不存在就创建一个
# 检查recorder的上一个信号时间如果没有那就从上线模型的共同最早时间开始出信号
# 从recorder的上一个信号时间开始出信号出到self.cur_time
for model in self.online_models():
pass
def prepare_tasks(self, *args, **kwargs):
"""
@@ -258,17 +264,17 @@ class RollingOnlineManager(OnlineManagerR):
Returns:
list: a list of new tasks.
"""
self.ta.set_end_time(self.cur_time)
#TODO: max_test = self.cur_time
latest_records, max_test = self.list_latest_recorders(
lambda rec: self.get_online_tag(rec) == OnlineManager.ONLINE_TAG
)
if max_test is None:
self.logger.warn(f"No latest online recorders, no new tasks.")
return []
calendar_latest = self.ta.last_date() if self.cur_time is None else self.cur_time
calendar_latest = D.calendar(end_time=self.cur_time)[-1] if self.cur_time is None else self.cur_time
if self.need_log:
self.logger.info(
f"The interval between current time and last rolling test begin time is {self.ta.cal_interval(calendar_latest, max_test[0])}, the rolling step is {self.rg.step}"
f"The interval between current time {calendar_latest} and last rolling test begin time {max_test[0]} is {self.ta.cal_interval(calendar_latest, max_test[0])}, the rolling step is {self.rg.step}"
)
if self.ta.cal_interval(calendar_latest, max_test[0]) >= self.rg.step:
old_tasks = []

View File

@@ -73,9 +73,7 @@ class PredUpdater(RecordUpdater):
Update the prediction in the Recorder
"""
LATEST = "__latest"
def __init__(self, record: Recorder, to_date=LATEST, hist_ref: int = 0, freq="day", need_log=True):
def __init__(self, record: Recorder, to_date=None, hist_ref: int = 0, freq="day", need_log=True):
"""
Parameters
----------
@@ -95,8 +93,7 @@ class PredUpdater(RecordUpdater):
self.freq = freq
self.rmdl = RMDLoader(rec=record)
# FIXME: why we need LATEST? can we use to_date=None instead?
if to_date == self.LATEST or to_date == None:
if to_date == None:
to_date = D.calendar(freq=freq)[-1]
self.to_date = pd.Timestamp(to_date)
self.old_pred = record.load_object("pred.pkl")

View File

@@ -169,7 +169,7 @@ class RollingGen(TaskGen):
# First rolling
# 1) prepare the end point
segments: dict = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"]))
test_end = self.ta.last_date() if segments[self.test_key][1] is None else segments[self.test_key][1]
test_end = self.ta.max() if segments[self.test_key][1] is None else segments[self.test_key][1]
# 2) and init test segments
test_start_idx = self.ta.align_idx(segments[self.test_key][0])
segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1))

View File

@@ -90,17 +90,10 @@ class TimeAdjuster:
def max(self):
"""
(Deprecated)
Return the max calendar datetime
"""
return max(self.cals)
def last_date(self) -> pd.Timestamp:
"""
Return the last datetime in the calendar
"""
return self.cals[-1]
def align_idx(self, time_point, tp_type="start"):
"""
align the index of time_point in the calendar