1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 02:50:58 +08:00

Merge pull request #466 from you-n-g/online_hotfix

Online bug fix, enhancement &  docs for dataset, workflow, trainer ...
This commit is contained in:
you-n-g
2021-06-17 11:38:44 +08:00
committed by GitHub
15 changed files with 167 additions and 56 deletions

View File

@@ -20,11 +20,17 @@ def init(default_conf="client", **kwargs):
from .config import C
from .data.cache import H
H.clear()
# FIXME: this logger ignored the level in config
logger = get_module_logger("Initialization", level=logging.INFO)
skip_if_reg = kwargs.pop("skip_if_reg", False)
if skip_if_reg and C.registered:
# if we reinitialize Qlib during running an experiment `R.start`.
# it will result in loss of the recorder
logger.warning("Skip initialization because `skip_if_reg is True`")
return
H.clear()
C.set(default_conf, **kwargs)
# check path if server/local
@@ -197,14 +203,15 @@ def auto_init(**kwargs):
- Find the project configuration and init qlib
- The parsing process will be affected by the `conf_type` of the configuration file
- Init qlib with default config
- Skip initialization if already initialized
"""
kwargs["skip_if_reg"] = kwargs.get("skip_if_reg", True)
try:
pp = get_project_path(cur_path=kwargs.pop("cur_path", None))
except FileNotFoundError:
init(**kwargs)
else:
conf_pp = pp / "config.yaml"
with conf_pp.open() as f:
conf = yaml.safe_load(f)

View File

@@ -1,6 +1,6 @@
from ...utils.serial import Serializable
from typing import Union, List, Tuple, Dict, Text, Optional
from ...utils import init_instance_by_config, np_ffill
from ...utils import init_instance_by_config, np_ffill, time_to_slc_point
from ...log import get_module_logger
from .handler import DataHandler, DataHandlerLP
from copy import deepcopy
@@ -243,6 +243,8 @@ class TSDataSampler:
It works like `torch.data.utils.Dataset`, it provides a very convenient interface for constructing time-series
dataset based on tabular data.
- On time step dimension, the smaller index indicates the historical data and the larger index indicates the future
data.
If user have further requirements for processing data, user could process them based on `TSDataSampler` or create
more powerful subclasses.
@@ -309,11 +311,19 @@ class TSDataSampler:
self.data_index = deepcopy(self.data.index)
if flt_data is not None:
self.flt_data = np.array(flt_data.reindex(self.data_index)).reshape(-1)
if isinstance(flt_data, pd.DataFrame):
assert len(flt_data.columns) == 1
flt_data = flt_data.iloc[:, 0]
# NOTE: bool(np.nan) is True !!!!!!!!
# make sure reindex comes first. Otherwise extra NaN may appear.
flt_data = flt_data.reindex(self.data_index).fillna(False).astype(np.bool)
self.flt_data = flt_data.values
self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map)
self.data_index = self.data_index[np.where(self.flt_data == True)[0]]
self.start_idx, self.end_idx = self.data_index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end))
self.start_idx, self.end_idx = self.data_index.slice_locs(
start=time_to_slc_point(start), end=time_to_slc_point(end)
)
self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance
del self.data # save memory
@@ -341,7 +351,7 @@ class TSDataSampler:
setattr(self, k, v)
@staticmethod
def build_index(data: pd.DataFrame) -> dict:
def build_index(data: pd.DataFrame) -> Tuple[pd.DataFrame, dict]:
"""
The relation of the data
@@ -352,9 +362,15 @@ class TSDataSampler:
Returns
-------
dict:
{<index>: <prev_index or None>}
# get the previous index of a line given index
Tuple[pd.DataFrame, dict]:
1) the first element: reshape the original index into a <datetime(row), instrument(column)> 2D dataframe
instrument SH600000 SH600004 SH600006 SH600007 SH600008 SH600009 ...
datetime
2021-01-11 0 1 2 3 4 5 ...
2021-01-12 4146 4147 4148 4149 4150 4151 ...
2021-01-13 8293 8294 8295 8296 8297 8298 ...
2021-01-14 12441 12442 12443 12444 12445 12446 ...
2) the second element: {<original index>: <row, col>}
"""
# object incase of pandas converting int to flaot
idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=object)

View File

@@ -28,16 +28,18 @@ class QlibLogger(metaclass=MetaLogger):
def __init__(self, module_name):
self.module_name = module_name
self.level = 0
# this feature name conflicts with the attribute with Logger
# rename it to avoid some corner cases that result in comparing `str` and `int`
self.__level = 0
@property
def logger(self):
logger = logging.getLogger(self.module_name)
logger.setLevel(self.level)
logger.setLevel(self.__level)
return logger
def setLevel(self, level):
self.level = level
self.__level = level
def __getattr__(self, name):
# During unpickling, python will call __getattr__. Use this line to avoid maximum recursion error.

View File

@@ -97,7 +97,7 @@ class ModelFT(Model):
# Finetune model based on previous trained model
with R.start(experiment_name="finetune model"):
recorder = R.get_recorder(rid, experiment_name="init models")
recorder = R.get_recorder(recorder_id=rid, experiment_name="init models")
model = recorder.load_object("init_model")
model.finetune(dataset, num_boost_round=10)

View File

@@ -8,7 +8,7 @@ There are two steps in each Trainer including ``train``(make model recorder) and
This is a concept called ``DelayTrainer``, which can be used in online simulating for parallel training.
In ``DelayTrainer``, the first step is only to save some necessary info to model recorders, and the second step which will be finished in the end can do some concurrent and time-consuming operations such as model fitting.
``Qlib`` offer two kinds of Trainer, ``TrainerR`` is the simplest way and ``TrainerRM`` is based on TaskManager to help manager tasks lifecycle automatically.
``Qlib`` offer two kinds of Trainer, ``TrainerR`` is the simplest way and ``TrainerRM`` is based on TaskManager to help manager tasks lifecycle automatically.
"""
import socket
@@ -153,6 +153,9 @@ class Trainer:
"""
return self.delay
def __call__(self, *args, **kwargs) -> list:
return self.end_train(self.train(*args, **kwargs))
class TrainerR(Trainer):
"""
@@ -286,7 +289,9 @@ class TrainerRM(Trainer):
# This tag is the _id in TaskManager to distinguish tasks.
TM_ID = "_id in TaskManager"
def __init__(self, experiment_name: str = None, task_pool: str = None, train_func=task_train):
def __init__(
self, experiment_name: str = None, task_pool: str = None, train_func=task_train, skip_run_task: bool = False
):
"""
Init TrainerR.
@@ -294,11 +299,16 @@ class TrainerRM(Trainer):
experiment_name (str): the default name of experiment.
task_pool (str): task pool name in TaskManager. None for use same name as experiment_name.
train_func (Callable, optional): default training method. Defaults to `task_train`.
skip_run_task (bool):
If skip_run_task == True:
Only run_task in the worker. Otherwise skip run_task.
"""
super().__init__()
self.experiment_name = experiment_name
self.task_pool = task_pool
self.train_func = train_func
self.skip_run_task = skip_run_task
def train(
self,
@@ -340,15 +350,16 @@ class TrainerRM(Trainer):
tm = TaskManager(task_pool=task_pool)
_id_list = tm.create_task(tasks) # all tasks will be saved to MongoDB
query = {"_id": {"$in": _id_list}}
run_task(
train_func,
task_pool,
query=query, # only train these tasks
experiment_name=experiment_name,
before_status=before_status,
after_status=after_status,
**kwargs,
)
if not self.skip_run_task:
run_task(
train_func,
task_pool,
query=query, # only train these tasks
experiment_name=experiment_name,
before_status=before_status,
after_status=after_status,
**kwargs,
)
if not self.is_delay():
tm.wait(query=query)
@@ -411,6 +422,7 @@ class DelayTrainerRM(TrainerRM):
task_pool: str = None,
train_func=begin_task_train,
end_train_func=end_task_train,
skip_run_task: bool = False,
):
"""
Init DelayTrainerRM.
@@ -420,10 +432,15 @@ class DelayTrainerRM(TrainerRM):
task_pool (str): task pool name in TaskManager. None for use same name as experiment_name.
train_func (Callable, optional): default train method. Defaults to `begin_task_train`.
end_train_func (Callable, optional): default end_train method. Defaults to `end_task_train`.
skip_run_task (bool):
If skip_run_task == True:
Only run_task in the worker. Otherwise skip run_task.
E.g. Starting trainer on a CPU VM and then waiting tasks to be finished on GPU VMs.
"""
super().__init__(experiment_name, task_pool, train_func)
self.end_train_func = end_train_func
self.delay = True
self.skip_run_task = skip_run_task
def train(self, tasks: list, train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:
"""
@@ -477,14 +494,15 @@ class DelayTrainerRM(TrainerRM):
_id_list.append(rec.list_tags()[self.TM_ID])
query = {"_id": {"$in": _id_list}}
run_task(
end_train_func,
task_pool,
query=query, # only train these tasks
experiment_name=experiment_name,
before_status=TaskManager.STATUS_PART_DONE,
**kwargs,
)
if not self.skip_run_task:
run_task(
end_train_func,
task_pool,
query=query, # only train these tasks
experiment_name=experiment_name,
before_status=TaskManager.STATUS_PART_DONE,
**kwargs,
)
TaskManager(task_pool=task_pool).wait(query=query)

View File

@@ -642,6 +642,28 @@ def split_pred(pred, number=None, split_date=None):
return pred_left, pred_right
def time_to_slc_point(t: Union[None, str, pd.Timestamp]) -> Union[None, pd.Timestamp]:
"""
Time slicing in Qlib or Pandas is a frequently-used action.
However, user often input all kinds of data format to represent time.
This function will help user to convert these inputs into a uniform format which is friendly to time slicing.
Parameters
----------
t : Union[None, str, pd.Timestamp]
original time
Returns
-------
Union[None, pd.Timestamp]:
"""
if t is None:
# None represents unbounded in Qlib or Pandas(e.g. df.loc[slice(None, "20210303")]).
return t
else:
return pd.Timestamp(t)
def can_use_cache():
res = True
r = get_redis_connection()

View File

@@ -216,9 +216,9 @@ class QlibRecorder:
-------
A dictionary (id -> recorder) of recorder information that being stored.
"""
return self.get_exp(experiment_id, experiment_name).list_recorders()
return self.get_exp(experiment_id=experiment_id, experiment_name=experiment_name).list_recorders()
def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True) -> Experiment:
def get_exp(self, *, experiment_id=None, experiment_name=None, create: bool = True) -> Experiment:
"""
Method for retrieving an experiment with given id or name. Once the `create` argument is set to
True, if no valid experiment is found, this method will create one for you. Otherwise, it will
@@ -263,7 +263,7 @@ class QlibRecorder:
# Case 2
with R.start('test'):
exp = R.get_exp('test1')
exp = R.get_exp(experiment_name='test1')
# Case 3
exp = R.get_exp() -> a default experiment.
@@ -288,7 +288,9 @@ class QlibRecorder:
-------
An experiment instance with given id or name.
"""
return self.exp_manager.get_exp(experiment_id, experiment_name, create, start=False)
return self.exp_manager.get_exp(
experiment_id=experiment_id, experiment_name=experiment_name, create=create, start=False
)
def delete_exp(self, experiment_id=None, experiment_name=None):
"""
@@ -332,7 +334,9 @@ class QlibRecorder:
"""
self.exp_manager.set_uri(uri)
def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=None) -> Recorder:
def get_recorder(
self, *, recorder_id=None, recorder_name=None, experiment_id=None, experiment_name=None
) -> Recorder:
"""
Method for retrieving a recorder.
@@ -385,7 +389,7 @@ class QlibRecorder:
-------
A recorder instance.
"""
return self.get_exp(experiment_name=experiment_name, create=False).get_recorder(
return self.get_exp(experiment_name=experiment_name, experiment_id=experiment_id, create=False).get_recorder(
recorder_id, recorder_name, create=False, start=False
)

View File

@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from typing import Union
import mlflow, logging
from mlflow.entities import ViewType
from mlflow.exceptions import MlflowException
@@ -213,11 +214,15 @@ class Experiment:
"""
raise NotImplementedError(f"Please implement the `_get_recorder` method")
def list_recorders(self):
def list_recorders(self, **flt_kwargs):
"""
List all the existing recorders of this experiment. Please first get the experiment instance before calling this method.
If user want to use the method `R.list_recorders()`, please refer to the related API document in `QlibRecorder`.
flt_kwargs : dict
filter recorders by conditions
e.g. list_recorders(status=Recorder.STATUS_FI)
Returns
-------
A dictionary (id -> recorder) of recorder information that being stored.
@@ -320,11 +325,21 @@ class MLflowExperiment(Experiment):
UNLIMITED = 50000 # FIXME: Mlflow can only list 50000 records at most!!!!!!!
def list_recorders(self, max_results=UNLIMITED):
def list_recorders(self, max_results: int = UNLIMITED, status: Union[str, None] = None):
"""
Parameters
----------
max_results : int
the number limitation of the results
status : str
the criteria based on status to filter results.
`None` indicates no filtering.
"""
runs = self._client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)
recorders = dict()
for i in range(len(runs)):
recorder = MLflowRecorder(self.id, self._uri, mlflow_run=runs[i])
recorders[runs[i].info.run_id] = recorder
if status is None or recorder.status == status:
recorders[runs[i].info.run_id] = recorder
return recorders

