mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-04 03:21:00 +08:00
Refactor update & modification when running NN
This commit is contained in:
@@ -58,7 +58,7 @@ class RollingEnsemble(Ensemble):
|
||||
|
||||
"""Merge the rolling objects in an Ensemble"""
|
||||
|
||||
def __call__(self, ensemble_dict: dict, *args, **kwargs):
|
||||
def __call__(self, ensemble_dict: dict):
|
||||
"""Merge a dict of rolling dataframe like `prediction` or `IC` into an ensemble.
|
||||
|
||||
NOTE: The values of dict must be pd.Dataframe, and have the index "datetime"
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from qlib.model.ens.ensemble import Ensemble, RollingEnsemble
|
||||
from typing import Callable, Union
|
||||
from qlib.utils.serial import Serializable
|
||||
from joblib import Parallel, delayed
|
||||
|
||||
|
||||
class Group(Serializable):
|
||||
@@ -18,10 +19,23 @@ class Group(Serializable):
|
||||
|
||||
ens (Ensemble, optional): If not None, do ensemble for grouped value after grouping.
|
||||
"""
|
||||
self.group = group_func
|
||||
self.ens = ens
|
||||
self._group_func = group_func
|
||||
self._ens_func = ens
|
||||
|
||||
def __call__(self, ungrouped_dict: dict, *args, **kwargs):
|
||||
def group(self, *args, **kwargs):
|
||||
# TODO: such design is weird when `_group_func` is the only configurable part in the class
|
||||
if isinstance(getattr(self, "_group_func", None), Callable):
|
||||
return self._group_func(*args, **kwargs)
|
||||
else:
|
||||
raise NotImplementedError(f"Please specify valid `group_func`.")
|
||||
|
||||
def reduce(self, *args, **kwargs):
|
||||
if isinstance(getattr(self, "_ens_func", None), Callable):
|
||||
return self._ens_func(*args, **kwargs)
|
||||
else:
|
||||
raise NotImplementedError(f"Please specify valid `_ens_func`.")
|
||||
|
||||
def __call__(self, ungrouped_dict: dict, n_jobs=1, verbose=0, *args, **kwargs):
|
||||
"""Group the ungrouped_dict into different groups.
|
||||
|
||||
Args:
|
||||
@@ -30,23 +44,24 @@ class Group(Serializable):
|
||||
Returns:
|
||||
dict: grouped_dict like {G1: object, G2: object}
|
||||
"""
|
||||
if isinstance(getattr(self, "group", None), Callable):
|
||||
grouped_dict = self.group(ungrouped_dict, *args, **kwargs)
|
||||
if self.ens is not None:
|
||||
ens_dict = {}
|
||||
for key, value in grouped_dict.items():
|
||||
ens_dict[key] = self.ens(value)
|
||||
grouped_dict = ens_dict
|
||||
return grouped_dict
|
||||
else:
|
||||
raise NotImplementedError(f"Please specify valid group_func.")
|
||||
|
||||
# FIXME: The multiprocessing will raise the following error
|
||||
# NotImplementedError: Please specify valid `_ens_func`.
|
||||
# The problem maybe the state of the function is lost
|
||||
grouped_dict = self.group(ungrouped_dict, *args, **kwargs)
|
||||
|
||||
key_l = []
|
||||
job_l = []
|
||||
for key, value in grouped_dict.items():
|
||||
key_l.append(key)
|
||||
job_l.append(delayed(Group.reduce)(self, value))
|
||||
return dict(zip(key_l, Parallel(n_jobs=n_jobs, verbose=verbose)(job_l)))
|
||||
|
||||
|
||||
class RollingGroup(Group):
|
||||
"""group the rolling dict"""
|
||||
|
||||
@staticmethod
|
||||
def rolling_group(rolling_dict: dict):
|
||||
def group(self, rolling_dict: dict):
|
||||
"""Given an rolling dict likes {(A,B,R): things}, return the grouped dict likes {(A,B): {R:things}}
|
||||
|
||||
NOTE: There is a assumption which is the rolling key is at the end of key tuple, because the rolling results always need to be ensemble firstly.
|
||||
@@ -63,7 +78,5 @@ class RollingGroup(Group):
|
||||
grouped_dict.setdefault(key[:-1], {})[key[-1]] = values
|
||||
return grouped_dict
|
||||
|
||||
def __init__(self, group_func=None):
|
||||
super().__init__(group_func=group_func, ens=RollingEnsemble())
|
||||
if group_func is None:
|
||||
self.group = RollingGroup.rolling_group
|
||||
def __init__(self):
|
||||
super().__init__(ens=RollingEnsemble())
|
||||
|
||||
@@ -8,6 +8,7 @@ from qlib.workflow.record_temp import SignalRecord
|
||||
from qlib.workflow.task.manage import TaskManager, run_task
|
||||
from qlib.data.dataset import Dataset
|
||||
from qlib.model.base import Model
|
||||
import socket
|
||||
|
||||
|
||||
def task_train(task_config: dict, experiment_name: str) -> Recorder:
|
||||
@@ -35,16 +36,17 @@ def task_train(task_config: dict, experiment_name: str) -> Recorder:
|
||||
|
||||
# train model
|
||||
R.log_params(**flatten_dict(task_config))
|
||||
model.fit(dataset)
|
||||
recorder = R.get_recorder()
|
||||
R.save_objects(**{"params.pkl": model})
|
||||
R.save_objects(**{"task": task_config}) # keep the original format and datatype
|
||||
R.set_tags(hostname=socket.gethostname())
|
||||
model.fit(dataset)
|
||||
R.save_objects(**{"params.pkl": model})
|
||||
# This dataset is saved for online inference. So the concrete data should not be dumped
|
||||
dataset.config(dump_all=False, recursive=True)
|
||||
R.save_objects(**{"dataset": dataset})
|
||||
|
||||
# generate records: prediction, backtest, and analysis
|
||||
records = task_config.get("record", [])
|
||||
recorder = R.get_recorder()
|
||||
if isinstance(records, dict): # prevent only one dict
|
||||
records = [records]
|
||||
for record in records:
|
||||
|
||||
@@ -522,7 +522,7 @@ def get_date_range(trading_date, left_shift=0, right_shift=0, future=False):
|
||||
return calendar
|
||||
|
||||
|
||||
def get_date_by_shift(trading_date, shift, future=False, clip_shift=True):
|
||||
def get_date_by_shift(trading_date, shift, future=False, clip_shift=True, freq="day"):
|
||||
"""get trading date with shift bias wil cur_date
|
||||
e.g. : shift == 1, return next trading date
|
||||
shift == -1, return previous trading date
|
||||
@@ -535,7 +535,7 @@ def get_date_by_shift(trading_date, shift, future=False, clip_shift=True):
|
||||
"""
|
||||
from qlib.data import D
|
||||
|
||||
cal = D.calendar(future=future)
|
||||
cal = D.calendar(future=future, freq=freq)
|
||||
if pd.to_datetime(trading_date) not in list(cal):
|
||||
raise ValueError("{} is not trading day!".format(str(trading_date)))
|
||||
_index = bisect.bisect_left(cal, trading_date)
|
||||
|
||||
@@ -1,13 +1,142 @@
|
||||
from typing import Union, List
|
||||
from qlib.data.dataset import DatasetH
|
||||
from qlib.workflow import R
|
||||
from qlib.data import D
|
||||
import pandas as pd
|
||||
from qlib import get_module_logger
|
||||
from qlib.workflow import R
|
||||
from qlib.model import Model
|
||||
from qlib.model.trainer import task_train
|
||||
from qlib.workflow.recorder import Recorder
|
||||
from qlib.workflow.task.utils import list_recorders
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.data.dataset import DatasetH
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from qlib.utils import get_date_by_shift
|
||||
|
||||
|
||||
class RMDLoader:
|
||||
"""
|
||||
Recorder Model Dataset Loader
|
||||
"""
|
||||
|
||||
def __init__(self, rec: Recorder):
|
||||
self.rec = rec
|
||||
|
||||
def get_dataset(self, start_time, end_time, segments=None) -> DatasetH:
|
||||
"""
|
||||
load, config and setup dataset.
|
||||
|
||||
This dataset is for inferene
|
||||
|
||||
Parameters
|
||||
----------
|
||||
start_time :
|
||||
the start_time of underlying data
|
||||
end_time :
|
||||
the end_time of underlying data
|
||||
segments : dict
|
||||
the segments config for dataset
|
||||
Due to the time series dataset (TSDatasetH), the test segments maybe different from start_time and end_time
|
||||
"""
|
||||
if segments is None:
|
||||
segments = {"test": (start_time, end_time)}
|
||||
dataset: DatasetH = self.rec.load_object("dataset")
|
||||
dataset.config(handler_kwargs={"start_time": start_time, "end_time": end_time}, segments=segments)
|
||||
dataset.setup_data(handler_kwargs={"init_type": DataHandlerLP.IT_LS})
|
||||
return dataset
|
||||
|
||||
def get_model(self) -> Model:
|
||||
return self.rec.load_object("params.pkl")
|
||||
|
||||
|
||||
class RecordUpdater(metaclass=ABCMeta):
|
||||
"""
|
||||
Updata a specific recorders
|
||||
"""
|
||||
|
||||
def __init__(self, record: Recorder, *args, **kwargs):
|
||||
self.record = record
|
||||
|
||||
@abstractmethod
|
||||
def update(self, *args, **kwargs):
|
||||
"""
|
||||
Update info for specific recorder
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class PredUpdater(RecordUpdater):
|
||||
"""
|
||||
Update the prediction in the Recorder
|
||||
"""
|
||||
|
||||
LATEST = "__latest"
|
||||
|
||||
def __init__(self, record: Recorder, to_date=LATEST, hist_ref: int = 0, freq="day"):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
record : Recorder
|
||||
to_date :
|
||||
update to prediction to the `to_date`
|
||||
hist_ref : int
|
||||
Sometimes, the dataset will have historical depends.
|
||||
Leave the problem to user to set the length of historical dependancy
|
||||
NOTE: the start_time is not included in the hist_ref
|
||||
# TODO: automate this step in the future.
|
||||
"""
|
||||
super().__init__(record=record)
|
||||
|
||||
self.to_date = to_date
|
||||
self.hist_ref = hist_ref
|
||||
self.freq = freq
|
||||
self.rmdl = RMDLoader(rec=record)
|
||||
|
||||
if to_date == self.LATEST:
|
||||
to_date = D.calendar(freq=freq)[-1]
|
||||
self.to_date = pd.Timestamp(to_date)
|
||||
self.old_pred = record.load_object("pred.pkl")
|
||||
self.last_end = self.old_pred.index.get_level_values("datetime").max()
|
||||
|
||||
def prepare_data(self) -> DatasetH:
|
||||
"""
|
||||
# Load dataset
|
||||
|
||||
Seperating this function will make it easier to reuse the dataset
|
||||
"""
|
||||
start_time_buffer = get_date_by_shift(self.last_end, -self.hist_ref + 1, clip_shift=False, freq=self.freq)
|
||||
start_time = get_date_by_shift(self.last_end, 1, freq=self.freq)
|
||||
seg = {"test": (start_time, self.to_date)}
|
||||
dataset = self.rmdl.get_dataset(start_time=start_time_buffer, end_time=self.to_date, segments=seg)
|
||||
return dataset
|
||||
|
||||
def update(self, dataset: DatasetH = None):
|
||||
"""
|
||||
update the precition in a recorder
|
||||
"""
|
||||
# FIXME: the problme below is not solved
|
||||
# The model dumped on GPU instances can not be loaded on CPU instance. Follow exception will raised
|
||||
# RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
|
||||
|
||||
# load dataset
|
||||
if dataset is None:
|
||||
# For reusing the dataset
|
||||
dataset = self.prepare_data()
|
||||
|
||||
# Load model
|
||||
model = self.rmdl.get_model()
|
||||
|
||||
new_pred = model.predict(dataset)
|
||||
|
||||
cb_pred = pd.concat([self.old_pred, new_pred.to_frame("score")], axis=0)
|
||||
cb_pred = cb_pred.sort_index()
|
||||
|
||||
self.record.save_objects(**{"pred.pkl": cb_pred})
|
||||
|
||||
get_module_logger(self.__class__.__name__).info(
|
||||
f"Finish updating new {new_pred.shape[0]} predictions in {self.record.info['id']}."
|
||||
)
|
||||
|
||||
|
||||
class ModelUpdater:
|
||||
|
||||
@@ -25,6 +25,8 @@ class Collector(Serializable):
|
||||
|
||||
|
||||
class RecorderCollector(Collector):
|
||||
ART_KEY_RAW = "__raw"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
exp_name,
|
||||
@@ -48,9 +50,9 @@ class RecorderCollector(Collector):
|
||||
rec_key_func = lambda rec: rec.info["id"]
|
||||
if artifacts_key is None:
|
||||
artifacts_key = self.artifacts_path.keys()
|
||||
self.rec_key = rec_key_func
|
||||
self._rec_key_func = rec_key_func
|
||||
self.artifacts_key = artifacts_key
|
||||
self.rec_filter = rec_filter_func
|
||||
self._rec_filter_func = rec_filter_func
|
||||
|
||||
def collect(self, artifacts_key=None, rec_filter_func=None):
|
||||
"""Collect different artifacts based on recorder after filtering.
|
||||
@@ -65,7 +67,7 @@ class RecorderCollector(Collector):
|
||||
if artifacts_key is None:
|
||||
artifacts_key = self.artifacts_key
|
||||
if rec_filter_func is None:
|
||||
rec_filter_func = self.rec_filter
|
||||
rec_filter_func = self._rec_filter_func
|
||||
|
||||
if isinstance(artifacts_key, str):
|
||||
artifacts_key = [artifacts_key]
|
||||
@@ -74,9 +76,12 @@ class RecorderCollector(Collector):
|
||||
# filter records
|
||||
recs_flt = list_recorders(self.exp_name, rec_filter_func)
|
||||
for _, rec in recs_flt.items():
|
||||
rec_key = self.rec_key(rec)
|
||||
rec_key = self._rec_key_func(rec)
|
||||
for key in artifacts_key:
|
||||
artifact = rec.load_object(self.artifacts_path[key])
|
||||
if self.ART_KEY_RAW == key:
|
||||
artifact = rec
|
||||
else:
|
||||
artifact = rec.load_object(self.artifacts_path[key])
|
||||
collect_dict.setdefault(key, {})[rec_key] = artifact
|
||||
|
||||
return collect_dict
|
||||
|
||||
@@ -80,6 +80,12 @@ class TaskGen(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
pass
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""
|
||||
This is just a syntactic sugar for generate
|
||||
"""
|
||||
return self.generate(*args, **kwargs)
|
||||
|
||||
|
||||
class RollingGen(TaskGen):
|
||||
ROLL_EX = TimeAdjuster.SHIFT_EX # fixed start date, expanding end date
|
||||
|
||||
@@ -18,7 +18,8 @@ import concurrent
|
||||
import pymongo
|
||||
from qlib.config import C
|
||||
from .utils import get_mongodb
|
||||
from qlib import get_module_logger
|
||||
from qlib import get_module_logger, auto_init
|
||||
import fire
|
||||
|
||||
|
||||
class TaskManager:
|
||||
@@ -49,7 +50,7 @@ class TaskManager:
|
||||
|
||||
ENCODE_FIELDS_PREFIX = ["def", "res"]
|
||||
|
||||
def __init__(self, task_pool: str):
|
||||
def __init__(self, task_pool: str = None):
|
||||
"""
|
||||
init Task Manager, remember to make the statement of MongoDB url and database name firstly.
|
||||
|
||||
@@ -59,7 +60,8 @@ class TaskManager:
|
||||
the name of Collection in MongoDB
|
||||
"""
|
||||
self.mdb = get_mongodb()
|
||||
self.task_pool = getattr(self.mdb, task_pool)
|
||||
if task_pool is not None:
|
||||
self.task_pool = getattr(self.mdb, task_pool)
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
|
||||
def list(self):
|
||||
@@ -287,6 +289,20 @@ class TaskManager:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
print(self.task_pool.update_many(query, {"$set": {"status": status}}))
|
||||
|
||||
def prioritize(self, task, priority: int):
|
||||
"""
|
||||
set priority for task
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task : dict
|
||||
The task query from the database
|
||||
priority : int
|
||||
the target priority
|
||||
"""
|
||||
update_dict = {"$set": {"priority": priority}}
|
||||
self.task_pool.update_one({"_id": task["_id"]}, update_dict)
|
||||
|
||||
def _get_undone_n(self, task_stat):
|
||||
return task_stat.get(self.STATUS_WAITING, 0) + task_stat.get(self.STATUS_RUNNING, 0)
|
||||
|
||||
@@ -345,3 +361,10 @@ def run_task(task_func, task_pool, force_release=False, *args, **kwargs):
|
||||
ever_run = True
|
||||
|
||||
return ever_run
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# This is for using it in cmd
|
||||
# E.g. : `python -m qlib.workflow.task.manage list`
|
||||
auto_init()
|
||||
fire.Fire(TaskManager)
|
||||
|
||||
Reference in New Issue
Block a user