1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 02:50:58 +08:00
Files
qlib/qlib/workflow/online/update.py
Linlang Lv (iSoftStone) 5200ff520a fix_download_data_for_CI
2022-03-25 16:56:02 +08:00

290 lines
10 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
Updater is a module to update artifacts such as predictions when the stock data is updating.
"""
from abc import ABCMeta, abstractmethod
from typing import Optional
import pandas as pd
from qlib import get_module_logger
from qlib.data import D
from qlib.data.dataset import Dataset, DatasetH, TSDatasetH
from qlib.data.dataset.handler import DataHandlerLP
from qlib.model import Model
from qlib.utils import get_date_by_shift
from qlib.workflow.recorder import Recorder
from qlib.workflow.record_temp import SignalRecord
class RMDLoader:
"""
Recorder Model Dataset Loader
"""
def __init__(self, rec: Recorder):
self.rec = rec
def get_dataset(
self, start_time, end_time, segments=None, unprepared_dataset: Optional[DatasetH] = None
) -> DatasetH:
"""
Load, config and setup dataset.
This dataset is for inference.
Args:
start_time :
the start_time of underlying data
end_time :
the end_time of underlying data
segments : dict
the segments config for dataset
Due to the time series dataset (TSDatasetH), the test segments maybe different from start_time and end_time
unprepared_dataset: Optional[DatasetH]
if user don't want to load dataset from recorder, please specify user's dataset
Returns:
DatasetH: the instance of DatasetH
"""
if segments is None:
segments = {"test": (start_time, end_time)}
if unprepared_dataset is None:
dataset: DatasetH = self.rec.load_object("dataset")
else:
dataset = unprepared_dataset
dataset.config(handler_kwargs={"start_time": start_time, "end_time": end_time}, segments=segments)
dataset.setup_data(handler_kwargs={"init_type": DataHandlerLP.IT_LS})
return dataset
def get_model(self) -> Model:
return self.rec.load_object("params.pkl")
class RecordUpdater(metaclass=ABCMeta):
"""
Update a specific recorders
"""
def __init__(self, record: Recorder, *args, **kwargs):
self.record = record
self.logger = get_module_logger(self.__class__.__name__)
@abstractmethod
def update(self, *args, **kwargs):
"""
Update info for specific recorder
"""
class DSBasedUpdater(RecordUpdater, metaclass=ABCMeta):
"""
Dataset-Based Updater
- Providing updating feature for Updating data based on Qlib Dataset
Assumption
- Based on Qlib dataset
- The data to be updated is a multi-level index pd.DataFrame. For example label , prediction.
LABEL0
datetime instrument
2021-05-10 SH600000 0.006965
SH600004 0.003407
... ...
2021-05-28 SZ300498 0.015748
SZ300676 -0.001321
"""
def __init__(
self,
record: Recorder,
to_date=None,
from_date=None,
hist_ref: Optional[int] = None,
freq="day",
fname="pred.pkl",
loader_cls: type = RMDLoader,
):
"""
Init PredUpdater.
Expected behavior in following cases:
- if `to_date` is greater than the max date in the calendar, the data will be updated to the latest date
- if there are data before `from_date` or after `to_date`, only the data between `from_date` and `to_date` are affected.
Args:
record : Recorder
to_date :
update to prediction to the `to_date`
if to_date is None:
data will updated to the latest date.
from_date :
the update will start from `from_date`
if from_date is None:
the updating will occur on the next tick after the latest data in historical data
hist_ref : int
Sometimes, the dataset will have historical depends.
Leave the problem to users to set the length of historical dependency
If user doesn't specify this parameter, Updater will try to load dataset to automatically determine the hist_ref
.. note::
the start_time is not included in the hist_ref
loader_cls : type
the class to load the model and dataset
"""
# TODO: automate this hist_ref in the future.
super().__init__(record=record)
self.to_date = to_date
self.hist_ref = hist_ref
self.freq = freq
self.fname = fname
self.rmdl = loader_cls(rec=record)
latest_date = D.calendar(freq=freq)[-1]
if to_date is None:
to_date = latest_date
to_date = pd.Timestamp(to_date)
if to_date >= latest_date:
self.logger.warning(
f"The given `to_date`({to_date}) is later than `latest_date`({latest_date}). So `to_date` is clipped to `latest_date`."
)
to_date = latest_date
self.to_date = to_date
# FIXME: it will raise error when running routine with delay trainer
# should we use another prediction updater for delay trainer?
self.old_data: pd.DataFrame = record.load_object(fname)
if from_date is None:
# dropna is for being compatible to some data with future information(e.g. label)
# The recent label data should be updated together
self.last_end = self.old_data.dropna().index.get_level_values("datetime").max()
else:
self.last_end = get_date_by_shift(from_date, -1, align="right")
def prepare_data(self, unprepared_dataset: Optional[DatasetH] = None) -> DatasetH:
"""
Load dataset
- if unprepared_dataset is specified, then prepare the dataset directly
- Otherwise,
Separating this function will make it easier to reuse the dataset
Returns:
DatasetH: the instance of DatasetH
"""
# automatically getting the historical dependency if not specified
if self.hist_ref is None:
dataset: DatasetH = self.record.load_object("dataset") if unprepared_dataset is None else unprepared_dataset
# Special treatment of historical dependencies
if isinstance(dataset, TSDatasetH):
hist_ref = dataset.step_len
else:
hist_ref = 0
else:
hist_ref = self.hist_ref
start_time_buffer = get_date_by_shift(
self.last_end, -hist_ref + 1, clip_shift=False, freq=self.freq # pylint: disable=E1130
)
start_time = get_date_by_shift(self.last_end, 1, freq=self.freq)
seg = {"test": (start_time, self.to_date)}
return self.rmdl.get_dataset(
start_time=start_time_buffer, end_time=self.to_date, segments=seg, unprepared_dataset=unprepared_dataset
)
def update(self, dataset: DatasetH = None, write: bool = True, ret_new: bool = False) -> Optional[object]:
"""
Parameters
----------
dataset : DatasetH
DatasetH: the instance of DatasetH. None for prepare it again.
write : bool
will the the write action be executed
ret_new : bool
will the updated data be returned
Returns
-------
Optional[object]
the updated dataset
"""
# FIXME: the problem below is not solved
# The model dumped on GPU instances can not be loaded on CPU instance. Follow exception will raised
# RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
# https://github.com/pytorch/pytorch/issues/16797
if self.last_end >= self.to_date:
self.logger.info(
f"The data in {self.record.info['id']} are latest ({self.last_end}). No need to update to {self.to_date}."
)
return
# load dataset
if dataset is None:
# For reusing the dataset
dataset = self.prepare_data()
updated_data = self.get_update_data(dataset)
if write:
self.record.save_objects(**{self.fname: updated_data})
if ret_new:
return updated_data
@abstractmethod
def get_update_data(self, dataset: Dataset) -> pd.DataFrame:
"""
return the updated data based on the given dataset
The difference between `get_update_data` and `update`
- `update_date` only include some data specific feature
- `update` include some general routine steps(e.g. prepare dataset, checking)
"""
def _replace_range(data, new_data):
dates = new_data.index.get_level_values("datetime")
data = data.sort_index()
data = data.drop(data.loc[dates.min() : dates.max()].index)
cb_data = pd.concat([data, new_data], axis=0)
cb_data = cb_data[~cb_data.index.duplicated(keep="last")].sort_index()
return cb_data
class PredUpdater(DSBasedUpdater):
"""
Update the prediction in the Recorder
"""
def get_update_data(self, dataset: Dataset) -> pd.DataFrame:
# Load model
model = self.rmdl.get_model()
new_pred: pd.Series = model.predict(dataset)
data = _replace_range(self.old_data, new_pred.to_frame("score"))
self.logger.info(f"Finish updating new {new_pred.shape[0]} predictions in {self.record.info['id']}.")
return data
class LabelUpdater(DSBasedUpdater):
"""
Update the label in the recorder
Assumption
- The label is generated from record_temp.SignalRecord.
"""
def __init__(self, record: Recorder, to_date=None, **kwargs):
super().__init__(record, to_date=to_date, fname="label.pkl", **kwargs)
def get_update_data(self, dataset: Dataset) -> pd.DataFrame:
new_label = SignalRecord.generate_label(dataset)
cb_data = _replace_range(self.old_data.sort_index(), new_label)
return cb_data