mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
logger & doc
This commit is contained in:
@@ -114,12 +114,10 @@ class OnlineSimulationExample:
|
||||
step=rolling_step, rtype=RollingGen.ROLL_SD, ds_extra_mod_func=None
|
||||
) # The rolling tasks generator, ds_extra_mod_func is None because we just need simulate to 2018-10-31 and needn't change handler end time.
|
||||
self.trainer = DelayTrainerRM(self.exp_name, self.task_pool)
|
||||
self.task_manager = TaskManager(self.task_pool) # A good way to manage all your tasks
|
||||
self.rolling_online_manager = OnlineManager(
|
||||
RollingAverageStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen, need_log=False),
|
||||
RollingAverageStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen),
|
||||
trainer=self.trainer,
|
||||
begin_time=self.start_time,
|
||||
need_log=False,
|
||||
)
|
||||
self.tasks = tasks
|
||||
|
||||
|
||||
@@ -200,7 +200,7 @@ class DatasetH(Dataset):
|
||||
The data to fetch: DK_*
|
||||
Default is DK_I, which indicate fetching data for **inference**.
|
||||
|
||||
kwargs :
|
||||
kwargs :
|
||||
The parameters that kwargs may contain:
|
||||
flt_col : str
|
||||
It only exists in TSDatasetH, can be used to add a column of data(True or False) to filter data.
|
||||
@@ -250,7 +250,9 @@ class TSDataSampler:
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, data: pd.DataFrame, start, end, step_len: int, fillna_type: str = "none", dtype=None, flt_data=None):
|
||||
def __init__(
|
||||
self, data: pd.DataFrame, start, end, step_len: int, fillna_type: str = "none", dtype=None, flt_data=None
|
||||
):
|
||||
"""
|
||||
Build a dataset which looks like torch.data.utils.Dataset.
|
||||
|
||||
@@ -518,17 +520,17 @@ class TSDatasetH(DatasetH):
|
||||
"""
|
||||
dtype = kwargs.pop("dtype", None)
|
||||
start, end = slc.start, slc.stop
|
||||
flt_col = kwargs.pop('flt_col', None)
|
||||
flt_col = kwargs.pop("flt_col", None)
|
||||
# TSDatasetH will retrieve more data for complete
|
||||
data = self._prepare_raw_seg(slc, **kwargs)
|
||||
|
||||
flt_kwargs = deepcopy(kwargs)
|
||||
if flt_col is not None:
|
||||
flt_kwargs['col_set'] = flt_col
|
||||
flt_kwargs["col_set"] = flt_col
|
||||
flt_data = self._prepare_raw_seg(slc, **flt_kwargs)
|
||||
assert len(flt_data.columns) == 1
|
||||
else:
|
||||
flt_data = None
|
||||
|
||||
tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len, dtype=dtype, flt_data=flt_data)
|
||||
return tsds
|
||||
return tsds
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
The Trainer will train a list of tasks and return a list of model recorder.
|
||||
There are two steps in each Trainer including ``train``(make model recorder) and ``end_train``(modify model recorder).
|
||||
|
||||
This is concept called ``DelayTrainer``, which can be used in online simulating to parallel training.
|
||||
This is concept called ``DelayTrainer``, which can be used in online simulating for parallel training.
|
||||
In ``DelayTrainer``, the first step is only to save some necessary info to model recorder, and the second step which will be finished in the end can do some concurrent and time-consuming operations such as model fitting.
|
||||
|
||||
``Qlib`` offer two kind of Trainer, ``TrainerR`` is the simplest way and ``TrainerRM`` is based on TaskManager to help manager tasks lifecycle automatically.
|
||||
@@ -103,7 +103,8 @@ def task_train(task_config: dict, experiment_name: str) -> Recorder:
|
||||
|
||||
class Trainer:
|
||||
"""
|
||||
The trainer which can train a list of model
|
||||
The trainer can train a list of model.
|
||||
There are Trainer and DelayTrainer, which can be distinguished by when it will finish real training.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
@@ -113,6 +114,9 @@ class Trainer:
|
||||
"""
|
||||
Given a list of model definition, begin a training and return the models.
|
||||
|
||||
For Trainer, it finish real training in this method.
|
||||
For DelayTrainer, it only do some preparation in this method.
|
||||
|
||||
Args:
|
||||
tasks: a list of tasks
|
||||
|
||||
@@ -126,6 +130,9 @@ class Trainer:
|
||||
Given a list of models, finished something in the end of training if you need.
|
||||
The models maybe Recorder, txt file, database and so on.
|
||||
|
||||
For Trainer, it do some finishing touches in this method.
|
||||
For DelayTrainer, it finish real training in this method.
|
||||
|
||||
Args:
|
||||
models: a list of models
|
||||
|
||||
@@ -326,7 +333,7 @@ class TrainerRM(Trainer):
|
||||
recs.append(rec)
|
||||
return recs
|
||||
|
||||
def end_train(self, recs: list, **kwargs) -> list:
|
||||
def end_train(self, recs: list, **kwargs) -> List[Recorder]:
|
||||
for rec in recs:
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
|
||||
return recs
|
||||
@@ -358,7 +365,7 @@ class DelayTrainerRM(TrainerRM):
|
||||
self.end_train_func = end_train_func
|
||||
self.delay = True
|
||||
|
||||
def train(self, tasks: list, train_func=None, experiment_name: str = None, **kwargs):
|
||||
def train(self, tasks: list, train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:
|
||||
"""
|
||||
Same as `train` of TrainerRM, after_status will be STATUS_PART_DONE.
|
||||
Args:
|
||||
@@ -378,7 +385,7 @@ class DelayTrainerRM(TrainerRM):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def end_train(self, recs, end_train_func=None, experiment_name: str = None, **kwargs):
|
||||
def end_train(self, recs, end_train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:
|
||||
"""
|
||||
Given a list of Recorder and return a list of trained Recorder.
|
||||
This class will finish real data loading and model fitting.
|
||||
|
||||
@@ -12,11 +12,13 @@ This module also provide a method to simulate `Online Strategy <#Online Strategy
|
||||
Which means you can verify your strategy or find a better one.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Callable, Dict, List, Union
|
||||
|
||||
import pandas as pd
|
||||
from qlib import get_module_logger
|
||||
from qlib.data.data import D
|
||||
from qlib.log import set_global_logger_level
|
||||
from qlib.model.ens.ensemble import AverageEnsemble
|
||||
from qlib.model.trainer import DelayTrainerR, Trainer
|
||||
from qlib.utils import flatten_dict
|
||||
@@ -37,7 +39,6 @@ class OnlineManager(Serializable):
|
||||
trainer: Trainer = None,
|
||||
begin_time: Union[str, pd.Timestamp] = None,
|
||||
freq="day",
|
||||
need_log=True,
|
||||
):
|
||||
"""
|
||||
Init OnlineManager.
|
||||
@@ -48,10 +49,8 @@ class OnlineManager(Serializable):
|
||||
begin_time (Union[str,pd.Timestamp], optional): the OnlineManager will begin at this time. Defaults to None for using latest date.
|
||||
trainer (Trainer): the trainer to train task. None for using DelayTrainerR.
|
||||
freq (str, optional): data frequency. Defaults to "day".
|
||||
need_log (bool, optional): print log or not. Defaults to True.
|
||||
"""
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
self.need_log = need_log
|
||||
if not isinstance(strategy, list):
|
||||
strategy = [strategy]
|
||||
self.strategy = strategy
|
||||
@@ -60,19 +59,18 @@ class OnlineManager(Serializable):
|
||||
begin_time = D.calendar(freq=self.freq).max()
|
||||
self.begin_time = pd.Timestamp(begin_time)
|
||||
self.cur_time = self.begin_time
|
||||
# The history of online models, which is a dict like {begin_time, {strategy, [online_models]}}
|
||||
# begin_time means when online_models are onlined
|
||||
self.history = {}
|
||||
# OnlineManager will recorder the history of online models, which is a dict like {begin_time, {strategy, [online_models]}}. begin_time means when online_models are onlined.
|
||||
self.history = {}
|
||||
if trainer is None:
|
||||
trainer = DelayTrainerR()
|
||||
self.trainer = trainer
|
||||
self.signals = None
|
||||
|
||||
def first_train(self, strategies:List[OnlineStrategy]=None, model_kwargs: dict = {}):
|
||||
def first_train(self, strategies: List[OnlineStrategy] = None, model_kwargs: dict = {}):
|
||||
"""
|
||||
Get tasks from every strategy's first_tasks method and train them.
|
||||
If using DelayTrainer, it can finish training all together after every strategy's first_tasks.
|
||||
|
||||
|
||||
Args:
|
||||
strategies (List[OnlineStrategy]): the strategies list (need this param when adding strategies). None for use default strategies.
|
||||
model_kwargs (dict): the params for `prepare_online_models`
|
||||
@@ -119,8 +117,7 @@ class OnlineManager(Serializable):
|
||||
for strategy in self.strategy:
|
||||
if not delay:
|
||||
strategy.tool.update_online_pred()
|
||||
if self.need_log:
|
||||
self.logger.info(f"Strategy `{strategy.name_id}` begins routine...")
|
||||
self.logger.info(f"Strategy `{strategy.name_id}` begins routine...")
|
||||
|
||||
tasks = strategy.prepare_tasks(self.cur_time, **task_kwargs)
|
||||
models = self.trainer.train(tasks)
|
||||
@@ -144,19 +141,20 @@ class OnlineManager(Serializable):
|
||||
delay (bool, optional): if delay prepare models. Defaults to False.
|
||||
model_kwargs (dict, optional): the params for `prepare_online_models`.
|
||||
"""
|
||||
if not models:
|
||||
return
|
||||
if not delay:
|
||||
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
|
||||
online_models = strategy.prepare_online_models(models, **model_kwargs)
|
||||
else:
|
||||
# just set every models as online models temporarily before ``prepare_online_models``
|
||||
online_models = models
|
||||
if len(online_models) > 0:
|
||||
strategy.tool.reset_online_tag(online_models)
|
||||
self.history.setdefault(self.cur_time, {})[strategy] = online_models
|
||||
self.history.setdefault(self.cur_time, {})[strategy] = online_models
|
||||
|
||||
def get_collector(self) -> MergeCollector:
|
||||
"""
|
||||
Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results from every strategy.
|
||||
This collector can be a basis as the signals preparation.
|
||||
|
||||
Returns:
|
||||
MergeCollector: the collector to merge other collectors.
|
||||
@@ -205,20 +203,23 @@ class OnlineManager(Serializable):
|
||||
signals = pd.concat([old_signals, new_signals], axis=0)
|
||||
else:
|
||||
new_signals = signals
|
||||
if self.need_log:
|
||||
self.logger.info(f"Finished preparing new {len(new_signals)} signals.")
|
||||
self.logger.info(f"Finished preparing new {len(new_signals)} signals.")
|
||||
self.signals = signals
|
||||
return new_signals
|
||||
|
||||
def get_signals(self) -> pd.Series:
|
||||
def get_signals(self) -> Union[pd.Series, pd.DataFrame]:
|
||||
"""
|
||||
Get prepared online signals.
|
||||
|
||||
Returns:
|
||||
pd.Series: signals
|
||||
Union[pd.Series, pd.DataFrame]: pd.Series for only one signals every datetime.
|
||||
pd.DataFrame for multiple signals, for example, buy and sell operation use different trading signal.
|
||||
"""
|
||||
return self.signals
|
||||
|
||||
SIM_LOG_LEVEL = logging.INFO + 1
|
||||
SIM_LOG_NAME = "SIMULATE_INFO"
|
||||
|
||||
def simulate(self, end_time, frequency="day", task_kwargs={}, model_kwargs={}, signal_kwargs={}):
|
||||
"""
|
||||
Starting from current time, this method will simulate every routine in OnlineManager until end time.
|
||||
@@ -239,8 +240,13 @@ class OnlineManager(Serializable):
|
||||
"""
|
||||
cal = D.calendar(start_time=self.cur_time, end_time=end_time, freq=frequency)
|
||||
self.first_train()
|
||||
|
||||
simulate_level = self.SIM_LOG_LEVEL
|
||||
set_global_logger_level(simulate_level)
|
||||
logging.addLevelName(simulate_level, self.SIM_LOG_NAME)
|
||||
|
||||
for cur_time in cal:
|
||||
self.logger.info(f"Simulating at {str(cur_time)}......")
|
||||
self.logger.log(level=simulate_level, msg=f"Simulating at {str(cur_time)}......")
|
||||
self.routine(
|
||||
cur_time,
|
||||
delay=self.trainer.is_delay(),
|
||||
@@ -251,6 +257,8 @@ class OnlineManager(Serializable):
|
||||
# delay prepare the models and signals
|
||||
if self.trainer.is_delay():
|
||||
self.delay_prepare(model_kwargs=model_kwargs, signal_kwargs=signal_kwargs)
|
||||
|
||||
set_global_logger_level(logging.INFO)
|
||||
self.logger.info(f"Finished preparing signals")
|
||||
return self.get_collector()
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ class OnlineStrategy:
|
||||
OnlineStrategy is working with `Online Manager <#Online Manager>`_, responsing how the tasks are generated, the models are updated and signals are perpared.
|
||||
"""
|
||||
|
||||
def __init__(self, name_id: str, need_log=True):
|
||||
def __init__(self, name_id: str):
|
||||
"""
|
||||
Init OnlineStrategy.
|
||||
This module **MUST** use `Trainer <../reference/api.html#Trainer>`_ to finishing model training.
|
||||
@@ -30,12 +30,10 @@ class OnlineStrategy:
|
||||
Args:
|
||||
name_id (str): a unique name or id
|
||||
trainer (Trainer, optional): a instance of Trainer. Defaults to None.
|
||||
need_log (bool, optional): print log or not. Defaults to True.
|
||||
"""
|
||||
self.name_id = name_id
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
self.need_log = need_log
|
||||
self.tool = OnlineTool(need_log)
|
||||
self.tool = OnlineTool()
|
||||
|
||||
def prepare_tasks(self, cur_time, **kwargs) -> List[dict]:
|
||||
"""
|
||||
@@ -46,20 +44,21 @@ class OnlineStrategy:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `prepare_tasks` method.")
|
||||
|
||||
def prepare_online_models(self, models, cur_time=None, check_func=None, **kwargs):
|
||||
def prepare_online_models(self, models, cur_time=None, check_func=None) -> List[object]:
|
||||
"""
|
||||
A typically implementation, but maybe you will need old models by online_tool.
|
||||
Use trainer to train a list of tasks and set the trained model to `online`.
|
||||
Select some models as the online models from the trained models.
|
||||
|
||||
NOTE: This method will first offline all models and online the online models prepared by this method. So you can find last online models by OnlineTool.online_models if you still need them.
|
||||
NOTE: This method offline all models and online the online models prepared by this method (if have). So you can find last online models by OnlineTool.online_models if you still need them.
|
||||
|
||||
Args:
|
||||
tasks (list): a list of tasks.
|
||||
check_func: the method to judge if a model can be online.
|
||||
The parameter is the model record and return True for online.
|
||||
None for online every models.
|
||||
**kwargs: will be passed to end_train which means will be passed to customized train method.
|
||||
|
||||
Returns:
|
||||
List[object]: a list of selected models.
|
||||
"""
|
||||
if check_func is not None:
|
||||
online_models = []
|
||||
@@ -101,7 +100,6 @@ class RollingAverageStrategy(OnlineStrategy):
|
||||
name_id: str,
|
||||
task_template: Union[dict, List[dict]],
|
||||
rolling_gen: RollingGen,
|
||||
need_log=True,
|
||||
):
|
||||
"""
|
||||
Init RollingAverageStrategy.
|
||||
@@ -112,15 +110,14 @@ class RollingAverageStrategy(OnlineStrategy):
|
||||
name_id (str): a unique name or id. Will be also the name of Experiment.
|
||||
task_template (Union[dict,List[dict]]): a list of task_template or a single template, which will be used to generate many tasks using rolling_gen.
|
||||
rolling_gen (RollingGen): an instance of RollingGen
|
||||
need_log (bool, optional): print log or not. Defaults to True.
|
||||
"""
|
||||
super().__init__(name_id=name_id, need_log=need_log)
|
||||
super().__init__(name_id=name_id)
|
||||
self.exp_name = self.name_id
|
||||
if not isinstance(task_template, list):
|
||||
task_template = [task_template]
|
||||
self.task_template = task_template
|
||||
self.rg = rolling_gen
|
||||
self.tool = OnlineToolR(self.exp_name, need_log)
|
||||
self.tool = OnlineToolR(self.exp_name)
|
||||
self.ta = TimeAdjuster()
|
||||
|
||||
def get_collector(self, process_list=[RollingGroup()], rec_key_func=None, rec_filter_func=None, artifacts_key=None):
|
||||
@@ -180,10 +177,9 @@ class RollingAverageStrategy(OnlineStrategy):
|
||||
self.logger.warn(f"No latest online recorders, no new tasks.")
|
||||
return []
|
||||
calendar_latest = D.calendar(end_time=cur_time)[-1] if cur_time is None else cur_time
|
||||
if self.need_log:
|
||||
self.logger.info(
|
||||
f"The interval between current time {calendar_latest} and last rolling test begin time {max_test[0]} is {self.ta.cal_interval(calendar_latest, max_test[0])}, the rolling step is {self.rg.step}"
|
||||
)
|
||||
self.logger.info(
|
||||
f"The interval between current time {calendar_latest} and last rolling test begin time {max_test[0]} is {self.ta.cal_interval(calendar_latest, max_test[0])}, the rolling step is {self.rg.step}"
|
||||
)
|
||||
if self.ta.cal_interval(calendar_latest, max_test[0]) >= self.rg.step:
|
||||
old_tasks = []
|
||||
tasks_tmp = []
|
||||
|
||||
@@ -60,10 +60,9 @@ class RecordUpdater(metaclass=ABCMeta):
|
||||
Update a specific recorders
|
||||
"""
|
||||
|
||||
def __init__(self, record: Recorder, need_log=True, *args, **kwargs):
|
||||
def __init__(self, record: Recorder, *args, **kwargs):
|
||||
self.record = record
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
self.need_log = need_log
|
||||
|
||||
@abstractmethod
|
||||
def update(self, *args, **kwargs):
|
||||
@@ -78,7 +77,7 @@ class PredUpdater(RecordUpdater):
|
||||
Update the prediction in the Recorder
|
||||
"""
|
||||
|
||||
def __init__(self, record: Recorder, to_date=None, hist_ref: int = 0, freq="day", need_log=True):
|
||||
def __init__(self, record: Recorder, to_date=None, hist_ref: int = 0, freq="day"):
|
||||
"""
|
||||
Init PredUpdater.
|
||||
|
||||
@@ -96,7 +95,7 @@ class PredUpdater(RecordUpdater):
|
||||
|
||||
"""
|
||||
# TODO: automate this hist_ref in the future.
|
||||
super().__init__(record=record, need_log=need_log)
|
||||
super().__init__(record=record)
|
||||
|
||||
self.to_date = to_date
|
||||
self.hist_ref = hist_ref
|
||||
@@ -138,8 +137,7 @@ class PredUpdater(RecordUpdater):
|
||||
|
||||
start_time = get_date_by_shift(self.last_end, 1, freq=self.freq)
|
||||
if start_time >= self.to_date:
|
||||
if self.need_log:
|
||||
self.logger.info(f"The prediction in {self.record.info['id']} are latest. No need to update.")
|
||||
self.logger.info(f"The prediction in {self.record.info['id']} are latest. No need to update.")
|
||||
return
|
||||
|
||||
# load dataset
|
||||
@@ -157,5 +155,4 @@ class PredUpdater(RecordUpdater):
|
||||
|
||||
self.record.save_objects(**{"pred.pkl": cb_pred})
|
||||
|
||||
if self.need_log:
|
||||
self.logger.info(f"Finish updating new {new_pred.shape[0]} predictions in {self.record.info['id']}.")
|
||||
self.logger.info(f"Finish updating new {new_pred.shape[0]} predictions in {self.record.info['id']}.")
|
||||
|
||||
@@ -24,15 +24,11 @@ class OnlineTool:
|
||||
ONLINE_TAG = "online" # the 'online' model
|
||||
OFFLINE_TAG = "offline" # the 'offline' model, not for online serving
|
||||
|
||||
def __init__(self, need_log=True):
|
||||
def __init__(self):
|
||||
"""
|
||||
Init OnlineTool.
|
||||
|
||||
Args:
|
||||
need_log (bool, optional): print log or not. Defaults to True.
|
||||
"""
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
self.need_log = need_log
|
||||
|
||||
def set_online_tag(self, tag, recorder: Union[list, object]):
|
||||
"""
|
||||
@@ -92,15 +88,14 @@ class OnlineToolR(OnlineTool):
|
||||
The implementation of OnlineTool based on (R)ecorder.
|
||||
"""
|
||||
|
||||
def __init__(self, experiment_name:str, need_log=True):
|
||||
def __init__(self, experiment_name: str):
|
||||
"""
|
||||
Init OnlineToolR.
|
||||
|
||||
Args:
|
||||
experiment_name (str): the experiment name.
|
||||
need_log (bool, optional): print log or not. Defaults to True.
|
||||
"""
|
||||
super().__init__(need_log=need_log)
|
||||
super().__init__()
|
||||
self.exp_name = experiment_name
|
||||
|
||||
def set_online_tag(self, tag, recorder: Union[Recorder, List]):
|
||||
@@ -115,8 +110,7 @@ class OnlineToolR(OnlineTool):
|
||||
recorder = [recorder]
|
||||
for rec in recorder:
|
||||
rec.set_tags(**{self.ONLINE_KEY: tag})
|
||||
if self.need_log:
|
||||
self.logger.info(f"Set {len(recorder)} models to '{tag}'.")
|
||||
self.logger.info(f"Set {len(recorder)} models to '{tag}'.")
|
||||
|
||||
def get_online_tag(self, recorder: Recorder) -> str:
|
||||
"""
|
||||
@@ -164,7 +158,6 @@ class OnlineToolR(OnlineTool):
|
||||
"""
|
||||
online_models = self.online_models()
|
||||
for rec in online_models:
|
||||
PredUpdater(rec, to_date=to_date, need_log=self.need_log).update()
|
||||
PredUpdater(rec, to_date=to_date).update()
|
||||
|
||||
if self.need_log:
|
||||
self.logger.info(f"Finished updating {len(online_models)} online model predictions of {self.exp_name}.")
|
||||
self.logger.info(f"Finished updating {len(online_models)} online model predictions of {self.exp_name}.")
|
||||
|
||||
Reference in New Issue
Block a user