1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-04 03:21:00 +08:00

Refactor update & modification when running NN

This commit is contained in:
Young
2021-04-11 14:39:19 +00:00
parent a366c11d67
commit cca43cf102
8 changed files with 211 additions and 33 deletions

View File

@@ -58,7 +58,7 @@ class RollingEnsemble(Ensemble):
"""Merge the rolling objects in an Ensemble"""
def __call__(self, ensemble_dict: dict, *args, **kwargs):
def __call__(self, ensemble_dict: dict):
"""Merge a dict of rolling dataframe like `prediction` or `IC` into an ensemble.
NOTE: The values of dict must be pd.Dataframe, and have the index "datetime"

View File

@@ -1,6 +1,7 @@
from qlib.model.ens.ensemble import Ensemble, RollingEnsemble
from typing import Callable, Union
from qlib.utils.serial import Serializable
from joblib import Parallel, delayed
class Group(Serializable):
@@ -18,10 +19,23 @@ class Group(Serializable):
ens (Ensemble, optional): If not None, do ensemble for grouped value after grouping.
"""
self.group = group_func
self.ens = ens
self._group_func = group_func
self._ens_func = ens
def __call__(self, ungrouped_dict: dict, *args, **kwargs):
def group(self, *args, **kwargs):
# TODO: such design is weird when `_group_func` is the only configurable part in the class
if isinstance(getattr(self, "_group_func", None), Callable):
return self._group_func(*args, **kwargs)
else:
raise NotImplementedError(f"Please specify valid `group_func`.")
def reduce(self, *args, **kwargs):
if isinstance(getattr(self, "_ens_func", None), Callable):
return self._ens_func(*args, **kwargs)
else:
raise NotImplementedError(f"Please specify valid `_ens_func`.")
def __call__(self, ungrouped_dict: dict, n_jobs=1, verbose=0, *args, **kwargs):
"""Group the ungrouped_dict into different groups.
Args:
@@ -30,23 +44,24 @@ class Group(Serializable):
Returns:
dict: grouped_dict like {G1: object, G2: object}
"""
if isinstance(getattr(self, "group", None), Callable):
grouped_dict = self.group(ungrouped_dict, *args, **kwargs)
if self.ens is not None:
ens_dict = {}
for key, value in grouped_dict.items():
ens_dict[key] = self.ens(value)
grouped_dict = ens_dict
return grouped_dict
else:
raise NotImplementedError(f"Please specify valid group_func.")
# FIXME: The multiprocessing will raise the following error
# NotImplementedError: Please specify valid `_ens_func`.
# The problem maybe the state of the function is lost
grouped_dict = self.group(ungrouped_dict, *args, **kwargs)
key_l = []
job_l = []
for key, value in grouped_dict.items():
key_l.append(key)
job_l.append(delayed(Group.reduce)(self, value))
return dict(zip(key_l, Parallel(n_jobs=n_jobs, verbose=verbose)(job_l)))
class RollingGroup(Group):
"""group the rolling dict"""
@staticmethod
def rolling_group(rolling_dict: dict):
def group(self, rolling_dict: dict):
"""Given an rolling dict likes {(A,B,R): things}, return the grouped dict likes {(A,B): {R:things}}
NOTE: There is a assumption which is the rolling key is at the end of key tuple, because the rolling results always need to be ensemble firstly.
@@ -63,7 +78,5 @@ class RollingGroup(Group):
grouped_dict.setdefault(key[:-1], {})[key[-1]] = values
return grouped_dict
def __init__(self, group_func=None):
super().__init__(group_func=group_func, ens=RollingEnsemble())
if group_func is None:
self.group = RollingGroup.rolling_group
def __init__(self):
super().__init__(ens=RollingEnsemble())

View File

@@ -8,6 +8,7 @@ from qlib.workflow.record_temp import SignalRecord
from qlib.workflow.task.manage import TaskManager, run_task
from qlib.data.dataset import Dataset
from qlib.model.base import Model
import socket
def task_train(task_config: dict, experiment_name: str) -> Recorder:
@@ -35,16 +36,17 @@ def task_train(task_config: dict, experiment_name: str) -> Recorder:
# train model
R.log_params(**flatten_dict(task_config))
model.fit(dataset)
recorder = R.get_recorder()
R.save_objects(**{"params.pkl": model})
R.save_objects(**{"task": task_config}) # keep the original format and datatype
R.set_tags(hostname=socket.gethostname())
model.fit(dataset)
R.save_objects(**{"params.pkl": model})
# This dataset is saved for online inference. So the concrete data should not be dumped
dataset.config(dump_all=False, recursive=True)
R.save_objects(**{"dataset": dataset})
# generate records: prediction, backtest, and analysis
records = task_config.get("record", [])
recorder = R.get_recorder()
if isinstance(records, dict): # prevent only one dict
records = [records]
for record in records:

View File

@@ -522,7 +522,7 @@ def get_date_range(trading_date, left_shift=0, right_shift=0, future=False):
return calendar
def get_date_by_shift(trading_date, shift, future=False, clip_shift=True):
def get_date_by_shift(trading_date, shift, future=False, clip_shift=True, freq="day"):
"""get trading date with shift bias wil cur_date
e.g. : shift == 1, return next trading date
shift == -1, return previous trading date
@@ -535,7 +535,7 @@ def get_date_by_shift(trading_date, shift, future=False, clip_shift=True):
"""
from qlib.data import D
cal = D.calendar(future=future)
cal = D.calendar(future=future, freq=freq)
if pd.to_datetime(trading_date) not in list(cal):
raise ValueError("{} is not trading day!".format(str(trading_date)))
_index = bisect.bisect_left(cal, trading_date)

View File

@@ -1,13 +1,142 @@
from typing import Union, List
from qlib.data.dataset import DatasetH
from qlib.workflow import R
from qlib.data import D
import pandas as pd
from qlib import get_module_logger
from qlib.workflow import R
from qlib.model import Model
from qlib.model.trainer import task_train
from qlib.workflow.recorder import Recorder
from qlib.workflow.task.utils import list_recorders
from qlib.data.dataset.handler import DataHandlerLP
from qlib.data.dataset import DatasetH
from abc import ABCMeta, abstractmethod
from qlib.utils import get_date_by_shift
class RMDLoader:
"""
Recorder Model Dataset Loader
"""
def __init__(self, rec: Recorder):
self.rec = rec
def get_dataset(self, start_time, end_time, segments=None) -> DatasetH:
"""
load, config and setup dataset.
This dataset is for inferene
Parameters
----------
start_time :
the start_time of underlying data
end_time :
the end_time of underlying data
segments : dict
the segments config for dataset
Due to the time series dataset (TSDatasetH), the test segments maybe different from start_time and end_time
"""
if segments is None:
segments = {"test": (start_time, end_time)}
dataset: DatasetH = self.rec.load_object("dataset")
dataset.config(handler_kwargs={"start_time": start_time, "end_time": end_time}, segments=segments)
dataset.setup_data(handler_kwargs={"init_type": DataHandlerLP.IT_LS})
return dataset
def get_model(self) -> Model:
return self.rec.load_object("params.pkl")
class RecordUpdater(metaclass=ABCMeta):
"""
Updata a specific recorders
"""
def __init__(self, record: Recorder, *args, **kwargs):
self.record = record
@abstractmethod
def update(self, *args, **kwargs):
"""
Update info for specific recorder
"""
...
class PredUpdater(RecordUpdater):
"""
Update the prediction in the Recorder
"""
LATEST = "__latest"
def __init__(self, record: Recorder, to_date=LATEST, hist_ref: int = 0, freq="day"):
"""
Parameters
----------
record : Recorder
to_date :
update to prediction to the `to_date`
hist_ref : int
Sometimes, the dataset will have historical depends.
Leave the problem to user to set the length of historical dependancy
NOTE: the start_time is not included in the hist_ref
# TODO: automate this step in the future.
"""
super().__init__(record=record)
self.to_date = to_date
self.hist_ref = hist_ref
self.freq = freq
self.rmdl = RMDLoader(rec=record)
if to_date == self.LATEST:
to_date = D.calendar(freq=freq)[-1]
self.to_date = pd.Timestamp(to_date)
self.old_pred = record.load_object("pred.pkl")
self.last_end = self.old_pred.index.get_level_values("datetime").max()
def prepare_data(self) -> DatasetH:
"""
# Load dataset
Seperating this function will make it easier to reuse the dataset
"""
start_time_buffer = get_date_by_shift(self.last_end, -self.hist_ref + 1, clip_shift=False, freq=self.freq)
start_time = get_date_by_shift(self.last_end, 1, freq=self.freq)
seg = {"test": (start_time, self.to_date)}
dataset = self.rmdl.get_dataset(start_time=start_time_buffer, end_time=self.to_date, segments=seg)
return dataset
def update(self, dataset: DatasetH = None):
"""
update the precition in a recorder
"""
# FIXME: the problme below is not solved
# The model dumped on GPU instances can not be loaded on CPU instance. Follow exception will raised
# RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
# load dataset
if dataset is None:
# For reusing the dataset
dataset = self.prepare_data()
# Load model
model = self.rmdl.get_model()
new_pred = model.predict(dataset)
cb_pred = pd.concat([self.old_pred, new_pred.to_frame("score")], axis=0)
cb_pred = cb_pred.sort_index()
self.record.save_objects(**{"pred.pkl": cb_pred})
get_module_logger(self.__class__.__name__).info(
f"Finish updating new {new_pred.shape[0]} predictions in {self.record.info['id']}."
)
class ModelUpdater:

View File

@@ -25,6 +25,8 @@ class Collector(Serializable):
class RecorderCollector(Collector):
ART_KEY_RAW = "__raw"
def __init__(
self,
exp_name,
@@ -48,9 +50,9 @@ class RecorderCollector(Collector):
rec_key_func = lambda rec: rec.info["id"]
if artifacts_key is None:
artifacts_key = self.artifacts_path.keys()
self.rec_key = rec_key_func
self._rec_key_func = rec_key_func
self.artifacts_key = artifacts_key
self.rec_filter = rec_filter_func
self._rec_filter_func = rec_filter_func
def collect(self, artifacts_key=None, rec_filter_func=None):
"""Collect different artifacts based on recorder after filtering.
@@ -65,7 +67,7 @@ class RecorderCollector(Collector):
if artifacts_key is None:
artifacts_key = self.artifacts_key
if rec_filter_func is None:
rec_filter_func = self.rec_filter
rec_filter_func = self._rec_filter_func
if isinstance(artifacts_key, str):
artifacts_key = [artifacts_key]
@@ -74,9 +76,12 @@ class RecorderCollector(Collector):
# filter records
recs_flt = list_recorders(self.exp_name, rec_filter_func)
for _, rec in recs_flt.items():
rec_key = self.rec_key(rec)
rec_key = self._rec_key_func(rec)
for key in artifacts_key:
artifact = rec.load_object(self.artifacts_path[key])
if self.ART_KEY_RAW == key:
artifact = rec
else:
artifact = rec.load_object(self.artifacts_path[key])
collect_dict.setdefault(key, {})[rec_key] = artifact
return collect_dict