View File

@@ -109,7 +109,7 @@ class ExpManager:
"""
raise NotImplementedError(f"Please implement the `search_records` method.")
def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True, start: bool = False):
def get_exp(self, *, experiment_id=None, experiment_name=None, create: bool = True, start: bool = False):
"""
Retrieve an experiment. This method includes getting an active experiment, and get_or_create a specific experiment.
@@ -190,7 +190,7 @@ class ExpManager:
except ValueError:
if experiment_name is None:
experiment_name = self._default_exp_name
logger.info(f"No valid experiment found. Create a new experiment with name {experiment_name}.")
logger.warning(f"No valid experiment found. Create a new experiment with name {experiment_name}.")
return self.create_exp(experiment_name), True
def _get_exp(self, experiment_id=None, experiment_name=None) -> Experiment:
@@ -352,6 +352,8 @@ class MLflowExpManager(ExpManager):
), "Please input at least one of experiment/recorder id or name before retrieving experiment/recorder."
if experiment_id is not None:
try:
# NOTE: the mlflow's experiment_id must be str type...
# https://www.mlflow.org/docs/latest/python_api/mlflow.tracking.html#mlflow.tracking.MlflowClient.get_experiment
exp = self.client.get_experiment(experiment_id)
if exp.lifecycle_stage.upper() == "DELETED":
raise MlflowException("No valid experiment has been found.")

