1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

logger & doc

This commit is contained in:
lzh222333
2021-05-09 11:58:06 +00:00
parent f5ded06a15
commit 370b6aad74
7 changed files with 69 additions and 68 deletions

View File

@@ -114,12 +114,10 @@ class OnlineSimulationExample:
step=rolling_step, rtype=RollingGen.ROLL_SD, ds_extra_mod_func=None
) # The rolling tasks generator, ds_extra_mod_func is None because we just need simulate to 2018-10-31 and needn't change handler end time.
self.trainer = DelayTrainerRM(self.exp_name, self.task_pool)
self.task_manager = TaskManager(self.task_pool) # A good way to manage all your tasks
self.rolling_online_manager = OnlineManager(
RollingAverageStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen, need_log=False),
RollingAverageStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen),
trainer=self.trainer,
begin_time=self.start_time,
need_log=False,
)
self.tasks = tasks

View File

@@ -200,7 +200,7 @@ class DatasetH(Dataset):
The data to fetch: DK_*
Default is DK_I, which indicate fetching data for **inference**.
kwargs :
kwargs :
The parameters that kwargs may contain:
flt_col : str
It only exists in TSDatasetH, can be used to add a column of data(True or False) to filter data.
@@ -250,7 +250,9 @@ class TSDataSampler:
"""
def __init__(self, data: pd.DataFrame, start, end, step_len: int, fillna_type: str = "none", dtype=None, flt_data=None):
def __init__(
self, data: pd.DataFrame, start, end, step_len: int, fillna_type: str = "none", dtype=None, flt_data=None
):
"""
Build a dataset which looks like torch.data.utils.Dataset.
@@ -518,17 +520,17 @@ class TSDatasetH(DatasetH):
"""
dtype = kwargs.pop("dtype", None)
start, end = slc.start, slc.stop
flt_col = kwargs.pop('flt_col', None)
flt_col = kwargs.pop("flt_col", None)
# TSDatasetH will retrieve more data for complete
data = self._prepare_raw_seg(slc, **kwargs)
flt_kwargs = deepcopy(kwargs)
if flt_col is not None:
flt_kwargs['col_set'] = flt_col
flt_kwargs["col_set"] = flt_col
flt_data = self._prepare_raw_seg(slc, **flt_kwargs)
assert len(flt_data.columns) == 1
else:
flt_data = None
tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len, dtype=dtype, flt_data=flt_data)
return tsds
return tsds

View File

@@ -5,7 +5,7 @@
The Trainer will train a list of tasks and return a list of model recorder.
There are two steps in each Trainer including ``train``(make model recorder) and ``end_train``(modify model recorder).
This is concept called ``DelayTrainer``, which can be used in online simulating to parallel training.
This is concept called ``DelayTrainer``, which can be used in online simulating for parallel training.
In ``DelayTrainer``, the first step is only to save some necessary info to model recorder, and the second step which will be finished in the end can do some concurrent and time-consuming operations such as model fitting.
``Qlib`` offer two kind of Trainer, ``TrainerR`` is the simplest way and ``TrainerRM`` is based on TaskManager to help manager tasks lifecycle automatically.
@@ -103,7 +103,8 @@ def task_train(task_config: dict, experiment_name: str) -> Recorder:
class Trainer:
"""
The trainer which can train a list of model
The trainer can train a list of model.
There are Trainer and DelayTrainer, which can be distinguished by when it will finish real training.
"""
def __init__(self):
@@ -113,6 +114,9 @@ class Trainer:
"""
Given a list of model definition, begin a training and return the models.
For Trainer, it finish real training in this method.
For DelayTrainer, it only do some preparation in this method.
Args:
tasks: a list of tasks
@@ -126,6 +130,9 @@ class Trainer:
Given a list of models, finished something in the end of training if you need.
The models maybe Recorder, txt file, database and so on.
For Trainer, it do some finishing touches in this method.
For DelayTrainer, it finish real training in this method.
Args:
models: a list of models
@@ -326,7 +333,7 @@ class TrainerRM(Trainer):
recs.append(rec)
return recs
def end_train(self, recs: list, **kwargs) -> list:
def end_train(self, recs: list, **kwargs) -> List[Recorder]:
for rec in recs:
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
return recs
@@ -358,7 +365,7 @@ class DelayTrainerRM(TrainerRM):
self.end_train_func = end_train_func
self.delay = True
def train(self, tasks: list, train_func=None, experiment_name: str = None, **kwargs):
def train(self, tasks: list, train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:
"""
Same as `train` of TrainerRM, after_status will be STATUS_PART_DONE.
Args:
@@ -378,7 +385,7 @@ class DelayTrainerRM(TrainerRM):
**kwargs,
)
def end_train(self, recs, end_train_func=None, experiment_name: str = None, **kwargs):
def end_train(self, recs, end_train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:
"""
Given a list of Recorder and return a list of trained Recorder.
This class will finish real data loading and model fitting.

View File

@@ -12,11 +12,13 @@ This module also provide a method to simulate `Online Strategy <#Online Strategy
Which means you can verify your strategy or find a better one.
"""
import logging
from typing import Callable, Dict, List, Union
import pandas as pd
from qlib import get_module_logger
from qlib.data.data import D
from qlib.log import set_global_logger_level
from qlib.model.ens.ensemble import AverageEnsemble
from qlib.model.trainer import DelayTrainerR, Trainer
from qlib.utils import flatten_dict
@@ -37,7 +39,6 @@ class OnlineManager(Serializable):
trainer: Trainer = None,
begin_time: Union[str, pd.Timestamp] = None,
freq="day",
need_log=True,
):
"""
Init OnlineManager.
@@ -48,10 +49,8 @@ class OnlineManager(Serializable):
begin_time (Union[str,pd.Timestamp], optional): the OnlineManager will begin at this time. Defaults to None for using latest date.
trainer (Trainer): the trainer to train task. None for using DelayTrainerR.
freq (str, optional): data frequency. Defaults to "day".
need_log (bool, optional): print log or not. Defaults to True.
"""
self.logger = get_module_logger(self.__class__.__name__)
self.need_log = need_log
if not isinstance(strategy, list):
strategy = [strategy]
self.strategy = strategy
@@ -60,19 +59,18 @@ class OnlineManager(Serializable):
begin_time = D.calendar(freq=self.freq).max()
self.begin_time = pd.Timestamp(begin_time)
self.cur_time = self.begin_time
# The history of online models, which is a dict like {begin_time, {strategy, [online_models]}}
# begin_time means when online_models are onlined
self.history = {}
# OnlineManager will recorder the history of online models, which is a dict like {begin_time, {strategy, [online_models]}}. begin_time means when online_models are onlined.
self.history = {}
if trainer is None:
trainer = DelayTrainerR()
self.trainer = trainer
self.signals = None
def first_train(self, strategies:List[OnlineStrategy]=None, model_kwargs: dict = {}):
def first_train(self, strategies: List[OnlineStrategy] = None, model_kwargs: dict = {}):
"""
Get tasks from every strategy's first_tasks method and train them.
If using DelayTrainer, it can finish training all together after every strategy's first_tasks.
Args:
strategies (List[OnlineStrategy]): the strategies list (need this param when adding strategies). None for use default strategies.
model_kwargs (dict): the params for `prepare_online_models`
@@ -119,8 +117,7 @@ class OnlineManager(Serializable):
for strategy in self.strategy:
if not delay:
strategy.tool.update_online_pred()
if self.need_log:
self.logger.info(f"Strategy `{strategy.name_id}` begins routine...")
self.logger.info(f"Strategy `{strategy.name_id}` begins routine...")
tasks = strategy.prepare_tasks(self.cur_time, **task_kwargs)
models = self.trainer.train(tasks)
@@ -144,19 +141,20 @@ class OnlineManager(Serializable):
delay (bool, optional): if delay prepare models. Defaults to False.
model_kwargs (dict, optional): the params for `prepare_online_models`.
"""
if not models:
return
if not delay:
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
online_models = strategy.prepare_online_models(models, **model_kwargs)
else:
# just set every models as online models temporarily before ``prepare_online_models``
online_models = models
if len(online_models) > 0:
strategy.tool.reset_online_tag(online_models)
self.history.setdefault(self.cur_time, {})[strategy] = online_models
self.history.setdefault(self.cur_time, {})[strategy] = online_models
def get_collector(self) -> MergeCollector:
"""
Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results from every strategy.
This collector can be a basis as the signals preparation.
Returns:
MergeCollector: the collector to merge other collectors.
@@ -205,20 +203,23 @@ class OnlineManager(Serializable):
signals = pd.concat([old_signals, new_signals], axis=0)
else:
new_signals = signals
if self.need_log:
self.logger.info(f"Finished preparing new {len(new_signals)} signals.")
self.logger.info(f"Finished preparing new {len(new_signals)} signals.")
self.signals = signals
return new_signals
def get_signals(self) -> pd.Series:
def get_signals(self) -> Union[pd.Series, pd.DataFrame]:
"""
Get prepared online signals.
Returns:
pd.Series: signals
Union[pd.Series, pd.DataFrame]: pd.Series for only one signals every datetime.
pd.DataFrame for multiple signals, for example, buy and sell operation use different trading signal.
"""
return self.signals
SIM_LOG_LEVEL = logging.INFO + 1
SIM_LOG_NAME = "SIMULATE_INFO"
def simulate(self, end_time, frequency="day", task_kwargs={}, model_kwargs={}, signal_kwargs={}):
"""
Starting from current time, this method will simulate every routine in OnlineManager until end time.
@@ -239,8 +240,13 @@ class OnlineManager(Serializable):
"""
cal = D.calendar(start_time=self.cur_time, end_time=end_time, freq=frequency)
self.first_train()
simulate_level = self.SIM_LOG_LEVEL
set_global_logger_level(simulate_level)
logging.addLevelName(simulate_level, self.SIM_LOG_NAME)
for cur_time in cal:
self.logger.info(f"Simulating at {str(cur_time)}......")
self.logger.log(level=simulate_level, msg=f"Simulating at {str(cur_time)}......")
self.routine(
cur_time,
delay=self.trainer.is_delay(),
@@ -251,6 +257,8 @@ class OnlineManager(Serializable):
# delay prepare the models and signals
if self.trainer.is_delay():
self.delay_prepare(model_kwargs=model_kwargs, signal_kwargs=signal_kwargs)
set_global_logger_level(logging.INFO)
self.logger.info(f"Finished preparing signals")
return self.get_collector()

View File

@@ -22,7 +22,7 @@ class OnlineStrategy:
OnlineStrategy is working with `Online Manager <#Online Manager>`_, responsing how the tasks are generated, the models are updated and signals are perpared.
"""
def __init__(self, name_id: str, need_log=True):
def __init__(self, name_id: str):
"""
Init OnlineStrategy.
This module **MUST** use `Trainer <../reference/api.html#Trainer>`_ to finishing model training.
@@ -30,12 +30,10 @@ class OnlineStrategy:
Args:
name_id (str): a unique name or id
trainer (Trainer, optional): a instance of Trainer. Defaults to None.
need_log (bool, optional): print log or not. Defaults to True.
"""
self.name_id = name_id
self.logger = get_module_logger(self.__class__.__name__)
self.need_log = need_log
self.tool = OnlineTool(need_log)
self.tool = OnlineTool()
def prepare_tasks(self, cur_time, **kwargs) -> List[dict]:
"""
@@ -46,20 +44,21 @@ class OnlineStrategy:
"""
raise NotImplementedError(f"Please implement the `prepare_tasks` method.")
def prepare_online_models(self, models, cur_time=None, check_func=None, **kwargs):
def prepare_online_models(self, models, cur_time=None, check_func=None) -> List[object]:
"""
A typically implementation, but maybe you will need old models by online_tool.
Use trainer to train a list of tasks and set the trained model to `online`.
Select some models as the online models from the trained models.
NOTE: This method will first offline all models and online the online models prepared by this method. So you can find last online models by OnlineTool.online_models if you still need them.
NOTE: This method offline all models and online the online models prepared by this method (if have). So you can find last online models by OnlineTool.online_models if you still need them.
Args:
tasks (list): a list of tasks.
check_func: the method to judge if a model can be online.
The parameter is the model record and return True for online.
None for online every models.
**kwargs: will be passed to end_train which means will be passed to customized train method.
Returns:
List[object]: a list of selected models.
"""
if check_func is not None:
online_models = []
@@ -101,7 +100,6 @@ class RollingAverageStrategy(OnlineStrategy):
name_id: str,
task_template: Union[dict, List[dict]],
rolling_gen: RollingGen,
need_log=True,
):
"""
Init RollingAverageStrategy.
@@ -112,15 +110,14 @@ class RollingAverageStrategy(OnlineStrategy):
name_id (str): a unique name or id. Will be also the name of Experiment.
task_template (Union[dict,List[dict]]): a list of task_template or a single template, which will be used to generate many tasks using rolling_gen.
rolling_gen (RollingGen): an instance of RollingGen
need_log (bool, optional): print log or not. Defaults to True.
"""
super().__init__(name_id=name_id, need_log=need_log)
super().__init__(name_id=name_id)
self.exp_name = self.name_id
if not isinstance(task_template, list):
task_template = [task_template]
self.task_template = task_template
self.rg = rolling_gen
self.tool = OnlineToolR(self.exp_name, need_log)
self.tool = OnlineToolR(self.exp_name)
self.ta = TimeAdjuster()
def get_collector(self, process_list=[RollingGroup()], rec_key_func=None, rec_filter_func=None, artifacts_key=None):
@@ -180,10 +177,9 @@ class RollingAverageStrategy(OnlineStrategy):
self.logger.warn(f"No latest online recorders, no new tasks.")
return []
calendar_latest = D.calendar(end_time=cur_time)[-1] if cur_time is None else cur_time
if self.need_log:
self.logger.info(
f"The interval between current time {calendar_latest} and last rolling test begin time {max_test[0]} is {self.ta.cal_interval(calendar_latest, max_test[0])}, the rolling step is {self.rg.step}"
)
self.logger.info(
f"The interval between current time {calendar_latest} and last rolling test begin time {max_test[0]} is {self.ta.cal_interval(calendar_latest, max_test[0])}, the rolling step is {self.rg.step}"
)
if self.ta.cal_interval(calendar_latest, max_test[0]) >= self.rg.step:
old_tasks = []
tasks_tmp = []

View File

@@ -60,10 +60,9 @@ class RecordUpdater(metaclass=ABCMeta):
Update a specific recorders
"""
def __init__(self, record: Recorder, need_log=True, *args, **kwargs):
def __init__(self, record: Recorder, *args, **kwargs):
self.record = record
self.logger = get_module_logger(self.__class__.__name__)
self.need_log = need_log
@abstractmethod
def update(self, *args, **kwargs):
@@ -78,7 +77,7 @@ class PredUpdater(RecordUpdater):
Update the prediction in the Recorder
"""
def __init__(self, record: Recorder, to_date=None, hist_ref: int = 0, freq="day", need_log=True):
def __init__(self, record: Recorder, to_date=None, hist_ref: int = 0, freq="day"):
"""
Init PredUpdater.
@@ -96,7 +95,7 @@ class PredUpdater(RecordUpdater):
"""
# TODO: automate this hist_ref in the future.
super().__init__(record=record, need_log=need_log)
super().__init__(record=record)
self.to_date = to_date
self.hist_ref = hist_ref
@@ -138,8 +137,7 @@ class PredUpdater(RecordUpdater):
start_time = get_date_by_shift(self.last_end, 1, freq=self.freq)
if start_time >= self.to_date:
if self.need_log:
self.logger.info(f"The prediction in {self.record.info['id']} are latest. No need to update.")
self.logger.info(f"The prediction in {self.record.info['id']} are latest. No need to update.")
return
# load dataset
@@ -157,5 +155,4 @@ class PredUpdater(RecordUpdater):
self.record.save_objects(**{"pred.pkl": cb_pred})
if self.need_log:
self.logger.info(f"Finish updating new {new_pred.shape[0]} predictions in {self.record.info['id']}.")
self.logger.info(f"Finish updating new {new_pred.shape[0]} predictions in {self.record.info['id']}.")

View File

@@ -24,15 +24,11 @@ class OnlineTool:
ONLINE_TAG = "online" # the 'online' model
OFFLINE_TAG = "offline" # the 'offline' model, not for online serving
def __init__(self, need_log=True):
def __init__(self):
"""
Init OnlineTool.
Args:
need_log (bool, optional): print log or not. Defaults to True.
"""
self.logger = get_module_logger(self.__class__.__name__)
self.need_log = need_log
def set_online_tag(self, tag, recorder: Union[list, object]):
"""
@@ -92,15 +88,14 @@ class OnlineToolR(OnlineTool):
The implementation of OnlineTool based on (R)ecorder.
"""
def __init__(self, experiment_name:str, need_log=True):
def __init__(self, experiment_name: str):
"""
Init OnlineToolR.
Args:
experiment_name (str): the experiment name.
need_log (bool, optional): print log or not. Defaults to True.
"""
super().__init__(need_log=need_log)
super().__init__()
self.exp_name = experiment_name
def set_online_tag(self, tag, recorder: Union[Recorder, List]):
@@ -115,8 +110,7 @@ class OnlineToolR(OnlineTool):
recorder = [recorder]
for rec in recorder:
rec.set_tags(**{self.ONLINE_KEY: tag})
if self.need_log:
self.logger.info(f"Set {len(recorder)} models to '{tag}'.")
self.logger.info(f"Set {len(recorder)} models to '{tag}'.")
def get_online_tag(self, recorder: Recorder) -> str:
"""
@@ -164,7 +158,6 @@ class OnlineToolR(OnlineTool):
"""
online_models = self.online_models()
for rec in online_models:
PredUpdater(rec, to_date=to_date, need_log=self.need_log).update()
PredUpdater(rec, to_date=to_date).update()
if self.need_log:
self.logger.info(f"Finished updating {len(online_models)} online model predictions of {self.exp_name}.")
self.logger.info(f"Finished updating {len(online_models)} online model predictions of {self.exp_name}.")