View File

@@ -80,6 +80,12 @@ class TaskGen(metaclass=abc.ABCMeta):
"""
pass
def __call__(self, *args, **kwargs):
"""
This is just a syntactic sugar for generate
"""
return self.generate(*args, **kwargs)
class RollingGen(TaskGen):
ROLL_EX = TimeAdjuster.SHIFT_EX # fixed start date, expanding end date

View File

@@ -18,7 +18,8 @@ import concurrent
import pymongo
from qlib.config import C
from .utils import get_mongodb
from qlib import get_module_logger
from qlib import get_module_logger, auto_init
import fire
class TaskManager:
@@ -49,7 +50,7 @@ class TaskManager:
ENCODE_FIELDS_PREFIX = ["def", "res"]
def __init__(self, task_pool: str):
def __init__(self, task_pool: str = None):
"""
init Task Manager, remember to make the statement of MongoDB url and database name firstly.
@@ -59,7 +60,8 @@ class TaskManager:
the name of Collection in MongoDB
"""
self.mdb = get_mongodb()
self.task_pool = getattr(self.mdb, task_pool)
if task_pool is not None:
self.task_pool = getattr(self.mdb, task_pool)
self.logger = get_module_logger(self.__class__.__name__)
def list(self):
@@ -287,6 +289,20 @@ class TaskManager:
query["_id"] = ObjectId(query["_id"])
print(self.task_pool.update_many(query, {"$set": {"status": status}}))
def prioritize(self, task, priority: int):
"""
set priority for task
Parameters
----------
task : dict
The task query from the database
priority : int
the target priority
"""
update_dict = {"$set": {"priority": priority}}
self.task_pool.update_one({"_id": task["_id"]}, update_dict)
def _get_undone_n(self, task_stat):
return task_stat.get(self.STATUS_WAITING, 0) + task_stat.get(self.STATUS_RUNNING, 0)
@@ -345,3 +361,10 @@ def run_task(task_func, task_pool, force_release=False, *args, **kwargs):
ever_run = True
return ever_run
if __name__ == "__main__":
# This is for using it in cmd
# E.g. : `python -m qlib.workflow.task.manage list`
auto_init()
fire.Fire(TaskManager)