View File

@@ -6,7 +6,7 @@ OnlineManager can manage a set of `Online Strategy <#Online Strategy>`_ and run
With the change of time, the decisive models will be also changed. In this module, we call those contributing models `online` models.
In every routine(such as every day or every minute), the `online` models may be changed and the prediction of them needs to be updated.
So this module provides a series of methods to control this process.
So this module provides a series of methods to control this process.
This module also provides a method to simulate `Online Strategy <#Online Strategy>`_ in history.
Which means you can verify your strategy or find a better one.
@@ -31,7 +31,7 @@ Simulation + Trainer When your models have some temporal dependence on the
Simulation + DelayTrainer When your models don't have any temporal dependence, you can use DelayTrainer
for the ability to multitasking. It means all tasks in all routines
can be REAL trained at the end of simulating. The signals will be prepared well at
can be REAL trained at the end of simulating. The signals will be prepared well at
different time segments (based on whether or not any new model is online).
========================= ===================================================================================
"""
@@ -113,6 +113,8 @@ class OnlineManager(Serializable):
models = self.trainer.train(tasks, experiment_name=strategy.name_id)
models_list.append(models)
self.logger.info(f"Finished training {len(models)} models.")
# FIXME: Traing multiple online models at `first_train` will result in getting too much online models at the
# start.
online_models = strategy.prepare_online_models(models, **model_kwargs)
self.history.setdefault(self.cur_time, {})[strategy] = online_models
@@ -148,8 +150,6 @@ class OnlineManager(Serializable):
models_list = []
for strategy in self.strategies:
self.logger.info(f"Strategy `{strategy.name_id}` begins routine...")
if self.status == self.STATUS_NORMAL:
strategy.tool.update_online_pred()
tasks = strategy.prepare_tasks(self.cur_time, **task_kwargs)
models = self.trainer.train(tasks, experiment_name=strategy.name_id)
@@ -158,6 +158,11 @@ class OnlineManager(Serializable):
online_models = strategy.prepare_online_models(models, **model_kwargs)
self.history.setdefault(self.cur_time, {})[strategy] = online_models
# The online model may changes in the above processes
# So updating the predictions of online models should be the last step
if self.status == self.STATUS_NORMAL:
strategy.tool.update_online_pred()
if not self.status == self.STATUS_SIMULATING or not self.trainer.is_delay():
for strategy, models in zip(self.strategies, models_list):
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
@@ -236,7 +241,7 @@ class OnlineManager(Serializable):
SIM_LOG_NAME = "SIMULATE_INFO"
def simulate(
self, end_time, frequency="day", task_kwargs={}, model_kwargs={}, signal_kwargs={}
self, end_time=None, frequency="day", task_kwargs={}, model_kwargs={}, signal_kwargs={}
) -> Union[pd.Series, pd.DataFrame]:
"""
Starting from the current time, this method will simulate every routine in OnlineManager until the end time.

