mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-30 17:41:18 +08:00
bug fixed
This commit is contained in:
@@ -1,18 +1,13 @@
|
||||
from typing import Dict, Union, List
|
||||
from qlib import get_module_logger
|
||||
from qlib.workflow import R
|
||||
from qlib.model.trainer import task_train
|
||||
from qlib.workflow.recorder import MLflowRecorder, Recorder
|
||||
from qlib.workflow.online.update import PredUpdater, RecordUpdater
|
||||
from qlib.workflow.task.collect import Collector
|
||||
from qlib.workflow.task.utils import TimeAdjuster
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.workflow.task.manage import run_task
|
||||
from qlib.workflow.task.utils import list_recorders
|
||||
from qlib.utils.serial import Serializable
|
||||
from qlib.model.trainer import Trainer, TrainerR
|
||||
from copy import deepcopy
|
||||
from typing import Dict, List, Union
|
||||
from qlib import get_module_logger
|
||||
from qlib.data.data import D
|
||||
from qlib.model.trainer import Trainer, TrainerR, task_train
|
||||
from qlib.workflow.online.update import PredUpdater
|
||||
from qlib.workflow.recorder import Recorder
|
||||
from qlib.workflow.task.collect import Collector
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.utils import TimeAdjuster, list_recorders
|
||||
|
||||
|
||||
class OnlineManager:
|
||||
@@ -63,6 +58,7 @@ class OnlineManager:
|
||||
`NEXT_ONLINE_TAG` for reset online model when calling `reset_online_tag`
|
||||
`OFFLINE_TAG` for train but offline those models
|
||||
"""
|
||||
# TODO: 回调
|
||||
if not (tasks is None or len(tasks) == 0):
|
||||
if self.trainer is not None:
|
||||
new_models = self.trainer.train(tasks)
|
||||
@@ -158,7 +154,8 @@ class OnlineManagerR(OnlineManager):
|
||||
collector (Collector, optional): a instance of Collector. Defaults to None.
|
||||
need_log (bool, optional): print log or not. Defaults to True.
|
||||
"""
|
||||
trainer = TrainerR(experiment_name)
|
||||
if trainer is None:
|
||||
trainer = TrainerR(experiment_name)
|
||||
super().__init__(trainer=trainer, collector=collector, need_log=need_log)
|
||||
self.exp_name = experiment_name
|
||||
|
||||
@@ -239,7 +236,8 @@ class RollingOnlineManager(OnlineManagerR):
|
||||
collector (Collector, optional): a instance of Collector. Defaults to None.
|
||||
need_log (bool, optional): print log or not. Defaults to True.
|
||||
"""
|
||||
trainer = TrainerR(experiment_name)
|
||||
if trainer is None:
|
||||
trainer = TrainerR(experiment_name)
|
||||
super().__init__(experiment_name=experiment_name, trainer=trainer, collector=collector, need_log=need_log)
|
||||
self.ta = TimeAdjuster()
|
||||
self.rg = rolling_gen
|
||||
@@ -247,9 +245,17 @@ class RollingOnlineManager(OnlineManagerR):
|
||||
|
||||
def prepare_signals(self, *args, **kwargs):
|
||||
"""
|
||||
Average the online models prediction and save them into a recorder
|
||||
|
||||
|
||||
Must use `pass` even though there is nothing to do.
|
||||
"""
|
||||
pass
|
||||
# 检查recorder是否存在,如果不存在就创建一个
|
||||
# 检查recorder的上一个信号时间,如果没有那就从上线模型的共同最早时间开始出信号
|
||||
# 从recorder的上一个信号时间开始出信号,出到self.cur_time
|
||||
for model in self.online_models():
|
||||
|
||||
pass
|
||||
|
||||
def prepare_tasks(self, *args, **kwargs):
|
||||
"""
|
||||
@@ -258,17 +264,17 @@ class RollingOnlineManager(OnlineManagerR):
|
||||
Returns:
|
||||
list: a list of new tasks.
|
||||
"""
|
||||
self.ta.set_end_time(self.cur_time)
|
||||
#TODO: max_test = self.cur_time
|
||||
latest_records, max_test = self.list_latest_recorders(
|
||||
lambda rec: self.get_online_tag(rec) == OnlineManager.ONLINE_TAG
|
||||
)
|
||||
if max_test is None:
|
||||
self.logger.warn(f"No latest online recorders, no new tasks.")
|
||||
return []
|
||||
calendar_latest = self.ta.last_date() if self.cur_time is None else self.cur_time
|
||||
calendar_latest = D.calendar(end_time=self.cur_time)[-1] if self.cur_time is None else self.cur_time
|
||||
if self.need_log:
|
||||
self.logger.info(
|
||||
f"The interval between current time and last rolling test begin time is {self.ta.cal_interval(calendar_latest, max_test[0])}, the rolling step is {self.rg.step}"
|
||||
f"The interval between current time {calendar_latest} and last rolling test begin time {max_test[0]} is {self.ta.cal_interval(calendar_latest, max_test[0])}, the rolling step is {self.rg.step}"
|
||||
)
|
||||
if self.ta.cal_interval(calendar_latest, max_test[0]) >= self.rg.step:
|
||||
old_tasks = []
|
||||
|
||||
@@ -73,9 +73,7 @@ class PredUpdater(RecordUpdater):
|
||||
Update the prediction in the Recorder
|
||||
"""
|
||||
|
||||
LATEST = "__latest"
|
||||
|
||||
def __init__(self, record: Recorder, to_date=LATEST, hist_ref: int = 0, freq="day", need_log=True):
|
||||
def __init__(self, record: Recorder, to_date=None, hist_ref: int = 0, freq="day", need_log=True):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -95,8 +93,7 @@ class PredUpdater(RecordUpdater):
|
||||
self.freq = freq
|
||||
self.rmdl = RMDLoader(rec=record)
|
||||
|
||||
# FIXME: why we need LATEST? can we use to_date=None instead?
|
||||
if to_date == self.LATEST or to_date == None:
|
||||
if to_date == None:
|
||||
to_date = D.calendar(freq=freq)[-1]
|
||||
self.to_date = pd.Timestamp(to_date)
|
||||
self.old_pred = record.load_object("pred.pkl")
|
||||
|
||||
@@ -169,7 +169,7 @@ class RollingGen(TaskGen):
|
||||
# First rolling
|
||||
# 1) prepare the end point
|
||||
segments: dict = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"]))
|
||||
test_end = self.ta.last_date() if segments[self.test_key][1] is None else segments[self.test_key][1]
|
||||
test_end = self.ta.max() if segments[self.test_key][1] is None else segments[self.test_key][1]
|
||||
# 2) and init test segments
|
||||
test_start_idx = self.ta.align_idx(segments[self.test_key][0])
|
||||
segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1))
|
||||
|
||||
@@ -90,17 +90,10 @@ class TimeAdjuster:
|
||||
|
||||
def max(self):
|
||||
"""
|
||||
(Deprecated)
|
||||
Return the max calendar datetime
|
||||
"""
|
||||
return max(self.cals)
|
||||
|
||||
def last_date(self) -> pd.Timestamp:
|
||||
"""
|
||||
Return the last datetime in the calendar
|
||||
"""
|
||||
return self.cals[-1]
|
||||
|
||||
def align_idx(self, time_point, tp_type="start"):
|
||||
"""
|
||||
align the index of time_point in the calendar
|
||||
|
||||
Reference in New Issue
Block a user