diff --git a/docs/advanced/serial.rst b/docs/advanced/serial.rst
index 8c0f83746..e0840069b 100644
--- a/docs/advanced/serial.rst
+++ b/docs/advanced/serial.rst
@@ -14,6 +14,9 @@ Serializable Class
``Qlib`` provides a base class ``qlib.utils.serial.Serializable``, whose state can be dumped into or loaded from disk in `pickle` format.
When users dump the state of a ``Serializable`` instance, the attributes of the instance whose name **does not** start with `_` will be saved on the disk.
+However, users can use ``config`` method or override ``default_dump_all`` attribute to prevent this feature.
+
+Users can also override ``pickle_backend`` attribute to choose a pickle backend. The supported value is "pickle" (default and common) and "dill" (dump more things such as function, more information in `here `_).
Example
==========================
diff --git a/docs/advanced/task_management.rst b/docs/advanced/task_management.rst
index d60049455..56a3137f9 100644
--- a/docs/advanced/task_management.rst
+++ b/docs/advanced/task_management.rst
@@ -19,7 +19,7 @@ An example of the entire process is shown `here `_.
Even though the task template is fixed, users can customize their ``TaskGen`` to generate different ``task`` by task template.
@@ -30,15 +30,15 @@ Here is the base class of ``TaskGen``:
:members:
``Qlib`` provides a class `RollingGen `_ to generate a list of ``task`` of the dataset in different date segments.
-This class allows users to verify the effect of data from different periods on the model in one experiment. More information in `here <../reference/api.html#TaskGen>`_.
+This class allows users to verify the effect of data from different periods on the model in one experiment. More information is `here <../reference/api.html#TaskGen>`_.
Task Storing
===============
To achieve higher efficiency and the possibility of cluster operation, ``Task Manager`` will store all tasks in `MongoDB `_.
``TaskManager`` can fetch undone tasks automatically and manage the lifecycle of a set of tasks with error handling.
-Users **MUST** finished the configuration of `MongoDB `_ when using this module.
+Users **MUST** finish the configuration of `MongoDB `_ when using this module.
-Users need to provide the MongoDB URL and database name for using ``TaskManager`` in `initialization <../start/initialization.html#Parameters>`_ or make statement like this.
+Users need to provide the MongoDB URL and database name for using ``TaskManager`` in `initialization <../start/initialization.html#Parameters>`_ or make a statement like this.
.. code-block:: python
@@ -55,8 +55,7 @@ More information of ``Task Manager`` can be found in `here <../reference/api.htm
Task Training
===============
-#FIXME: Trainer
-After generating and storing those ``task``, it's time to run the ``task`` which are in the *WAITING* status.
+After generating and storing those ``task``, it's time to run the ``task`` which is in the *WAITING* status.
``Qlib`` provides a method called ``run_task`` to run those ``task`` in task pool, however, users can also customize how tasks are executed.
An easy way to get the ``task_func`` is using ``qlib.model.trainer.task_train`` directly.
It will run the whole workflow defined by ``task``, which includes *Model*, *Dataset*, *Record*.
@@ -64,16 +63,20 @@ It will run the whole workflow defined by ``task``, which includes *Model*, *Dat
.. autofunction:: qlib.workflow.task.manage.run_task
Meanwhile, ``Qlib`` provides a module called ``Trainer``.
-``Trainer`` will train a list of tasks and return a list of model recorder.
-``Qlib`` offer two kind of Trainer, TrainerR is the simplest way and TrainerRM is based on TaskManager to help manager tasks lifecycle automatically.
+
+.. autoclass:: qlib.model.trainer.Trainer
+ :members:
+
+``Trainer`` will train a list of tasks and return a list of model recorders.
+``Qlib`` offer two kinds of Trainer, TrainerR is the simplest way and TrainerRM is based on TaskManager to help manager tasks lifecycle automatically.
If you do not want to use ``Task Manager`` to manage tasks, then use TrainerR to train a list of tasks generated by ``TaskGen`` is enough.
-More information is in `here <../reference/api.html#Trainer>`_.
+`Here <../reference/api.html#Trainer>`_ are the details about different ``Trainer``.
Task Collecting
===============
To collect the results of ``task`` after training, ``Qlib`` provides `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_ to collect the results in a readable, expandable and loosely-coupled way.
-`Collector <../reference/api.html#Collector>`_ can collect object from everywhere and process them such as merging, grouping, averaging and so on. It has 2 step action including ``collect`` (collect anything in a dict) and ``process_collect`` (process collected dict).
+`Collector <../reference/api.html#Collector>`_ can collect objects from everywhere and process them such as merging, grouping, averaging and so on. It has 2 step action including ``collect`` (collect anything in a dict) and ``process_collect`` (process collected dict).
`Group <../reference/api.html#Group>`_ also has 2 steps including ``group`` (can group a set of object based on `group_func` and change them to a dict) and ``reduce`` (can make a dict become an ensemble based on some rule).
For example: {(A,B,C1): object, (A,B,C2): object} ---``group``---> {(A,B): {C1: object, C2: object}} ---``reduce``---> {(A,B): object}
@@ -81,6 +84,6 @@ For example: {(A,B,C1): object, (A,B,C2): object} ---``group``---> {(A,B): {C1:
`Ensemble <../reference/api.html#Ensemble>`_ can merge the objects in an ensemble.
For example: {C1: object, C2: object} ---``Ensemble``---> object
-So the hierarchy is ``Collector``'s second step correspond to ``Group``. And ``Group``'s second step correspond to ``Ensemble``.
+So the hierarchy is ``Collector``'s second step corresponds to ``Group``. And ``Group``'s second step correspond to ``Ensemble``.
-For more information, please see `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_, or the `example `_
\ No newline at end of file
+For more information, please see `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_, or the `example `_.
\ No newline at end of file
diff --git a/docs/component/online.rst b/docs/component/online.rst
index e25173153..66331901f 100644
--- a/docs/component/online.rst
+++ b/docs/component/online.rst
@@ -9,12 +9,12 @@ Online Serving
Introduction
=============
In addition to backtesting, one way to test a model is effective is to make predictions in real market conditions or even do real trading based on those predictions.
-``Online Serving`` is a set of module for online models using latest data,
+``Online Serving`` is a set of modules for online models using the latest data,
which including `Online Manager <#Online Manager>`_, `Online Strategy <#Online Strategy>`_, `Online Tool <#Online Tool>`_, `Updater <#Updater>`_.
`Here `_ are several examples for reference, which demonstrate different features of ``Online Serving``.
-If you have many models or `task` need to be managed, please consider `Task Management <../advanced/task_management.html>`_.
-The `examples `_ maybe based on `Task Management <../advanced/task_management.html>`_ such as ``TrainerRM`` or ``Collector``.
+If you have many models or `task` needs to be managed, please consider `Task Management <../advanced/task_management.html>`_.
+The `examples `_ are based on some components in `Task Management <../advanced/task_management.html>`_ such as ``TrainerRM`` or ``Collector``.
Online Manager
=============
diff --git a/docs/reference/api.rst b/docs/reference/api.rst
index edba6228a..57f61f18b 100644
--- a/docs/reference/api.rst
+++ b/docs/reference/api.rst
@@ -226,4 +226,7 @@ Serializable
--------------------
.. automodule:: qlib.utils.serial.Serializable
- :members:
\ No newline at end of file
+ :members:
+
+
+
\ No newline at end of file
diff --git a/examples/model_rolling/task_manager_rolling.py b/examples/model_rolling/task_manager_rolling.py
index 175319885..4f3ac04b1 100644
--- a/examples/model_rolling/task_manager_rolling.py
+++ b/examples/model_rolling/task_manager_rolling.py
@@ -2,8 +2,8 @@
# Licensed under the MIT License.
"""
-This example shows how a TrainerRM work based on TaskManager with rolling tasks.
-After training, how to collect the rolling results will be showed in task_collecting.
+This example shows how a TrainerRM works based on TaskManager with rolling tasks.
+After training, how to collect the rolling results will be shown in task_collecting.
"""
from pprint import pprint
diff --git a/examples/online_srv/online_management_simulate.py b/examples/online_srv/online_management_simulate.py
index 48433c6d5..c09b10aa7 100644
--- a/examples/online_srv/online_management_simulate.py
+++ b/examples/online_srv/online_management_simulate.py
@@ -2,7 +2,7 @@
# Licensed under the MIT License.
"""
-This examples is about how can simulate the OnlineManager based on rolling tasks.
+This example is about how can simulate the OnlineManager based on rolling tasks.
"""
import fire
@@ -112,8 +112,8 @@ class OnlineSimulationExample:
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
self.rolling_gen = RollingGen(
step=rolling_step, rtype=RollingGen.ROLL_SD, ds_extra_mod_func=None
- ) # The rolling tasks generator, ds_extra_mod_func is None because we just need simulate to 2018-10-31 and needn't change handler end time.
- self.trainer = DelayTrainerRM(self.exp_name, self.task_pool)
+ ) # The rolling tasks generator, ds_extra_mod_func is None because we just need to simulate to 2018-10-31 and needn't change the handler end time.
+ self.trainer = DelayTrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR
self.rolling_online_manager = OnlineManager(
RollingStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen),
trainer=self.trainer,
@@ -138,8 +138,6 @@ class OnlineSimulationExample:
print(self.rolling_online_manager.get_collector()())
print("========== signals ==========")
print(self.rolling_online_manager.get_signals())
- print("========== online history ==========")
- print(self.rolling_online_manager.history)
if __name__ == "__main__":
diff --git a/examples/online_srv/rolling_online_management.py b/examples/online_srv/rolling_online_management.py
index e15daeb29..e5c37dac6 100644
--- a/examples/online_srv/rolling_online_management.py
+++ b/examples/online_srv/rolling_online_management.py
@@ -2,7 +2,7 @@
# Licensed under the MIT License.
"""
-This example show how OnlineManager works with rolling tasks.
+This example shows how OnlineManager works with rolling tasks.
There are two parts including first train and routine.
Firstly, the OnlineManager will finish the first training and set trained models to `online` models.
Next, the OnlineManager will finish a routine process, including update online prediction -> prepare signals -> prepare tasks -> prepare new models -> reset online models
@@ -154,7 +154,7 @@ if __name__ == "__main__":
# python rolling_online_management.py first_run
####### to update the models and predictions after the trading time, use the command below
- # python rolling_online_management.py after_day
+ # python rolling_online_management.py routine
####### to define your own parameters, use `--`
# python rolling_online_management.py first_run --exp_name='your_exp_name' --rolling_step=40
diff --git a/examples/online_srv/update_online_pred.py b/examples/online_srv/update_online_pred.py
index 6e2725c7a..228bc0dac 100644
--- a/examples/online_srv/update_online_pred.py
+++ b/examples/online_srv/update_online_pred.py
@@ -2,10 +2,10 @@
# Licensed under the MIT License.
"""
-This example show how OnlineTool works when we need update prediction.
+This example shows how OnlineTool works when we need update prediction.
There are two parts including first_train and update_online_pred.
-Firstly, we will finish the training and set the trained model to `online` model.
-Next, we will finish updating online prediction.
+Firstly, we will finish the training and set the trained models to the `online` models.
+Next, we will finish updating online predictions.
"""
import fire
import qlib
diff --git a/qlib/model/ens/ensemble.py b/qlib/model/ens/ensemble.py
index 6040517e2..0f48ce728 100644
--- a/qlib/model/ens/ensemble.py
+++ b/qlib/model/ens/ensemble.py
@@ -2,7 +2,7 @@
# Licensed under the MIT License.
"""
-Ensemble can merge the objects in an Ensemble. For example, if there are many submodels predictions, we may need to merge them in an ensemble predictions.
+Ensemble module can merge the objects in an Ensemble. For example, if there are many submodels predictions, we may need to merge them into an ensemble prediction.
"""
from typing import Union
@@ -11,29 +11,41 @@ from qlib.utils import FLATTEN_TUPLE, flatten_dict
class Ensemble:
- """Merge the objects in an Ensemble."""
+ """Merge the ensemble_dict into an ensemble object.
- def __call__(self, ensemble_dict: dict, *args, **kwargs):
- """Merge the ensemble_dict into an ensemble object.
- For example: {Rollinga_b: object, Rollingb_c: object} -> object
+ For example: {Rollinga_b: object, Rollingb_c: object} -> object
+ When calling this class:
+
Args:
- ensemble_dict (dict): the ensemble dict waiting for merging like {name: things}
+ ensemble_dict (dict): the ensemble dict like {name: things} waiting for merging
Returns:
object: the ensemble object
- """
+ """
+
+ def __call__(self, ensemble_dict: dict, *args, **kwargs):
raise NotImplementedError(f"Please implement the `__call__` method.")
class SingleKeyEnsemble(Ensemble):
"""
- Extract the object if there is only one key and value in dict. Make result more readable.
+ Extract the object if there is only one key and value in the dict. Make the result more readable.
{Only key: Only value} -> Only value
- If there are more than 1 key or less than 1 key, then do nothing.
+
+ If there is more than 1 key or less than 1 key, then do nothing.
Even you can run this recursively to make dict more readable.
- NOTE: Default run recursively.
+
+ NOTE: Default runs recursively.
+
+ When calling this class:
+
+ Args:
+ ensemble_dict (dict): the dict. The key of the dict will be ignored.
+
+ Returns:
+ dict: the readable dict.
"""
def __call__(self, ensemble_dict: Union[dict, object], recursion: bool = True) -> object:
@@ -52,12 +64,11 @@ class SingleKeyEnsemble(Ensemble):
class RollingEnsemble(Ensemble):
- """Merge the rolling objects in an Ensemble"""
+ """Merge a dict of rolling dataframe like `prediction` or `IC` into an ensemble.
- def __call__(self, ensemble_dict: dict) -> pd.DataFrame:
- """Merge a dict of rolling dataframe like `prediction` or `IC` into an ensemble.
+ NOTE: The values of dict must be pd.DataFrame, and have the index "datetime".
- NOTE: The values of dict must be pd.DataFrame, and have the index "datetime"
+ When calling this class:
Args:
ensemble_dict (dict): a dict like {"A": pd.DataFrame, "B": pd.DataFrame}.
@@ -65,7 +76,9 @@ class RollingEnsemble(Ensemble):
Returns:
pd.DataFrame: the complete result of rolling.
- """
+ """
+
+ def __call__(self, ensemble_dict: dict) -> pd.DataFrame:
artifact_list = list(ensemble_dict.values())
artifact_list.sort(key=lambda x: x.index.get_level_values("datetime").min())
artifact = pd.concat(artifact_list)
@@ -76,11 +89,12 @@ class RollingEnsemble(Ensemble):
class AverageEnsemble(Ensemble):
- def __call__(self, ensemble_dict: dict):
- """
- Average and standardize a dict of same shape dataframe like `prediction` or `IC` into an ensemble.
+ """
+ Average and standardize a dict of same shape dataframe like `prediction` or `IC` into an ensemble.
- NOTE: The values of dict must be pd.DataFrame, and have the index "datetime". If it is a nested dict, then flat it.
+ NOTE: The values of dict must be pd.DataFrame, and have the index "datetime". If it is a nested dict, then flat it.
+
+ When calling this class:
Args:
ensemble_dict (dict): a dict like {"A": pd.DataFrame, "B": pd.DataFrame}.
@@ -88,7 +102,8 @@ class AverageEnsemble(Ensemble):
Returns:
pd.DataFrame: the complete result of averaging and standardizing.
- """
+ """
+ def __call__(self, ensemble_dict: dict) -> pd.DataFrame:
# need to flatten the nested dict
ensemble_dict = flatten_dict(ensemble_dict, sep=FLATTEN_TUPLE)
values = list(ensemble_dict.values())
diff --git a/qlib/model/ens/group.py b/qlib/model/ens/group.py
index a00a8ea0e..93903f433 100644
--- a/qlib/model/ens/group.py
+++ b/qlib/model/ens/group.py
@@ -2,7 +2,7 @@
# Licensed under the MIT License.
"""
-Group can group a set of object based on `group_func` and change them to a dict.
+Group can group a set of objects based on `group_func` and change them to a dict.
After group, we provide a method to reduce them.
For example:
@@ -21,10 +21,11 @@ class Group:
"""Group the objects based on dict"""
def __init__(self, group_func=None, ens: Ensemble = None):
- """init Group.
+ """
+ Init Group.
Args:
- group_func (Callable, optional): Given a dict and return the group key and one of group elements.
+ group_func (Callable, optional): Given a dict and return the group key and one of the group elements.
For example: {(A,B,C1): object, (A,B,C2): object} -> {(A,B): {C1: object, C2: object}}
@@ -37,7 +38,7 @@ class Group:
def group(self, *args, **kwargs) -> dict:
"""
- Group a set of object and change them to a dict.
+ Group a set of objects and change them to a dict.
For example: {(A,B,C1): object, (A,B,C2): object} -> {(A,B): {C1: object, C2: object}}
@@ -51,7 +52,7 @@ class Group:
def reduce(self, *args, **kwargs) -> dict:
"""
- Reduce grouped dict in some way.
+ Reduce grouped dict.
For example: {(A,B): {C1: object, C2: object}} -> {(A,B): object}
@@ -63,7 +64,7 @@ class Group:
else:
raise NotImplementedError(f"Please specify valid `_ens_func`.")
- def __call__(self, ungrouped_dict: dict, n_jobs=1, verbose=0, *args, **kwargs) -> dict:
+ def __call__(self, ungrouped_dict: dict, n_jobs:int=1, verbose:int=0, *args, **kwargs) -> dict:
"""
Group the ungrouped_dict into different groups.
@@ -72,10 +73,12 @@ class Group:
Returns:
dict: grouped_dict like {G1: object, G2: object}
+ n_jobs: how many progress you need.
+ verbose: the print mode for Parallel.
"""
# NOTE: The multiprocessing will raise error if you use `Serializable`
- # Because the `Serializable` will affect the behaviours of pickle
+ # Because the `Serializable` will affect the behaviors of pickle
grouped_dict = self.group(ungrouped_dict, *args, **kwargs)
key_l = []
@@ -87,12 +90,12 @@ class Group:
class RollingGroup(Group):
- """group the rolling dict"""
+ """Group the rolling dict"""
def group(self, rolling_dict: dict) -> dict:
"""Given an rolling dict likes {(A,B,R): things}, return the grouped dict likes {(A,B): {R:things}}
- NOTE: There is a assumption which is the rolling key is at the end of key tuple, because the rolling results always need to be ensemble firstly.
+ NOTE: There is an assumption which is the rolling key is at the end of the key tuple, because the rolling results always need to be ensemble firstly.
Args:
rolling_dict (dict): an rolling dict. If the key is not a tuple, then do nothing.
diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py
index f261a4b4e..0c9c9e2c2 100644
--- a/qlib/model/trainer.py
+++ b/qlib/model/trainer.py
@@ -2,13 +2,13 @@
# Licensed under the MIT License.
"""
-The Trainer will train a list of tasks and return a list of model recorder.
+The Trainer will train a list of tasks and return a list of model recorders.
There are two steps in each Trainer including ``train``(make model recorder) and ``end_train``(modify model recorder).
-This is concept called ``DelayTrainer``, which can be used in online simulating for parallel training.
-In ``DelayTrainer``, the first step is only to save some necessary info to model recorder, and the second step which will be finished in the end can do some concurrent and time-consuming operations such as model fitting.
+This is a concept called ``DelayTrainer``, which can be used in online simulating for parallel training.
+In ``DelayTrainer``, the first step is only to save some necessary info to model recorders, and the second step which will be finished in the end can do some concurrent and time-consuming operations such as model fitting.
-``Qlib`` offer two kind of Trainer, ``TrainerR`` is the simplest way and ``TrainerRM`` is based on TaskManager to help manager tasks lifecycle automatically.
+``Qlib`` offer two kinds of Trainer, ``TrainerR`` is the simplest way and ``TrainerRM`` is based on TaskManager to help manager tasks lifecycle automatically.
"""
import socket
@@ -25,7 +25,7 @@ from qlib.workflow.task.manage import TaskManager, run_task
def begin_task_train(task_config: dict, experiment_name: str, recorder_name: str = None) -> Recorder:
"""
- Begin a task training to start a recorder and save the task config.
+ Begin task training to start a recorder and save the task config.
Args:
task_config (dict): the config of a task
@@ -94,7 +94,7 @@ def task_train(task_config: dict, experiment_name: str) -> Recorder:
Returns
----------
- Recorder : The instance of the recorder
+ Recorder: The instance of the recorder
"""
recorder = begin_task_train(task_config, experiment_name)
recorder = end_task_train(recorder, experiment_name)
@@ -103,7 +103,7 @@ def task_train(task_config: dict, experiment_name: str) -> Recorder:
class Trainer:
"""
- The trainer can train a list of model.
+ The trainer can train a list of models.
There are Trainer and DelayTrainer, which can be distinguished by when it will finish real training.
"""
@@ -112,10 +112,10 @@ class Trainer:
def train(self, tasks: list, *args, **kwargs) -> list:
"""
- Given a list of model definition, begin a training and return the models.
+ Given a list of task definitions, begin training, and return the models.
- For Trainer, it finish real training in this method.
- For DelayTrainer, it only do some preparation in this method.
+ For Trainer, it finishes real training in this method.
+ For DelayTrainer, it only does some preparation in this method.
Args:
tasks: a list of tasks
@@ -127,11 +127,11 @@ class Trainer:
def end_train(self, models: list, *args, **kwargs) -> list:
"""
- Given a list of models, finished something in the end of training if you need.
- The models maybe Recorder, txt file, database and so on.
+ Given a list of models, finished something at the end of training if you need.
+ The models may be Recorder, txt file, database, and so on.
- For Trainer, it do some finishing touches in this method.
- For DelayTrainer, it finish real training in this method.
+ For Trainer, it does some finishing touches in this method.
+ For DelayTrainer, it finishes real training in this method.
Args:
models: a list of models
@@ -155,9 +155,9 @@ class Trainer:
class TrainerR(Trainer):
"""
Trainer based on (R)ecorder.
- It will train a list of tasks and return a list of model recorder in a linear way.
+ It will train a list of tasks and return a list of model recorders in a linear way.
- Assumption: models were defined by `task` and the results will saved to `Recorder`
+ Assumption: models were defined by `task` and the results will be saved to `Recorder`.
"""
# Those tag will help you distinguish whether the Recorder has finished traning
@@ -182,13 +182,13 @@ class TrainerR(Trainer):
Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed.
Args:
- tasks (list): a list of definition based on `task` dict
- train_func (Callable): the train method which need at least `task`s and `experiment_name`. None for default training method.
+ tasks (list): a list of definitions based on `task` dict
+ train_func (Callable): the training method which needs at least `tasks` and `experiment_name`. None for the default training method.
experiment_name (str): the experiment name, None for use default name.
kwargs: the params for train_func.
Returns:
- list: a list of Recorders
+ List[Recorder]: a list of Recorders
"""
if len(tasks) == 0:
return []
@@ -204,6 +204,15 @@ class TrainerR(Trainer):
return recs
def end_train(self, recs: list, **kwargs) -> List[Recorder]:
+ """
+ Set STATUS_END tag to the recorders.
+
+ Args:
+ recs (list): a list of trained recorders.
+
+ Returns:
+ List[Recorder]: the same list as the param.
+ """
for rec in recs:
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
return recs
@@ -231,15 +240,15 @@ class DelayTrainerR(TrainerR):
"""
Given a list of Recorder and return a list of trained Recorder.
This class will finish real data loading and model fitting.
-
+
Args:
recs (list): a list of Recorder, the tasks have been saved to them
- end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
+ end_train_func (Callable, optional): the end_train method which needs at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
experiment_name (str): the experiment name, None for use default name.
kwargs: the params for end_train_func.
-
+
Returns:
- list: a list of Recorders
+ List[Recorder]: a list of Recorders
"""
if end_train_func is None:
end_train_func = self.end_train_func
@@ -256,7 +265,7 @@ class DelayTrainerR(TrainerR):
class TrainerRM(Trainer):
"""
Trainer based on (R)ecorder and Task(M)anager.
- It can train a list of tasks and return a list of model recorder in a multiprocessing way.
+ It can train a list of tasks and return a list of model recorders in a multiprocessing way.
Assumption: `task` will be saved to TaskManager and `task` will be fetched and trained from TaskManager
"""
@@ -296,15 +305,15 @@ class TrainerRM(Trainer):
Users can customize their train_func to realize multiple processes or even multiple machines.
Args:
- tasks (list): a list of definition based on `task` dict
- train_func (Callable): the train method which need at least `task`s and `experiment_name`. None for default training method.
+ tasks (list): a list of definitions based on `task` dict
+ train_func (Callable): the training method which needs at least `task`s and `experiment_name`. None for the default training method.
experiment_name (str): the experiment name, None for use default name.
before_status (str): the tasks in before_status will be fetched and trained. Can be STATUS_WAITING, STATUS_PART_DONE.
after_status (str): the tasks after trained will become after_status. Can be STATUS_WAITING, STATUS_PART_DONE.
kwargs: the params for train_func.
Returns:
- list: a list of Recorders
+ List[Recorder]: a list of Recorders
"""
if len(tasks) == 0:
return []
@@ -334,6 +343,15 @@ class TrainerRM(Trainer):
return recs
def end_train(self, recs: list, **kwargs) -> List[Recorder]:
+ """
+ Set STATUS_END tag to the recorders.
+
+ Args:
+ recs (list): a list of trained recorders.
+
+ Returns:
+ List[Recorder]: the same list as the param.
+ """
for rec in recs:
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
return recs
@@ -368,12 +386,14 @@ class DelayTrainerRM(TrainerRM):
def train(self, tasks: list, train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:
"""
Same as `train` of TrainerRM, after_status will be STATUS_PART_DONE.
+
Args:
tasks (list): a list of definition based on `task` dict
train_func (Callable): the train method which need at least `task`s and `experiment_name`. Defaults to None for using self.train_func.
experiment_name (str): the experiment name, None for use default name.
+
Returns:
- list: a list of Recorders
+ List[Recorder]: a list of Recorders
"""
if len(tasks) == 0:
return []
@@ -390,7 +410,7 @@ class DelayTrainerRM(TrainerRM):
Given a list of Recorder and return a list of trained Recorder.
This class will finish real data loading and model fitting.
- NOTE: This method will train all STATUS_PART_DONE tasks in task pool, not only the ``recs``.
+ NOTE: This method will train all STATUS_PART_DONE tasks in the task pool, not only the ``recs``.
Args:
recs (list): a list of Recorder, the tasks have been saved to them.
@@ -399,7 +419,7 @@ class DelayTrainerRM(TrainerRM):
kwargs: the params for end_train_func.
Returns:
- list: a list of Recorders
+ List[Recorder]: a list of Recorders
"""
if end_train_func is None:
diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py
index 2ff687737..77857182d 100644
--- a/qlib/utils/__init__.py
+++ b/qlib/utils/__init__.py
@@ -719,7 +719,7 @@ def lazy_sort_index(df: pd.DataFrame, axis=0) -> pd.DataFrame:
FLATTEN_TUPLE = "_FLATTEN_TUPLE"
-def flatten_dict(d, parent_key="", sep="."):
+def flatten_dict(d, parent_key="", sep=".") -> dict:
"""
Flatten a nested dict.
diff --git a/qlib/utils/serial.py b/qlib/utils/serial.py
index 352949198..c7c51bac2 100644
--- a/qlib/utils/serial.py
+++ b/qlib/utils/serial.py
@@ -3,10 +3,12 @@
from pathlib import Path
import pickle
+import typing
import dill
from typing import Union
+
class Serializable:
"""
Serializable will change the behaviors of pickle.
@@ -16,7 +18,7 @@ class Serializable:
"""
pickle_backend = "pickle" # another optional value is "dill" which can pickle more things of python.
- default_dump_all = False # if dump all things
+ default_dump_all = False # if dump all things
def __init__(self):
self._dump_all = self.default_dump_all
@@ -76,6 +78,14 @@ class Serializable:
del self.__dict__[self.FLAG_KEY]
def to_pickle(self, path: Union[Path, str], dump_all: bool = None, exclude: list = None):
+ """
+ Dump self to a pickle file.
+
+ Args:
+ path (Union[Path, str]): the path to dump
+ dump_all (bool, optional): if need to dump all things. Defaults to None.
+ exclude (list, optional): will exclude the attributes in this list when dumping. Defaults to None.
+ """
self.config(dump_all=dump_all, exclude=exclude)
with Path(path).open("wb") as f:
self.get_backend().dump(self, f)
@@ -83,7 +93,7 @@ class Serializable:
@classmethod
def load(cls, filepath):
"""
- load the collector from a file
+ Load the collector from a filepath.
Args:
filepath (str): the path of file
@@ -104,10 +114,10 @@ class Serializable:
@classmethod
def get_backend(cls):
"""
- Return the backend of a Serializable class. The value will be "pickle" or "dill".
+ Return the real backend of a Serializable class. The pickle_backend value can be "pickle" or "dill".
Returns:
- str: The value of "pickle" or "dill"
+ module: pickle or dill module based on pickle_backend
"""
if cls.pickle_backend == "pickle":
return pickle
@@ -115,4 +125,3 @@ class Serializable:
return dill
else:
raise ValueError("Unknown pickle backend, please use 'pickle' or 'dill'.")
-
diff --git a/qlib/workflow/online/manager.py b/qlib/workflow/online/manager.py
index 63169b58d..7d1c723f3 100644
--- a/qlib/workflow/online/manager.py
+++ b/qlib/workflow/online/manager.py
@@ -4,12 +4,20 @@
"""
OnlineManager can manage a set of `Online Strategy <#Online Strategy>`_ and run them dynamically.
-With the change of time, the decisive models will be also changed. In this module, we call those contributing models as `online` models.
-In every routine(such as everyday or every minutes), the `online` models maybe changed and the prediction of them need to be updated.
-So this module provide a series methods to control this process.
+With the change of time, the decisive models will be also changed. In this module, we call those contributing models `online` models.
+In every routine(such as every day or every minute), the `online` models may be changed and the prediction of them needs to be updated.
+So this module provides a series of methods to control this process.
-This module also provide a method to simulate `Online Strategy <#Online Strategy>`_ in the history.
+This module also provides a method to simulate `Online Strategy <#Online Strategy>`_ in history.
Which means you can verify your strategy or find a better one.
+
+There are total 3 situations for using the different trainer:
+
+1: Online: Only use Trainer.
+
+2: Simulate with temporal dependence: Only use Trainer.
+
+3: Simulate without temporal dependence: Use Trainer or DelayTrainer.
"""
import logging
@@ -20,7 +28,7 @@ from qlib import get_module_logger
from qlib.data.data import D
from qlib.log import set_global_logger_level
from qlib.model.ens.ensemble import AverageEnsemble
-from qlib.model.trainer import DelayTrainerR, Trainer
+from qlib.model.trainer import DelayTrainerR, Trainer, TrainerR
from qlib.utils import flatten_dict
from qlib.utils.serial import Serializable
from qlib.workflow.online.strategy import OnlineStrategy
@@ -30,9 +38,12 @@ from qlib.workflow.task.collect import MergeCollector
class OnlineManager(Serializable):
"""
OnlineManager can manage online models with `Online Strategy <#Online Strategy>`_.
- It also provide a history recording which models are onlined at what time.
+ It also provides a history recording of which models are online at what time.
"""
+ STATUS_SIMULATING = "simulating" # when calling `simulate`
+ STATUS_NORMAL = "normal" # the normal status
+
def __init__(
self,
strategies: Union[OnlineStrategy, List[OnlineStrategy]],
@@ -46,8 +57,8 @@ class OnlineManager(Serializable):
Args:
strategies (Union[OnlineStrategy, List[OnlineStrategy]]): an instance of OnlineStrategy or a list of OnlineStrategy
- begin_time (Union[str,pd.Timestamp], optional): the OnlineManager will begin at this time. Defaults to None for using latest date.
- trainer (Trainer): the trainer to train task. None for using DelayTrainerR.
+ begin_time (Union[str,pd.Timestamp], optional): the OnlineManager will begin at this time. Defaults to None for using the latest date.
+ trainer (Trainer): the trainer to train task. None for using TrainerR.
freq (str, optional): data frequency. Defaults to "day".
"""
self.logger = get_module_logger(self.__class__.__name__)
@@ -59,12 +70,13 @@ class OnlineManager(Serializable):
begin_time = D.calendar(freq=self.freq).max()
self.begin_time = pd.Timestamp(begin_time)
self.cur_time = self.begin_time
- # OnlineManager will recorder the history of online models, which is a dict like {begin_time, {strategy, [online_models]}}. begin_time means when online_models are onlined.
+ # OnlineManager will recorder the history of online models, which is a dict like {pd.Timestamp, {strategy, [online_models]}}.
self.history = {}
if trainer is None:
- trainer = DelayTrainerR()
+ trainer = TrainerR()
self.trainer = trainer
self.signals = None
+ self.status = self.STATUS_NORMAL
def first_train(self, strategies: List[OnlineStrategy] = None, model_kwargs: dict = {}):
"""
@@ -75,37 +87,36 @@ class OnlineManager(Serializable):
strategies (List[OnlineStrategy]): the strategies list (need this param when adding strategies). None for use default strategies.
model_kwargs (dict): the params for `prepare_online_models`
"""
- models_list = []
if strategies is None:
strategies = self.strategies
for strategy in strategies:
+
self.logger.info(f"Strategy `{strategy.name_id}` begins first training...")
tasks = strategy.first_tasks()
models = self.trainer.train(tasks, experiment_name=strategy.name_id)
- models_list.append(models)
+ models = self.trainer.end_train(models, experiment_name=strategy.name_id)
+ self.logger.info(f"Finished training {len(models)} models.")
- for strategy, models in zip(strategies, models_list):
- self.prepare_online_models(strategy, models, model_kwargs=model_kwargs)
+ online_models = strategy.prepare_online_models(models, **model_kwargs)
+ self.history.setdefault(self.cur_time, {})[strategy] = online_models
def routine(
self,
cur_time: Union[str, pd.Timestamp] = None,
- delay: bool = False,
task_kwargs: dict = {},
model_kwargs: dict = {},
signal_kwargs: dict = {},
):
"""
- Run typical update process for every strategy and record the online history.
+ Typical update process for every strategy and record the online history.
The typical update process after a routine, such as day by day or month by month.
- The process is: Prepare signals -> Prepare tasks -> Prepare online models.
+ The process is: Update predictions -> Prepare tasks -> Prepare online models -> Prepare signals.
If using DelayTrainer, it can finish training all together after every strategy's prepare_tasks.
Args:
cur_time (Union[str,pd.Timestamp], optional): run routine method in this time. Defaults to None.
- delay (bool): if delay prepare signals and models
task_kwargs (dict): the params for `prepare_tasks`
model_kwargs (dict): the params for `prepare_online_models`
signal_kwargs (dict): the params for `prepare_signals`
@@ -113,40 +124,23 @@ class OnlineManager(Serializable):
if cur_time is None:
cur_time = D.calendar(freq=self.freq).max()
self.cur_time = pd.Timestamp(cur_time) # None for latest date
- models_list = []
+
for strategy in self.strategies:
self.logger.info(f"Strategy `{strategy.name_id}` begins routine...")
- if not delay:
+ if self.status == self.STATUS_NORMAL:
strategy.tool.update_online_pred()
tasks = strategy.prepare_tasks(self.cur_time, **task_kwargs)
models = self.trainer.train(tasks)
+ if self.status == self.STATUS_NORMAL or not self.trainer.is_delay():
+ models = self.trainer.end_train(models, experiment_name=strategy.name_id)
self.logger.info(f"Finished training {len(models)} models.")
- models_list.append(models)
+ online_models = strategy.prepare_online_models(models, **model_kwargs)
+ self.history.setdefault(self.cur_time, {})[strategy] = online_models
- for strategy, models in zip(self.strategies, models_list):
- self.prepare_online_models(strategy, models, delay=delay, model_kwargs=model_kwargs)
-
- if not delay:
+ if not self.trainer.is_delay():
self.prepare_signals(**signal_kwargs)
- def prepare_online_models(
- self, strategy: OnlineStrategy, models: list, delay: bool = False, model_kwargs: dict = {}
- ):
- """
- Prepare online model for strategy, including end_train, reset_online_tag and add history.
-
- Args:
- strategy (OnlineStrategy): the instance of strategy.
- models (list): a list of models.
- delay (bool, optional): if delay prepare models. Defaults to False.
- model_kwargs (dict, optional): the params for `prepare_online_models`.
- """
- if not delay:
- models = self.trainer.end_train(models, experiment_name=strategy.name_id)
- online_models = strategy.prepare_online_models(models, **model_kwargs)
- self.history.setdefault(self.cur_time, {})[strategy] = online_models
-
def get_collector(self) -> MergeCollector:
"""
Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results from every strategy.
@@ -162,7 +156,7 @@ class OnlineManager(Serializable):
def add_strategy(self, strategies: Union[OnlineStrategy, List[OnlineStrategy]]):
"""
- Add some new strategies to online manager.
+ Add some new strategies to OnlineManager.
Args:
strategy (Union[OnlineStrategy, List[OnlineStrategy]]): a list of OnlineStrategy
@@ -174,9 +168,9 @@ class OnlineManager(Serializable):
def prepare_signals(self, prepare_func: Callable = AverageEnsemble(), over_write=False):
"""
- After perparing the data of last routine (a box in box-plot) which means the end of the routine, we can prepare trading signals for next routine.
+ After preparing the data of the last routine (a box in box-plot) which means the end of the routine, we can prepare trading signals for the next routine.
- NOTE: Given a set prediction, all signals before these prediction end time will be prepared well.
+ NOTE: Given a set prediction, all signals before these prediction end times will be prepared well.
Even if the latest signal already exists, the latest calculation result will be overwritten.
@@ -185,7 +179,7 @@ class OnlineManager(Serializable):
Given a prediction of a certain time, all signals before this time will be prepared well.
Args:
- prepare_func (Callable, optional): Get signals from a dict after collecting. Defaults to AverageEnsemble(), the results after mergecollector must be {xxx:pred}.
+ prepare_func (Callable, optional): Get signals from a dict after collecting. Defaults to AverageEnsemble(), the results collected by MergeCollector must be {xxx:pred}.
over_write (bool, optional): If True, the new signals will overwrite. If False, the new signals will append to the end of signals. Defaults to False.
Returns:
@@ -209,18 +203,18 @@ class OnlineManager(Serializable):
Returns:
Union[pd.Series, pd.DataFrame]: pd.Series for only one signals every datetime.
- pd.DataFrame for multiple signals, for example, buy and sell operation use different trading signal.
+ pd.DataFrame for multiple signals, for example, buy and sell operations use different trading signals.
"""
return self.signals
SIM_LOG_LEVEL = logging.INFO + 1
SIM_LOG_NAME = "SIMULATE_INFO"
- def simulate(self, end_time, frequency="day", task_kwargs={}, model_kwargs={}, signal_kwargs={}):
+ def simulate(self, end_time, frequency="day", task_kwargs={}, model_kwargs={}, signal_kwargs={}) -> Union[pd.Series, pd.DataFrame]:
"""
- Starting from current time, this method will simulate every routine in OnlineManager until end time.
+ Starting from the current time, this method will simulate every routine in OnlineManager until the end time.
- Considering the parallel training, the models and signals can be perpared after all routine simulating.
+ Considering the parallel training, the models and signals can be prepared after all routine simulating.
The delay training way can be ``DelayTrainer`` and the delay preparing signals way can be ``delay_prepare``.
@@ -232,8 +226,10 @@ class OnlineManager(Serializable):
signal_kwargs (dict): the params for `prepare_signals`
Returns:
- HyperCollector: the OnlineManager's collector
+ Union[pd.Series, pd.DataFrame]: pd.Series for only one signals every datetime.
+ pd.DataFrame for multiple signals, for example, buy and sell operations use different trading signals.
"""
+ self.status = self.STATUS_SIMULATING
cal = D.calendar(start_time=self.cur_time, end_time=end_time, freq=frequency)
self.first_train()
@@ -245,7 +241,6 @@ class OnlineManager(Serializable):
self.logger.log(level=simulate_level, msg=f"Simulating at {str(cur_time)}......")
self.routine(
cur_time,
- delay=self.trainer.is_delay(),
task_kwargs=task_kwargs,
model_kwargs=model_kwargs,
signal_kwargs=signal_kwargs,
@@ -257,11 +252,12 @@ class OnlineManager(Serializable):
# FIXME: get logging level firstly and restore it here
set_global_logger_level(logging.DEBUG)
self.logger.info(f"Finished preparing signals")
- return self.get_collector()
+ self.status = self.STATUS_NORMAL
+ return self.get_signals()
def delay_prepare(self, model_kwargs={}, signal_kwargs={}):
"""
- Prepare all models and signals if there are something waiting for prepare.
+ Prepare all models and signals if something is waiting for preparation.
Args:
model_kwargs: the params for `prepare_online_models`
@@ -270,6 +266,6 @@ class OnlineManager(Serializable):
for cur_time, strategy_models in self.history.items():
self.cur_time = cur_time
for strategy, models in strategy_models.items():
- self.prepare_online_models(strategy, models, delay=False, model_kwargs=model_kwargs)
+ models = self.trainer.end_train(models, experiment_name=strategy.name_id)
# NOTE: Assumption: the predictions of online models need less than next cur_time, or this method will work in a wrong way.
self.prepare_signals(**signal_kwargs)
diff --git a/qlib/workflow/online/strategy.py b/qlib/workflow/online/strategy.py
index 04c854f79..491b191dd 100644
--- a/qlib/workflow/online/strategy.py
+++ b/qlib/workflow/online/strategy.py
@@ -2,7 +2,7 @@
# Licensed under the MIT License.
"""
-OnlineStrategy is a set of strategy for online serving.
+OnlineStrategy module is an element of online serving.
"""
from copy import deepcopy
@@ -19,7 +19,7 @@ from qlib.workflow.task.utils import TimeAdjuster
class OnlineStrategy:
"""
- OnlineStrategy is working with `Online Manager <#Online Manager>`_, responsing how the tasks are generated, the models are updated and signals are perpared.
+ OnlineStrategy is working with `Online Manager <#Online Manager>`_, responding to how the tasks are generated, the models are updated and signals are prepared.
"""
def __init__(self, name_id: str):
@@ -28,7 +28,7 @@ class OnlineStrategy:
This module **MUST** use `Trainer <../reference/api.html#Trainer>`_ to finishing model training.
Args:
- name_id (str): a unique name or id
+ name_id (str): a unique name or id.
trainer (Trainer, optional): a instance of Trainer. Defaults to None.
"""
self.name_id = name_id
@@ -40,29 +40,29 @@ class OnlineStrategy:
After the end of a routine, check whether we need to prepare and train some new tasks based on cur_time (None for latest)..
Return the new tasks waiting for training.
- You can find last online models by OnlineTool.online_models.
+ You can find the last online models by OnlineTool.online_models.
"""
raise NotImplementedError(f"Please implement the `prepare_tasks` method.")
- def prepare_online_models(self, models, cur_time=None) -> List[object]:
+ def prepare_online_models(self, trained_models, cur_time=None) -> List[object]:
"""
Select some models from trained models and set them to online models.
- This is a typically implementation to online all trained models, you can override it to implement complex method.
- You can find last online models by OnlineTool.online_models if you still need them.
+ This is a typical implementation to online all trained models, you can override it to implement the complex method.
+ You can find the last online models by OnlineTool.online_models if you still need them.
- NOTE: Reset all online models to trained model. If there is no trained models, then do nothing.
+ NOTE: Reset all online models to trained models. If there are no trained models, then do nothing.
Args:
models (list): a list of models.
- cur_time (pd.Dataframe): current time from OnlineManger. None for latest.
+ cur_time (pd.Dataframe): current time from OnlineManger. None for the latest.
Returns:
List[object]: a list of online models.
"""
- if not models:
+ if not trained_models:
return self.tool.online_models()
- self.tool.reset_online_tag(models)
- return models
+ self.tool.reset_online_tag(trained_models)
+ return trained_models
def first_tasks(self) -> List[dict]:
"""
@@ -87,7 +87,7 @@ class OnlineStrategy:
class RollingStrategy(OnlineStrategy):
"""
- This example strategy always use latest rolling model as online model.
+ This example strategy always uses the latest rolling model sas online models.
"""
def __init__(
@@ -99,11 +99,11 @@ class RollingStrategy(OnlineStrategy):
"""
Init RollingStrategy.
- Assumption: the str of name_id, the experiment name and the trainer's experiment name are same one.
+ Assumption: the str of name_id, the experiment name, and the trainer's experiment name are the same.
Args:
- name_id (str): a unique name or id. Will be also the name of Experiment.
- task_template (Union[dict,List[dict]]): a list of task_template or a single template, which will be used to generate many tasks using rolling_gen.
+ name_id (str): a unique name or id. Will be also the name of the Experiment.
+ task_template (Union[dict, List[dict]]): a list of task_template or a single template, which will be used to generate many tasks using rolling_gen.
rolling_gen (RollingGen): an instance of RollingGen
"""
super().__init__(name_id=name_id)
@@ -117,9 +117,10 @@ class RollingStrategy(OnlineStrategy):
def get_collector(self, process_list=[RollingGroup()], rec_key_func=None, rec_filter_func=None, artifacts_key=None):
"""
- Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results. The returned collector must can distinguish results in different models.
- Assumption: the models can be distinguished based on model name and rolling test segments.
- If you do not want this assumption, please implement your own method or use another rec_key_func.
+ Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results. The returned collector must distinguish results in different models.
+
+ Assumption: the models can be distinguished based on the model name and rolling test segments.
+ If you do not want this assumption, please implement your method or use another rec_key_func.
Args:
rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.
@@ -160,9 +161,9 @@ class RollingStrategy(OnlineStrategy):
def prepare_tasks(self, cur_time) -> List[dict]:
"""
- Prepare new tasks based on cur_time (None for latest).
+ Prepare new tasks based on cur_time (None for the latest).
- You can find last online models by OnlineToolR.online_models.
+ You can find the last online models by OnlineToolR.online_models.
Returns:
List[dict]: a list of new tasks.
@@ -198,7 +199,7 @@ class RollingStrategy(OnlineStrategy):
rec_list (List[Recorder]): a list of Recorder
Returns:
- List[Recorder], pd.Timestamp: the latest recorders and its test end time
+ List[Recorder], pd.Timestamp: the latest recorders and their test end time
"""
if len(rec_list) == 0:
return rec_list, None
diff --git a/qlib/workflow/online/update.py b/qlib/workflow/online/update.py
index 9cb294169..561f7e18a 100644
--- a/qlib/workflow/online/update.py
+++ b/qlib/workflow/online/update.py
@@ -2,7 +2,7 @@
# Licensed under the MIT License.
"""
-Updater is a module to update artifacts such as predictions, when the stock data is updating.
+Updater is a module to update artifacts such as predictions when the stock data is updating.
"""
from abc import ABCMeta, abstractmethod
@@ -87,7 +87,7 @@ class PredUpdater(RecordUpdater):
update to prediction to the `to_date`
hist_ref : int
Sometimes, the dataset will have historical depends.
- Leave the problem to user to set the length of historical dependency
+ Leave the problem to users to set the length of historical dependency
.. note::
@@ -112,7 +112,7 @@ class PredUpdater(RecordUpdater):
"""
Load dataset
- Seperating this function will make it easier to reuse the dataset
+ Separating this function will make it easier to reuse the dataset
Returns:
DatasetH: the instance of DatasetH
@@ -125,7 +125,7 @@ class PredUpdater(RecordUpdater):
def update(self, dataset: DatasetH = None):
"""
- Update the precition in a recorder
+ Update the prediction in a recorder.
Args:
DatasetH: the instance of DatasetH. None for reprepare.
diff --git a/qlib/workflow/online/utils.py b/qlib/workflow/online/utils.py
index 3c2774cec..f3ef13aa9 100644
--- a/qlib/workflow/online/utils.py
+++ b/qlib/workflow/online/utils.py
@@ -3,8 +3,8 @@
"""
OnlineTool is a module to set and unset a series of `online` models.
-The `online` models are some decisive models in some time point, which can be changed with the change of time.
-This allows us to use efficient submodels as the market style changing.
+The `online` models are some decisive models in some time points, which can be changed with the change of time.
+This allows us to use efficient submodels as the market-style changing.
"""
from typing import List, Union
@@ -17,7 +17,7 @@ from qlib.workflow.task.utils import list_recorders
class OnlineTool:
"""
- OnlineTool will manage `online` models in an experiment which includes the models recorder.
+ OnlineTool will manage `online` models in an experiment that includes the model recorders.
"""
ONLINE_KEY = "online_status" # the online status key in recorder
@@ -74,10 +74,10 @@ class OnlineTool:
def update_online_pred(self, to_date=None):
"""
- Update the predictions of `online` models to a date.
+ Update the predictions of `online` models to to_date.
Args:
- to_date (pd.Timestamp): the pred before this date will be updated. None for update to latest.
+ to_date (pd.Timestamp): the pred before this date will be updated. None for updating to the latest.
"""
raise NotImplementedError(f"Please implement the `update_online_pred` method.")
@@ -151,15 +151,16 @@ class OnlineToolR(OnlineTool):
def update_online_pred(self, to_date=None):
"""
- Update the predictions of online models to a date.
+ Update the predictions of online models to to_date.
Args:
- to_date (pd.Timestamp): the pred before this date will be updated. None for update to latest time in Calendar.
+ to_date (pd.Timestamp): the pred before this date will be updated. None for updating to latest time in Calendar.
"""
online_models = self.online_models()
for rec in online_models:
hist_ref = 0
task = rec.load_object("task")
+ # Special treatment of historical dependencies
if task["dataset"]["class"] == "TSDatasetH":
hist_ref = task["dataset"]["kwargs"]["step_len"]
PredUpdater(rec, to_date=to_date, hist_ref=hist_ref).update()
diff --git a/qlib/workflow/task/collect.py b/qlib/workflow/task/collect.py
index 1080d07f4..3a8bd1f2c 100644
--- a/qlib/workflow/task/collect.py
+++ b/qlib/workflow/task/collect.py
@@ -2,7 +2,7 @@
# Licensed under the MIT License.
"""
-Collector can collect object from everywhere and process them such as merging, grouping, averaging and so on.
+Collector module can collect objects from everywhere and process them such as merging, grouping, averaging and so on.
"""
from typing import Callable, Dict, List
@@ -17,15 +17,18 @@ class Collector(Serializable):
def __init__(self, process_list=[]):
"""
+ Init Collector.
+
Args:
- process_list (list, optional): process_list (list or Callable): the list of processors or the instance of processor to process dict.
+ process_list (list or Callable): the list of processors or the instance of a processor to process dict.
"""
if not isinstance(process_list, list):
process_list = [process_list]
self.process_list = process_list
def collect(self) -> dict:
- """Collect the results and return a dict like {key: things}
+ """
+ Collect the results and return a dict like {key: things}
Returns:
dict: the dict after collecting.
@@ -42,13 +45,14 @@ class Collector(Serializable):
@staticmethod
def process_collect(collected_dict, process_list=[], *args, **kwargs) -> dict:
- """do a series of processing to the dict returned by collect and return a dict like {key: things}
- For example: you can group and ensemble.
+ """
+ Do a series of processing to the dict returned by collect and return a dict like {key: things}
+ For example, you can group and ensemble.
Args:
collected_dict (dict): the dict return by `collect`
- process_list (list or Callable): the list of processors or the instance of processor to process dict.
- The processor order is same as the list order.
+ process_list (list or Callable): the list of processors or the instance of a processor to process dict.
+ The processor order is the same as the list order.
For example: [Group1(..., Ensemble1()), Group2(..., Ensemble2())]
Returns:
@@ -68,7 +72,7 @@ class Collector(Serializable):
def __call__(self, *args, **kwargs) -> dict:
"""
- do the workflow including collect and process_collect
+ Do the workflow including ``collect`` and ``process_collect``
Returns:
dict: the dict after collecting and processing.
@@ -93,11 +97,13 @@ class MergeCollector(Collector):
def __init__(self, collector_dict: Dict[str, Collector], process_list: List[Callable] = [], merge_func=None):
"""
+ Init MergeCollector.
+
Args:
collector_dict (Dict[str,Collector]): the dict like {collector_key, Collector}
process_list (List[Callable]): the list of processors or the instance of processor to process dict.
merge_func (Callable): a method to generate outermost key. The given params are ``collector_key`` from collector_dict and ``key`` from every collector after collecting.
- None for use tuple to connect them, such as "ABC"+("a","b") -> ("ABC", ("a","b")).
+ None for using tuple to connect them, such as "ABC"+("a","b") -> ("ABC", ("a","b")).
"""
super().__init__(process_list=process_list)
self.collector_dict = collector_dict
@@ -105,7 +111,7 @@ class MergeCollector(Collector):
def collect(self) -> dict:
"""
- Collect all result of collector_dict and change the outermost key to a recombination key.
+ Collect all results of collector_dict and change the outermost key to a recombination key.
Returns:
dict: the dict after collecting.
@@ -133,11 +139,12 @@ class RecorderCollector(Collector):
artifacts_path={"pred": "pred.pkl"},
artifacts_key=None,
):
- """init RecorderCollector
+ """
+ Init RecorderCollector.
Args:
- experiment (Experiment or str): an instance of a Experiment or the name of a Experiment
- process_list (list or Callable): the list of processors or the instance of processor to process dict.
+ experiment (Experiment or str): an instance of an Experiment or the name of an Experiment
+ process_list (list or Callable): the list of processors or the instance of a processor to process dict.
rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
artifacts_path (dict, optional): The artifacts name and its path in Recorder. Defaults to {"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}.
@@ -157,12 +164,13 @@ class RecorderCollector(Collector):
self.rec_filter_func = rec_filter_func
def collect(self, artifacts_key=None, rec_filter_func=None, only_exist=True) -> dict:
- """Collect different artifacts based on recorder after filtering.
+ """
+ Collect different artifacts based on recorder after filtering.
Args:
- artifacts_key (str or List, optional): the artifacts key you want to get. If None, use default.
- rec_filter_func (Callable, optional): filter the recorder by return True or False. If None, use default.
- only_exist (bool, optional): if only collect the artifacts when a recorder really have.
+ artifacts_key (str or List, optional): the artifacts key you want to get. If None, use the default.
+ rec_filter_func (Callable, optional): filter the recorder by return True or False. If None, use the default.
+ only_exist (bool, optional): if only collect the artifacts when a recorder really has.
If True, the recorder with exception when loading will not be collected. But if False, it will raise the exception.
Returns:
diff --git a/qlib/workflow/task/gen.py b/qlib/workflow/task/gen.py
index 7e08c76f4..cdebf5049 100644
--- a/qlib/workflow/task/gen.py
+++ b/qlib/workflow/task/gen.py
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
-Task generator can generate many tasks based on TaskGen and some task templates.
+TaskGenerator module can generate many tasks based on TaskGen and some task templates.
"""
import abc
import copy
@@ -10,7 +10,8 @@ from .utils import TimeAdjuster
def task_generator(tasks, generators) -> list:
- """Use a list of TaskGen and a list of task templates to generate different tasks.
+ """
+ Use a list of TaskGen and a list of task templates to generate different tasks.
For examples:
@@ -47,7 +48,7 @@ def task_generator(tasks, generators) -> list:
class TaskGen(metaclass=abc.ABCMeta):
"""
- the base class for generate different tasks
+ The base class for generating different tasks
Example 1:
@@ -66,7 +67,7 @@ class TaskGen(metaclass=abc.ABCMeta):
@abc.abstractmethod
def generate(self, task: dict) -> List[dict]:
"""
- generate different tasks based on a task template
+ Generate different tasks based on a task template
Parameters
----------
@@ -87,7 +88,7 @@ class TaskGen(metaclass=abc.ABCMeta):
return self.generate(*args, **kwargs)
-def handler_mod(task: dict, rg):
+def handler_mod(task: dict, rolling_gen):
"""
Help to modify the handler end time when using RollingGen
@@ -96,14 +97,14 @@ def handler_mod(task: dict, rg):
rg (RollingGen): an instance of RollingGen
"""
try:
- interval = rg.ta.cal_interval(
+ interval = rolling_gen.ta.cal_interval(
task["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"],
- task["dataset"]["kwargs"]["segments"][rg.test_key][1],
+ task["dataset"]["kwargs"]["segments"][rolling_gen.test_key][1],
)
# if end_time < the end of test_segments, then change end_time to allow load more data
if interval < 0:
task["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"] = copy.deepcopy(
- task["dataset"]["kwargs"]["segments"][rg.test_key][1]
+ task["dataset"]["kwargs"]["segments"][rolling_gen.test_key][1]
)
except KeyError:
# Maybe dataset do not have handler, then do nothing.
@@ -126,7 +127,7 @@ class RollingGen(TaskGen):
rolling type (expanding, sliding)
ds_extra_mod_func: Callable
A method like: handler_mod(task: dict, rg: RollingGen)
- Do some extra action after generating a task. For example, use ``handler_mod`` to modify the end time of handler of dataset.
+ Do some extra action after generating a task. For example, use ``handler_mod`` to modify the end time of the handler of a dataset.
"""
self.step = step
self.rtype = rtype
@@ -142,7 +143,7 @@ class RollingGen(TaskGen):
Parameters
----------
- task : dict
+ task: dict
A dict describing a task. For example.
.. code-block:: python
@@ -184,7 +185,7 @@ class RollingGen(TaskGen):
Returns
----------
- typing.List[dict]: a list of tasks
+ List[dict]: a list of tasks
"""
res = []
diff --git a/qlib/workflow/task/manage.py b/qlib/workflow/task/manage.py
index 40f868295..658eec4d6 100644
--- a/qlib/workflow/task/manage.py
+++ b/qlib/workflow/task/manage.py
@@ -2,7 +2,7 @@
# Licensed under the MIT License.
"""
-TaskManager can fetch unused tasks automatically and manager the lifecycle of a set of tasks with error handling.
+TaskManager can fetch unused tasks automatically and manage the lifecycle of a set of tasks with error handling.
These features can run tasks concurrently and ensure every task will be used only once.
Task Manager will store all tasks in `MongoDB `_.
Users **MUST** finished the configuration of `MongoDB `_ when using this module.
@@ -10,7 +10,7 @@ Users **MUST** finished the configuration of `MongoDB
A task in TaskManager consists of 3 parts
- tasks description: the desc will define the task
- tasks status: the status of the task
-- tasks result information : A user can get the task with the task description and task result.
+- tasks result: A user can get the task with the task description and task result.
"""
import concurrent
import pickle
@@ -44,7 +44,7 @@ class TaskManager:
'res': pickle serialized task result,
}
- The tasks manager assume that you will only update the tasks you fetched.
+ The tasks manager assumes that you will only update the tasks you fetched.
The mongo fetch one and update will make it date updating secure.
.. note::
@@ -53,7 +53,7 @@ class TaskManager:
Here are four status which are:
- STATUS_WAITING: waiting for train
+ STATUS_WAITING: waiting for training
STATUS_RUNNING: training
@@ -85,7 +85,7 @@ class TaskManager:
def list(self) -> list:
"""
- list the all collection(task_pool) of the db
+ List the all collection(task_pool) of the db
Returns:
list
@@ -112,6 +112,10 @@ class TaskManager:
def replace_task(self, task, new_task):
"""
Use a new task to replace a old one
+
+ Args:
+ task: old task
+ new_task: new task
"""
new_task = self._encode_task(new_task)
query = {"_id": ObjectId(task["_id"])}
@@ -122,7 +126,15 @@ class TaskManager:
self.task_pool.replace_one(query, new_task)
def insert_task(self, task):
+ """
+ Insert a task.
+ Args:
+ task: the task waiting for insert
+
+ Returns:
+ pymongo.results.InsertOneResult
+ """
try:
insert_result = self.task_pool.insert_one(task)
except InvalidDocument:
@@ -132,7 +144,7 @@ class TaskManager:
def insert_task_def(self, task_def):
"""
- insert a task to task_pool
+ Insert a task to task_pool
Parameters
----------
@@ -155,8 +167,8 @@ class TaskManager:
def create_task(self, task_def_l, dry_run=False, print_nt=False) -> List[str]:
"""
- If the tasks in task_def_l is new, then insert new tasks into the task_pool, and record inserted_id.
- If a task is not new, then query its _id.
+ If the tasks in task_def_l are new, then insert new tasks into the task_pool, and record inserted_id.
+ If a task is not new, then just query its _id.
Parameters
----------
@@ -169,7 +181,7 @@ class TaskManager:
Returns
-------
- list
+ List[str]
a list of the _id of task_def_l
"""
new_tasks = []
@@ -202,7 +214,7 @@ class TaskManager:
def fetch_task(self, query={}, status=STATUS_WAITING) -> dict:
"""
- Use query to fetch tasks
+ Use query to fetch tasks.
Args:
query (dict, optional): query dict. Defaults to {}.
@@ -257,6 +269,7 @@ class TaskManager:
def query(self, query={}, decode=True):
"""
+ Query task in collection.
This function may raise exception `pymongo.errors.CursorNotFound: cursor id not found` if it takes too long to iterate the generator
Parameters
@@ -330,7 +343,16 @@ class TaskManager:
query["_id"] = ObjectId(query["_id"])
self.task_pool.delete_many(query)
- def task_stat(self, query={}):
+ def task_stat(self, query={}) -> dict:
+ """
+ Count the tasks in every status.
+
+ Args:
+ query (dict, optional): the query dict. Defaults to {}.
+
+ Returns:
+ dict
+ """
query = query.copy()
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
@@ -341,6 +363,12 @@ class TaskManager:
return status_stat
def reset_waiting(self, query={}):
+ """
+ Reset all running task into waiting status. Can be used when some running task exit unexpected.
+
+ Args:
+ query (dict, optional): the query dict. Defaults to {}.
+ """
query = query.copy()
# default query
if "status" not in query:
@@ -400,7 +428,7 @@ def run_task(
**kwargs,
):
"""
- While task pool is not empty (has WAITING tasks), use task_func to fetch and run tasks in task_pool
+ While the task pool is not empty (has WAITING tasks), use task_func to fetch and run tasks in task_pool
After running this method, here are 4 situations (before_status -> after_status):
diff --git a/qlib/workflow/task/utils.py b/qlib/workflow/task/utils.py
index ed5e1a235..89059e9f8 100644
--- a/qlib/workflow/task/utils.py
+++ b/qlib/workflow/task/utils.py
@@ -19,8 +19,9 @@ from typing import Union
def get_mongodb() -> Database:
"""
- Get database in MongoDB, which means you need to declare the address and the name of database.
- for example:
+ Get database in MongoDB, which means you need to declare the address and the name of a database at first.
+
+ For example:
Using qlib.init():
@@ -52,10 +53,10 @@ def get_mongodb() -> Database:
def list_recorders(experiment, rec_filter_func=None):
"""
- List all recorders which can pass the filter in a experiment.
+ List all recorders which can pass the filter in an experiment.
Args:
- experiment (str or Experiment): the name of a Experiment or a instance
+ experiment (str or Experiment): the name of an Experiment or an instance
rec_filter_func (Callable, optional): return True to retain the given recorder. Defaults to None.
Returns:
@@ -82,11 +83,17 @@ class TimeAdjuster:
self.cals = D.calendar(future=future, end_time=end_time)
def set_end_time(self, end_time=None):
+ """
+ Set end time. None for use calendar's end time.
+
+ Args:
+ end_time
+ """
self.cals = D.calendar(future=self._future, end_time=end_time)
def get(self, idx: int):
"""
- Get datetime by index
+ Get datetime by index.
Parameters
----------
@@ -105,7 +112,7 @@ class TimeAdjuster:
def align_idx(self, time_point, tp_type="start") -> int:
"""
- Align the index of time_point in the calendar
+ Align the index of time_point in the calendar.
Parameters
----------
@@ -155,7 +162,7 @@ class TimeAdjuster:
def align_seg(self, segment: Union[dict, tuple]) -> Union[dict, tuple]:
"""
- align the given date to trade date
+ Align the given date to the trade date
for example:
@@ -184,7 +191,7 @@ class TimeAdjuster:
def truncate(self, segment: tuple, test_start, days: int) -> tuple:
"""
- truncate the segment based on the test_start date
+ Truncate the segment based on the test_start date
Parameters
----------
@@ -215,7 +222,7 @@ class TimeAdjuster:
def shift(self, seg: tuple, step: int, rtype=SHIFT_SD) -> tuple:
"""
- shift the datatime of segment
+ Shift the datatime of segment
Parameters
----------