View File

@@ -52,6 +52,12 @@ class OnlineStrategy:
NOTE: Reset all online models to trained models. If there are no trained models, then do nothing.
**NOTE**:
Current implementation is very naive. Here is a more complex situation which is more closer to the
practical scenarios.
1. Train new models at the day before `test_start` (at time stamp `T`)
2. Switch models at the `test_start` (at time timestamp `T + 1` typically)
Args:
models (list): a list of models.
cur_time (pd.Dataframe): current time from OnlineManger. None for the latest.

View File

@@ -136,7 +136,7 @@ class PredUpdater(RecordUpdater):
# https://github.com/pytorch/pytorch/issues/16797
start_time = get_date_by_shift(self.last_end, 1, freq=self.freq)
if start_time >= self.to_date:
if start_time > self.to_date:
self.logger.info(
f"The prediction in {self.record.info['id']} are latest ({start_time}). No need to update to {self.to_date}."
)

View File

@@ -6,6 +6,7 @@ Collector module can collect objects from everywhere and process them such as me
"""
from typing import Callable, Dict, List
from qlib.log import get_module_logger
from qlib.utils.serial import Serializable
from qlib.workflow import R
@@ -192,6 +193,7 @@ class RecorderCollector(Collector):
if rec_filter_func is None or rec_filter_func(rec):
recs_flt[rid] = rec
logger = get_module_logger("RecorderCollector")
for _, rec in recs_flt.items():
rec_key = self.rec_key_func(rec)
for key in artifacts_key:
@@ -205,7 +207,13 @@ class RecorderCollector(Collector):
# only collect existing artifact
continue
raise e
collect_dict.setdefault(key, {})[rec_key] = artifact
# give user some warning if the values are overridden
cdd = collect_dict.setdefault(key, {})
if rec_key in cdd:
logger.warning(
f"key '{rec_key}' is duplicated. Previous value will be overrides. Please check you `rec_key_func`"
)
cdd[rec_key] = artifact
return collect_dict

View File

@@ -6,6 +6,8 @@ TaskGenerator module can generate many tasks based on TaskGen and some task temp
import abc
import copy
from typing import List, Union, Callable
from qlib.utils import transform_end_date
from .utils import TimeAdjuster
@@ -199,7 +201,7 @@ class RollingGen(TaskGen):
# First rolling
# 1) prepare the end point
segments: dict = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"]))
test_end = self.ta.max() if segments[self.test_key][1] is None else segments[self.test_key][1]
test_end = transform_end_date(segments[self.test_key][1])
# 2) and init test segments
test_start_idx = self.ta.align_idx(segments[self.test_key][0])
segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1))

View File

@@ -272,10 +272,10 @@ class TaskManager:
task = self.fetch_task(query=query, status=status)
try:
yield task
except Exception:
except (Exception, KeyboardInterrupt): # KeyboardInterrupt is not a subclass of Exception
if task is not None:
self.logger.info("Returning task before raising error")
self.return_task(task)
self.return_task(task, status=status) # return task as the original status
self.logger.info("Task returned")
raise
@@ -411,7 +411,11 @@ class TaskManager:
self.task_pool.update_one({"_id": task["_id"]}, update_dict)
def _get_undone_n(self, task_stat):
return task_stat.get(self.STATUS_WAITING, 0) + task_stat.get(self.STATUS_RUNNING, 0)
return (
task_stat.get(self.STATUS_WAITING, 0)
+ task_stat.get(self.STATUS_RUNNING, 0)
+ task_stat.get(self.STATUS_PART_DONE, 0)
)
def _get_total(self, task_stat):
return sum(task_stat.values())
@@ -429,7 +433,7 @@ class TaskManager:
last_undone_n = self._get_undone_n(task_stat)
if last_undone_n == 0:
return
self.logger.warn(f"Waiting for {last_undone_n} undone tasks. Please make sure they are running.")
self.logger.warning(f"Waiting for {last_undone_n} undone tasks. Please make sure they are running.")
with tqdm(total=total, initial=total - last_undone_n) as pbar:
while True:
time.sleep(10)