1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Merge pull request #290 from you-n-g/online_srv

init version of online serving and rolling
This commit is contained in:
you-n-g
2021-05-17 17:35:29 +08:00
committed by GitHub
39 changed files with 3903 additions and 125 deletions

BIN
docs/_static/img/online_serving.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 440 KiB

View File

@@ -14,6 +14,9 @@ Serializable Class
``Qlib`` provides a base class ``qlib.utils.serial.Serializable``, whose state can be dumped into or loaded from disk in `pickle` format.
When users dump the state of a ``Serializable`` instance, the attributes of the instance whose name **does not** start with `_` will be saved on the disk.
However, users can use ``config`` method or override ``default_dump_all`` attribute to prevent this feature.
Users can also override ``pickle_backend`` attribute to choose a pickle backend. The supported value is "pickle" (default and common) and "dill" (dump more things such as function, more information in `here <https://pypi.org/project/dill/>`_).
Example
==========================

View File

@@ -0,0 +1,89 @@
.. _task_management:
=================================
Task Management
=================================
.. currentmodule:: qlib
Introduction
=============
The `Workflow <../component/introduction.html>`_ part introduces how to run research workflow in a loosely-coupled way. But it can only execute one ``task`` when you use ``qrun``.
To automatically generate and execute different tasks, ``Task Management`` provides a whole process including `Task Generating`_, `Task Storing`_, `Task Training`_ and `Task Collecting`_.
With this module, users can run their ``task`` automatically at different periods, in different losses, or even by different models.
This whole process can be used in `Online Serving <../component/online.html>`_.
An example of the entire process is shown `here <https://github.com/microsoft/qlib/tree/main/examples/model_rolling/task_manager_rolling.py>`_.
Task Generating
===============
A ``task`` consists of `Model`, `Dataset`, `Record`, or anything added by users.
The specific task template can be viewed in
`Task Section <../component/workflow.html#task-section>`_.
Even though the task template is fixed, users can customize their ``TaskGen`` to generate different ``task`` by task template.
Here is the base class of ``TaskGen``:
.. autoclass:: qlib.workflow.task.gen.TaskGen
:members:
``Qlib`` provides a class `RollingGen <https://github.com/microsoft/qlib/tree/main/qlib/workflow/task/gen.py>`_ to generate a list of ``task`` of the dataset in different date segments.
This class allows users to verify the effect of data from different periods on the model in one experiment. More information is `here <../reference/api.html#TaskGen>`_.
Task Storing
===============
To achieve higher efficiency and the possibility of cluster operation, ``Task Manager`` will store all tasks in `MongoDB <https://www.mongodb.com/>`_.
``TaskManager`` can fetch undone tasks automatically and manage the lifecycle of a set of tasks with error handling.
Users **MUST** finish the configuration of `MongoDB <https://www.mongodb.com/>`_ when using this module.
Users need to provide the MongoDB URL and database name for using ``TaskManager`` in `initialization <../start/initialization.html#Parameters>`_ or make a statement like this.
.. code-block:: python
from qlib.config import C
C["mongo"] = {
"task_url" : "mongodb://localhost:27017/", # your MongoDB url
"task_db_name" : "rolling_db" # database name
}
.. autoclass:: qlib.workflow.task.manage.TaskManager
:members:
More information of ``Task Manager`` can be found in `here <../reference/api.html#TaskManager>`_.
Task Training
===============
After generating and storing those ``task``, it's time to run the ``task`` which is in the *WAITING* status.
``Qlib`` provides a method called ``run_task`` to run those ``task`` in task pool, however, users can also customize how tasks are executed.
An easy way to get the ``task_func`` is using ``qlib.model.trainer.task_train`` directly.
It will run the whole workflow defined by ``task``, which includes *Model*, *Dataset*, *Record*.
.. autofunction:: qlib.workflow.task.manage.run_task
Meanwhile, ``Qlib`` provides a module called ``Trainer``.
.. autoclass:: qlib.model.trainer.Trainer
:members:
``Trainer`` will train a list of tasks and return a list of model recorders.
``Qlib`` offer two kinds of Trainer, TrainerR is the simplest way and TrainerRM is based on TaskManager to help manager tasks lifecycle automatically.
If you do not want to use ``Task Manager`` to manage tasks, then use TrainerR to train a list of tasks generated by ``TaskGen`` is enough.
`Here <../reference/api.html#Trainer>`_ are the details about different ``Trainer``.
Task Collecting
===============
To collect the results of ``task`` after training, ``Qlib`` provides `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_ to collect the results in a readable, expandable and loosely-coupled way.
`Collector <../reference/api.html#Collector>`_ can collect objects from everywhere and process them such as merging, grouping, averaging and so on. It has 2 step action including ``collect`` (collect anything in a dict) and ``process_collect`` (process collected dict).
`Group <../reference/api.html#Group>`_ also has 2 steps including ``group`` (can group a set of object based on `group_func` and change them to a dict) and ``reduce`` (can make a dict become an ensemble based on some rule).
For example: {(A,B,C1): object, (A,B,C2): object} ---``group``---> {(A,B): {C1: object, C2: object}} ---``reduce``---> {(A,B): object}
`Ensemble <../reference/api.html#Ensemble>`_ can merge the objects in an ensemble.
For example: {C1: object, C2: object} ---``Ensemble``---> object
So the hierarchy is ``Collector``'s second step corresponds to ``Group``. And ``Group``'s second step correspond to ``Ensemble``.
For more information, please see `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_, or the `example <https://github.com/microsoft/qlib/tree/main/examples/model_rolling/task_manager_rolling.py>`_.

46
docs/component/online.rst Normal file
View File

@@ -0,0 +1,46 @@
.. _online:
=================================
Online Serving
=================================
.. currentmodule:: qlib
Introduction
=============
.. image:: ../_static/img/online_serving.png
:align: center
In addition to backtesting, one way to test a model is effective is to make predictions in real market conditions or even do real trading based on those predictions.
``Online Serving`` is a set of modules for online models using the latest data,
which including `Online Manager <#Online Manager>`_, `Online Strategy <#Online Strategy>`_, `Online Tool <#Online Tool>`_, `Updater <#Updater>`_.
`Here <https://github.com/microsoft/qlib/tree/main/examples/online_srv>`_ are several examples for reference, which demonstrate different features of ``Online Serving``.
If you have many models or `task` needs to be managed, please consider `Task Management <../advanced/task_management.html>`_.
The `examples <https://github.com/microsoft/qlib/tree/main/examples/online_srv>`_ are based on some components in `Task Management <../advanced/task_management.html>`_ such as ``TrainerRM`` or ``Collector``.
Online Manager
=============
.. automodule:: qlib.workflow.online.manager
:members:
Online Strategy
=============
.. automodule:: qlib.workflow.online.strategy
:members:
Online Tool
=============
.. automodule:: qlib.workflow.online.utils
:members:
Updater
=============
.. automodule:: qlib.workflow.online.update
:members:

View File

@@ -42,6 +42,7 @@ Document Structure
Intraday Trading: Model&Strategy Testing <component/backtest.rst>
Qlib Recorder: Experiment Management <component/recorder.rst>
Analysis: Evaluation & Results Analysis <component/report.rst>
Online Serving: Online Management & Strategy & Tool <component/online.rst>
.. toctree::
:maxdepth: 3
@@ -50,6 +51,7 @@ Document Structure
Building Formulaic Alphas <advanced/alpha.rst>
Online & Offline mode <advanced/server.rst>
Serialization <advanced/serial.rst>
Task Management <advanced/task_management.rst>
.. toctree::
:maxdepth: 3

View File

@@ -154,6 +154,70 @@ Record Template
.. automodule:: qlib.workflow.record_temp
:members:
Task Management
====================
TaskGen
--------------------
.. automodule:: qlib.workflow.task.gen
:members:
TaskManager
--------------------
.. automodule:: qlib.workflow.task.manage
:members:
Trainer
--------------------
.. automodule:: qlib.model.trainer
:members:
Collector
--------------------
.. automodule:: qlib.workflow.task.collect
:members:
Group
--------------------
.. automodule:: qlib.model.ens.group
:members:
Ensemble
--------------------
.. automodule:: qlib.model.ens.ensemble
:members:
Utils
--------------------
.. automodule:: qlib.workflow.task.utils
:members:
Online Serving
====================
Online Manager
--------------------
.. automodule:: qlib.workflow.online.manager
:members:
Online Strategy
--------------------
.. automodule:: qlib.workflow.online.strategy
:members:
Online Tool
--------------------
.. automodule:: qlib.workflow.online.utils
:members:
RecordUpdater
--------------------
.. automodule:: qlib.workflow.online.update
:members:
Utils
====================
@@ -162,4 +226,7 @@ Serializable
--------------------
.. automodule:: qlib.utils.serial.Serializable
:members:
:members:

View File

@@ -75,3 +75,14 @@ Besides `provider_uri` and `region`, `qlib.init` has other parameters. The follo
"default_exp_name": "Experiment",
}
})
- `mongo`
Type: dict, optional parameter, the setting of `MongoDB <https://www.mongodb.com/>`_ which will be used in some features such as `Task Management <../advanced/task_management.html>`_, with high performance and clustered processing.
Users need finished `installation <https://www.mongodb.com/try/download/community>`_ firstly, and run it in a fixed URL.
.. code-block:: Python
# For example, you can initialize qlib below
qlib.init(provider_uri=provider_uri, region=REG_CN, mongo={
"task_url": "mongodb://localhost:27017/", # your mongo url
"task_db_name": "rolling_db", # the database name of Task Management
})

View File

@@ -0,0 +1,159 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
This example shows how a TrainerRM works based on TaskManager with rolling tasks.
After training, how to collect the rolling results will be shown in task_collecting.
"""
from pprint import pprint
import fire
import qlib
from qlib.config import REG_CN
from qlib.workflow import R
from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.manage import TaskManager
from qlib.workflow.task.collect import RecorderCollector
from qlib.model.ens.group import RollingGroup
from qlib.model.trainer import TrainerRM
data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": "csi100",
}
dataset_config = {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": data_handler_config,
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
},
},
}
record_config = [
{
"class": "SignalRecord",
"module_path": "qlib.workflow.record_temp",
},
{
"class": "SigAnaRecord",
"module_path": "qlib.workflow.record_temp",
},
]
# use lgb
task_lgb_config = {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
},
"dataset": dataset_config,
"record": record_config,
}
# use xgboost
task_xgboost_config = {
"model": {
"class": "XGBModel",
"module_path": "qlib.contrib.model.xgboost",
},
"dataset": dataset_config,
"record": record_config,
}
class RollingTaskExample:
def __init__(
self,
provider_uri="~/.qlib/qlib_data/cn_data",
region=REG_CN,
task_url="mongodb://10.0.0.4:27017/",
task_db_name="rolling_db",
experiment_name="rolling_exp",
task_pool="rolling_task",
task_config=[task_xgboost_config, task_lgb_config],
rolling_step=550,
rolling_type=RollingGen.ROLL_SD,
):
# TaskManager config
mongo_conf = {
"task_url": task_url,
"task_db_name": task_db_name,
}
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
self.experiment_name = experiment_name
self.task_pool = task_pool
self.task_config = task_config
self.rolling_gen = RollingGen(step=rolling_step, rtype=rolling_type)
# Reset all things to the first status, be careful to save important data
def reset(self):
print("========== reset ==========")
TaskManager(task_pool=self.task_pool).remove()
exp = R.get_exp(experiment_name=self.experiment_name)
for rid in exp.list_recorders():
exp.delete_recorder(rid)
def task_generating(self):
print("========== task_generating ==========")
tasks = task_generator(
tasks=self.task_config,
generators=self.rolling_gen, # generate different date segments
)
pprint(tasks)
return tasks
def task_training(self, tasks):
print("========== task_training ==========")
trainer = TrainerRM(self.experiment_name, self.task_pool)
trainer.train(tasks)
def task_collecting(self):
print("========== task_collecting ==========")
def rec_key(recorder):
task_config = recorder.load_object("task")
model_key = task_config["model"]["class"]
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
return model_key, rolling_key
def my_filter(recorder):
# only choose the results of "LGBModel"
model_key, rolling_key = rec_key(recorder)
if model_key == "LGBModel":
return True
return False
collector = RecorderCollector(
experiment=self.experiment_name,
process_list=RollingGroup(),
rec_key_func=rec_key,
rec_filter_func=my_filter,
)
print(collector())
def main(self):
self.reset()
tasks = self.task_generating()
self.task_training(tasks)
self.task_collecting()
if __name__ == "__main__":
## to see the whole process with your own parameters, use the command below
# python task_manager_rolling.py main --experiment_name="your_exp_name"
fire.Fire(RollingTaskExample)

View File

@@ -0,0 +1,146 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
This example is about how can simulate the OnlineManager based on rolling tasks.
"""
import fire
import qlib
from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM
from qlib.workflow import R
from qlib.workflow.online.manager import OnlineManager
from qlib.workflow.online.strategy import RollingStrategy
from qlib.workflow.task.gen import RollingGen
from qlib.workflow.task.manage import TaskManager
data_handler_config = {
"start_time": "2018-01-01",
"end_time": "2018-10-31",
"fit_start_time": "2018-01-01",
"fit_end_time": "2018-03-31",
"instruments": "csi100",
}
dataset_config = {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": data_handler_config,
},
"segments": {
"train": ("2018-01-01", "2018-03-31"),
"valid": ("2018-04-01", "2018-05-31"),
"test": ("2018-06-01", "2018-09-10"),
},
},
}
record_config = [
{
"class": "SignalRecord",
"module_path": "qlib.workflow.record_temp",
},
{
"class": "SigAnaRecord",
"module_path": "qlib.workflow.record_temp",
},
]
# use lgb model
task_lgb_config = {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
},
"dataset": dataset_config,
"record": record_config,
}
# use xgboost model
task_xgboost_config = {
"model": {
"class": "XGBModel",
"module_path": "qlib.contrib.model.xgboost",
},
"dataset": dataset_config,
"record": record_config,
}
class OnlineSimulationExample:
def __init__(
self,
provider_uri="~/.qlib/qlib_data/cn_data",
region="cn",
exp_name="rolling_exp",
task_url="mongodb://10.0.0.4:27017/",
task_db_name="rolling_db",
task_pool="rolling_task",
rolling_step=80,
start_time="2018-09-10",
end_time="2018-10-31",
tasks=[task_xgboost_config, task_lgb_config],
):
"""
Init OnlineManagerExample.
Args:
provider_uri (str, optional): the provider uri. Defaults to "~/.qlib/qlib_data/cn_data".
region (str, optional): the stock region. Defaults to "cn".
exp_name (str, optional): the experiment name. Defaults to "rolling_exp".
task_url (str, optional): your MongoDB url. Defaults to "mongodb://10.0.0.4:27017/".
task_db_name (str, optional): database name. Defaults to "rolling_db".
task_pool (str, optional): the task pool name (a task pool is a collection in MongoDB). Defaults to "rolling_task".
rolling_step (int, optional): the step for rolling. Defaults to 80.
start_time (str, optional): the start time of simulating. Defaults to "2018-09-10".
end_time (str, optional): the end time of simulating. Defaults to "2018-10-31".
tasks (dict or list[dict]): a set of the task config waiting for rolling and training
"""
self.exp_name = exp_name
self.task_pool = task_pool
self.start_time = start_time
self.end_time = end_time
mongo_conf = {
"task_url": task_url,
"task_db_name": task_db_name,
}
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
self.rolling_gen = RollingGen(
step=rolling_step, rtype=RollingGen.ROLL_SD, ds_extra_mod_func=None
) # The rolling tasks generator, ds_extra_mod_func is None because we just need to simulate to 2018-10-31 and needn't change the handler end time.
self.trainer = DelayTrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR
self.rolling_online_manager = OnlineManager(
RollingStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen),
trainer=self.trainer,
begin_time=self.start_time,
)
self.tasks = tasks
# Reset all things to the first status, be careful to save important data
def reset(self):
TaskManager(self.task_pool).remove()
exp = R.get_exp(experiment_name=self.exp_name)
for rid in exp.list_recorders():
exp.delete_recorder(rid)
# Run this to run all workflow automatically
def main(self):
print("========== reset ==========")
self.reset()
print("========== simulate ==========")
self.rolling_online_manager.simulate(end_time=self.end_time)
print("========== collect results ==========")
print(self.rolling_online_manager.get_collector()())
print("========== signals ==========")
print(self.rolling_online_manager.get_signals())
if __name__ == "__main__":
## to run all workflow automatically with your own parameters, use the command below
# python online_management_simulate.py main --experiment_name="your_exp_name" --rolling_step=60
fire.Fire(OnlineSimulationExample)

View File

@@ -0,0 +1,181 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
This example shows how OnlineManager works with rolling tasks.
There are four parts including first train, routine 1, add strategy and routine 2.
Firstly, the OnlineManager will finish the first training and set trained models to `online` models.
Next, the OnlineManager will finish a routine process, including update online prediction -> prepare tasks -> prepare new models -> prepare signals
Then, we will add some new strategies to the OnlineManager. This will finish first training of new strategies.
Finally, the OnlineManager will finish second routine and update all strategies.
"""
import os
import fire
import qlib
from qlib.workflow import R
from qlib.workflow.online.strategy import RollingStrategy
from qlib.workflow.task.gen import RollingGen
from qlib.workflow.online.manager import OnlineManager
data_handler_config = {
"start_time": "2013-01-01",
"end_time": "2020-09-25",
"fit_start_time": "2013-01-01",
"fit_end_time": "2014-12-31",
"instruments": "csi100",
}
dataset_config = {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": data_handler_config,
},
"segments": {
"train": ("2013-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2015-12-31"),
"test": ("2016-01-01", "2020-07-10"),
},
},
}
record_config = [
{
"class": "SignalRecord",
"module_path": "qlib.workflow.record_temp",
},
{
"class": "SigAnaRecord",
"module_path": "qlib.workflow.record_temp",
},
]
# use lgb model
task_lgb_config = {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
},
"dataset": dataset_config,
"record": record_config,
}
# use xgboost model
task_xgboost_config = {
"model": {
"class": "XGBModel",
"module_path": "qlib.contrib.model.xgboost",
},
"dataset": dataset_config,
"record": record_config,
}
class RollingOnlineExample:
def __init__(
self,
provider_uri="~/.qlib/qlib_data/cn_data",
region="cn",
task_url="mongodb://10.0.0.4:27017/",
task_db_name="rolling_db",
rolling_step=550,
tasks=[task_xgboost_config],
add_tasks=[task_lgb_config],
):
mongo_conf = {
"task_url": task_url, # your MongoDB url
"task_db_name": task_db_name, # database name
}
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
self.tasks = tasks
self.add_tasks = add_tasks
self.rolling_step = rolling_step
strategies = []
for task in tasks:
name_id = task["model"]["class"] # NOTE: Assumption: The model class can specify only one strategy
strategies.append(
RollingStrategy(
name_id,
task,
RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD),
)
)
self.rolling_online_manager = OnlineManager(strategies)
_ROLLING_MANAGER_PATH = (
".RollingOnlineExample" # the OnlineManager will dump to this file, for it can be loaded when calling routine.
)
# Reset all things to the first status, be careful to save important data
def reset(self):
for task in self.tasks + self.add_tasks:
name_id = task["model"]["class"]
exp = R.get_exp(experiment_name=name_id)
for rid in exp.list_recorders():
exp.delete_recorder(rid)
if os.path.exists(self._ROLLING_MANAGER_PATH):
os.remove(self._ROLLING_MANAGER_PATH)
def first_run(self):
print("========== reset ==========")
self.reset()
print("========== first_run ==========")
self.rolling_online_manager.first_train()
print("========== collect results ==========")
print(self.rolling_online_manager.get_collector()())
print("========== dump ==========")
self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)
def routine(self):
print("========== load ==========")
self.rolling_online_manager = OnlineManager.load(self._ROLLING_MANAGER_PATH)
print("========== routine ==========")
self.rolling_online_manager.routine()
print("========== collect results ==========")
print(self.rolling_online_manager.get_collector()())
print("========== signals ==========")
print(self.rolling_online_manager.get_signals())
print("========== dump ==========")
self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)
def add_strategy(self):
print("========== load ==========")
self.rolling_online_manager = OnlineManager.load(self._ROLLING_MANAGER_PATH)
print("========== add strategy ==========")
strategies = []
for task in self.add_tasks:
name_id = task["model"]["class"] # NOTE: Assumption: The model class can specify only one strategy
strategies.append(
RollingStrategy(
name_id,
task,
RollingGen(step=self.rolling_step, rtype=RollingGen.ROLL_SD),
)
)
self.rolling_online_manager.add_strategy(strategies=strategies)
print("========== dump ==========")
self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)
def main(self):
self.first_run()
self.routine()
self.add_strategy()
self.routine()
if __name__ == "__main__":
####### to train the first version's models, use the command below
# python rolling_online_management.py first_run
####### to update the models and predictions after the trading time, use the command below
# python rolling_online_management.py routine
####### to define your own parameters, use `--`
# python rolling_online_management.py first_run --exp_name='your_exp_name' --rolling_step=40
fire.Fire(RollingOnlineExample)

View File

@@ -0,0 +1,91 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
This example shows how OnlineTool works when we need update prediction.
There are two parts including first_train and update_online_pred.
Firstly, we will finish the training and set the trained models to the `online` models.
Next, we will finish updating online predictions.
"""
import fire
import qlib
from qlib.config import REG_CN
from qlib.model.trainer import task_train
from qlib.workflow.online.utils import OnlineToolR
data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": "csi100",
}
task = {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
"kwargs": {
"loss": "mse",
"colsample_bytree": 0.8879,
"learning_rate": 0.0421,
"subsample": 0.8789,
"lambda_l1": 205.6999,
"lambda_l2": 580.9768,
"max_depth": 8,
"num_leaves": 210,
"num_threads": 20,
},
},
"dataset": {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": data_handler_config,
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
},
},
},
"record": {
"class": "SignalRecord",
"module_path": "qlib.workflow.record_temp",
},
}
class UpdatePredExample:
def __init__(
self, provider_uri="~/.qlib/qlib_data/cn_data", region=REG_CN, experiment_name="online_srv", task_config=task
):
qlib.init(provider_uri=provider_uri, region=region)
self.experiment_name = experiment_name
self.online_tool = OnlineToolR(self.experiment_name)
self.task_config = task_config
def first_train(self):
rec = task_train(self.task_config, experiment_name=self.experiment_name)
self.online_tool.reset_online_tag(rec) # set to online model
def update_online_pred(self):
self.online_tool.update_online_pred()
def main(self):
self.first_train()
self.update_online_pred()
if __name__ == "__main__":
## to train a model and set it to online model, use the command below
# python update_online_pred.py first_train
## to update online predictions once a day, use the command below
# python update_online_pred.py update_online_pred
## to see the whole process with your own parameters, use the command below
# python update_online_pred.py main --experiment_name="your_exp_name"
fire.Fire(UpdatePredExample)

View File

@@ -3,6 +3,7 @@
__version__ = "0.6.3.99"
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
import os
@@ -10,12 +11,13 @@ import yaml
import logging
import platform
import subprocess
from pathlib import Path
from .log import get_module_logger
# init qlib
def init(default_conf="client", **kwargs):
from .config import C
from .log import get_module_logger
from .data.cache import H
H.clear()
@@ -48,7 +50,6 @@ def init(default_conf="client", **kwargs):
def _mount_nfs_uri(C):
from .log import get_module_logger
LOG = get_module_logger("mount nfs", level=logging.INFO)
@@ -151,3 +152,74 @@ def init_from_yaml_conf(conf_path, **kwargs):
config.update(kwargs)
default_conf = config.pop("default_conf", "client")
init(default_conf, **config)
def get_project_path(config_name="config.yaml", cur_path=None) -> Path:
"""
If users are building a project follow the following pattern.
- Qlib is a sub folder in project path
- There is a file named `config.yaml` in qlib.
For example:
If your project file system stucuture follows such a pattern
<project_path>/
- config.yaml
- ...some folders...
- qlib/
This folder will return <project_path>
NOTE: link is not supported here.
This method is often used when
- user want to use a relative config path instead of hard-coding qlib config path in code
Raises
------
FileNotFoundError:
If project path is not found
"""
if cur_path is None:
cur_path = Path(__file__).absolute().resolve()
while True:
if (cur_path / config_name).exists():
return cur_path
if cur_path == cur_path.parent:
raise FileNotFoundError("We can't find the project path")
cur_path = cur_path.parent
def auto_init(**kwargs):
"""
This function will init qlib automatically with following priority
- Find the project configuration and init qlib
- The parsing process will be affected by the `conf_type` of the configuration file
- Init qlib with default config
"""
try:
pp = get_project_path(cur_path=kwargs.pop("cur_path", None))
except FileNotFoundError:
init(**kwargs)
else:
conf_pp = pp / "config.yaml"
with conf_pp.open() as f:
conf = yaml.safe_load(f)
conf_type = conf.get("conf_type", "origin")
if conf_type == "origin":
# The type of config is just like original qlib config
init_from_yaml_conf(conf_pp, **kwargs)
elif conf_type == "ref":
# This config type will be more convenient in following scenario
# - There is a shared configure file and you don't want to edit it inplace.
# - The shared configure may be updated later and you don't want to copy it.
# - You have some customized config.
qlib_conf_path = conf["qlib_cfg"]
qlib_conf_update = conf.get("qlib_cfg_update")
init_from_yaml_conf(qlib_conf_path, **qlib_conf_update, **kwargs)
logger = get_module_logger("Initialization")
logger.info(f"Auto load project config: {conf_pp}")

View File

@@ -33,6 +33,9 @@ class Config:
raise AttributeError(f"No such {attr} in self._config")
def get(self, key, default=None):
return self.__dict__["_config"].get(key, default)
def __setitem__(self, key, value):
self.__dict__["_config"][key] = value
@@ -131,7 +134,7 @@ _default_config = {
},
"loggers": {"qlib": {"level": logging.DEBUG, "handlers": ["console"]}},
},
# Defatult config for experiment manager
# Default config for experiment manager
"exp_manager": {
"class": "MLflowExpManager",
"module_path": "qlib.workflow.expm",
@@ -140,6 +143,11 @@ _default_config = {
"default_exp_name": "Experiment",
},
},
# Default config for MongoDB
"mongo": {
"task_url": "mongodb://localhost:27017/",
"task_db_name": "default_task_db",
},
}
MODE_CONF = {
@@ -310,8 +318,22 @@ class QlibConfig(Config):
# clean up experiment when python program ends
experiment_exit_handler()
# Supporting user reset qlib version (useful when user want to connect to qlib server with old version)
self.reset_qlib_version()
self._registered = True
def reset_qlib_version(self):
import qlib
reset_version = self.get("qlib_reset_version", None)
if reset_version is not None:
qlib.__version__ = reset_version
else:
qlib.__version__ = getattr(qlib, "__version__bak")
# Due to a bug? that converting __version__ to _QlibConfig__version__bak
# Using __version__bak instead of __version__
@property
def registered(self):
return self._registered

View File

@@ -26,6 +26,7 @@ def check_transform_proc(proc_l, fit_start_time, fit_end_time):
"fit_end_time": fit_end_time,
}
)
# FIXME: the `module_path` parameter is missed.
new_l.append({"class": klass.__name__, "kwargs": pkwargs})
else:
new_l.append(p)

View File

@@ -27,7 +27,7 @@ class Dataset(Serializable):
- setup data
- The data related attributes' names should start with '_' so that it will not be saved on disk when serializing.
The data could specify the info to caculate the essential data for preparation
The data could specify the info to calculate the essential data for preparation
"""
self.setup_data(**kwargs)
super().__init__()
@@ -92,7 +92,7 @@ class DatasetH(Dataset):
handler : Union[dict, DataHandler]
handler could be:
- insntance of `DataHandler`
- instance of `DataHandler`
- config of `DataHandler`. Please refer to `DataHandler`
@@ -112,8 +112,9 @@ class DatasetH(Dataset):
'outsample': ("2017-01-01", "2020-08-01",),
}
"""
self.handler = init_instance_by_config(handler, accept_types=DataHandler)
self.handler: DataHandler = init_instance_by_config(handler, accept_types=DataHandler)
self.segments = segments.copy()
self.fetch_kwargs = {}
super().__init__(**kwargs)
def config(self, handler_kwargs: dict = None, **kwargs):
@@ -123,7 +124,7 @@ class DatasetH(Dataset):
Parameters
----------
handler_kwargs : dict
Config of DataHanlder, which could include the following arguments:
Config of DataHandler, which could include the following arguments:
- arguments of DataHandler.conf_data, such as 'instruments', 'start_time' and 'end_time'.
@@ -147,11 +148,11 @@ class DatasetH(Dataset):
Parameters
----------
handler_kwargs : dict
init arguments of DataHanlder, which could include the following arguments:
init arguments of DataHandler, which could include the following arguments:
- init_type : Init Type of Handler
- enable_cache : wheter to enable cache
- enable_cache : whether to enable cache
"""
super().setup_data(**kwargs)
@@ -171,7 +172,10 @@ class DatasetH(Dataset):
----------
slc : slice
"""
return self.handler.fetch(slc, **kwargs)
if hasattr(self, "fetch_kwargs"):
return self.handler.fetch(slc, **kwargs, **self.fetch_kwargs)
else:
return self.handler.fetch(slc, **kwargs)
def prepare(
self,
@@ -199,6 +203,12 @@ class DatasetH(Dataset):
The data to fetch: DK_*
Default is DK_I, which indicate fetching data for **inference**.
kwargs :
The parameters that kwargs may contain:
flt_col : str
It only exists in TSDatasetH, can be used to add a column of data(True or False) to filter data.
This parameter is only supported when it is an instance of TSDatasetH.
Returns
-------
Union[List[pd.DataFrame], pd.DataFrame]:
@@ -231,7 +241,7 @@ class TSDataSampler:
(T)ime-(S)eries DataSampler
This is the result of TSDatasetH
It works like `torch.data.utils.Dataset`, it provides a very convient interface for constructing time-series
It works like `torch.data.utils.Dataset`, it provides a very convenient interface for constructing time-series
dataset based on tabular data.
If user have further requirements for processing data, user could process them based on `TSDataSampler` or create
@@ -243,7 +253,9 @@ class TSDataSampler:
"""
def __init__(self, data: pd.DataFrame, start, end, step_len: int, fillna_type: str = "none"):
def __init__(
self, data: pd.DataFrame, start, end, step_len: int, fillna_type: str = "none", dtype=None, flt_data=None
):
"""
Build a dataset which looks like torch.data.utils.Dataset.
@@ -265,6 +277,11 @@ class TSDataSampler:
ffill with previous sample
ffill+bfill:
ffill with previous samples first and fill with later samples second
flt_data : pd.Series
a column of data(True or False) to filter data.
None:
kepp all data
"""
self.start = start
self.end = end
@@ -272,23 +289,51 @@ class TSDataSampler:
self.fillna_type = fillna_type
assert get_level_index(data, "datetime") == 0
self.data = lazy_sort_index(data)
self.data_arr = np.array(self.data) # Get index from numpy.array will much faster than DataFrame.values!
# NOTE: append last line with full NaN for better performance in `__getitem__`
self.data_arr = np.append(self.data_arr, np.full((1, self.data_arr.shape[1]), np.nan), axis=0)
kwargs = {"object": self.data}
if dtype is not None:
kwargs["dtype"] = dtype
self.data_arr = np.array(**kwargs) # Get index from numpy.array will much faster than DataFrame.values!
# NOTE:
# - append last line with full NaN for better performance in `__getitem__`
# - Keep the same dtype will result in a better performance
self.data_arr = np.append(
self.data_arr, np.full((1, self.data_arr.shape[1]), np.nan, dtype=self.data_arr.dtype), axis=0
)
self.nan_idx = -1 # The last line is all NaN
# the data type will be changed
# The index of usable data is between start_idx and end_idx
self.start_idx, self.end_idx = self.data.index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end))
self.idx_df, self.idx_map = self.build_index(self.data)
self.data_index = deepcopy(self.data.index)
if flt_data is not None:
self.flt_data = np.array(flt_data.reindex(self.data_index)).reshape(-1)
self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map)
self.data_index = self.data_index[np.where(self.flt_data == True)[0]]
self.start_idx, self.end_idx = self.data_index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end))
self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance
del self.data # save memory
@staticmethod
def flt_idx_map(flt_data, idx_map):
idx = 0
new_idx_map = {}
for i, exist in enumerate(flt_data):
if exist:
new_idx_map[idx] = idx_map[i]
idx += 1
return new_idx_map
def get_index(self):
"""
Get the pandas index of the data, it will be useful in following scenarios
- Special sampler will be used (e.g. user want to sample day by day)
"""
return self.data.index[self.start_idx : self.end_idx]
return self.data_index[self.start_idx : self.end_idx]
def config(self, **kwargs):
# Config the attributes
@@ -432,7 +477,7 @@ class TSDatasetH(DatasetH):
(T)ime-(S)eries Dataset (H)andler
Covnert the tabular data to Time-Series data
Convert the tabular data to Time-Series data
Requirements analysis
@@ -461,7 +506,7 @@ class TSDatasetH(DatasetH):
cal = sorted(cal)
self.cal = cal
def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
def _prepare_raw_seg(self, slc: slice, **kwargs) -> pd.DataFrame:
# Dataset decide how to slice data(Get more data for timeseries).
start, end = slc.start, slc.stop
start_idx = bisect.bisect_left(self.cal, pd.Timestamp(start))
@@ -470,6 +515,25 @@ class TSDatasetH(DatasetH):
# TSDatasetH will retrieve more data for complete
data = super()._prepare_seg(slice(pad_start, end), **kwargs)
return data
tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len)
def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
"""
split the _prepare_raw_seg is to leave a hook for data preprocessing before creating processing data
"""
dtype = kwargs.pop("dtype", None)
start, end = slc.start, slc.stop
flt_col = kwargs.pop("flt_col", None)
# TSDatasetH will retrieve more data for complete
data = self._prepare_raw_seg(slc, **kwargs)
flt_kwargs = deepcopy(kwargs)
if flt_col is not None:
flt_kwargs["col_set"] = flt_col
flt_data = self._prepare_raw_seg(slc, **flt_kwargs)
assert len(flt_data.columns) == 1
else:
flt_data = None
tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len, dtype=dtype, flt_data=flt_data)
return tsds

View File

@@ -7,7 +7,7 @@ import bisect
import logging
import warnings
from inspect import getfullargspec
from typing import Union, Tuple, List, Iterator, Optional
from typing import Callable, Union, Tuple, List, Iterator, Optional
import pandas as pd
import numpy as np
@@ -36,7 +36,7 @@ class DataHandler(Serializable):
The data handler try to maintain a handler with 2 level.
`datetime` & `instruments`.
Any order of the index level can be suported (The order will be implied in the data).
Any order of the index level can be supported (The order will be implied in the data).
The order <`datetime`, `instruments`> will be used when the dataframe index name is missed.
Example of the data:
@@ -51,6 +51,9 @@ class DataHandler(Serializable):
SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042
SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
Tips for improving the performance of datahandler
- Fetching data with `col_set=CS_RAW` will return the raw data and may avoid pandas from copying the data when calling `loc`
"""
def __init__(
@@ -74,7 +77,7 @@ class DataHandler(Serializable):
data_loader : Union[dict, str, DataLoader]
data loader to load the data.
init_data :
intialize the original data in the constructor.
initialize the original data in the constructor.
fetch_orig : bool
Return the original data instead of copy if possible.
"""
@@ -125,7 +128,7 @@ class DataHandler(Serializable):
def setup_data(self, enable_cache: bool = False):
"""
Set Up the data in case of running intialization for multiple time
Set Up the data in case of running initialization for multiple time
It is responsible for maintaining following variable
1) self._data
@@ -163,6 +166,7 @@ class DataHandler(Serializable):
level: Union[str, int] = "datetime",
col_set: Union[str, List[str]] = CS_ALL,
squeeze: bool = False,
proc_func: Callable = None,
) -> pd.DataFrame:
"""
fetch data from underlying data source
@@ -185,6 +189,14 @@ class DataHandler(Serializable):
- if isinstance(col_set, List[str]):
select several sets of meaningful columns, the returned data has multiple levels
proc_func: Callable
- Give a hook for processing data before fetching
- An example to explain the necessity of the hook:
- A Dataset learned some processors to process data which is related to data segmentation
- It will apply them every time when preparing data.
- The learned processor require the dataframe remains the same format when fitting and applying
- However the data format will change according to the parameters.
- So the processors should be applied to the underlayer data.
squeeze : bool
whether squeeze columns and index
@@ -193,8 +205,15 @@ class DataHandler(Serializable):
-------
pd.DataFrame.
"""
if proc_func is None:
df = self._data
else:
# FIXME: fetching by time first will be more friendly to `proc_func`
# Copy in case of `proc_func` changing the data inplace....
df = proc_func(fetch_df_by_index(self._data, selector, level, fetch_orig=self.fetch_orig).copy())
# Fetch column first will be more friendly to SepDataFrame
df = self._fetch_df_by_col(self._data, col_set)
df = self._fetch_df_by_col(df, col_set)
df = fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig)
if squeeze:
# squeeze columns
@@ -261,6 +280,10 @@ class DataHandler(Serializable):
class DataHandlerLP(DataHandler):
"""
DataHandler with **(L)earnable (P)rocessor**
Tips to improving the performance of data handler
- To reduce the memory cost
- `drop_raw=True`: this will modify the data inplace on raw data;
"""
# data key
@@ -430,7 +453,7 @@ class DataHandlerLP(DataHandler):
def setup_data(self, init_type: str = IT_FIT_SEQ, **kwargs):
"""
Set up the data in case of running intialization for multiple time
Set up the data in case of running initialization for multiple time
Parameters
----------
@@ -474,6 +497,7 @@ class DataHandlerLP(DataHandler):
level: Union[str, int] = "datetime",
col_set=DataHandler.CS_ALL,
data_key: str = DK_I,
proc_func: Callable = None,
) -> pd.DataFrame:
"""
fetch data from underlying data source
@@ -488,12 +512,18 @@ class DataHandlerLP(DataHandler):
select a set of meaningful columns.(e.g. features, columns).
data_key : str
the data to fetch: DK_*.
proc_func: Callable
please refer to the doc of DataHandler.fetch
Returns
-------
pd.DataFrame:
"""
df = self._get_df_by_key(data_key)
if proc_func is not None:
# FIXME: fetch by time first will be more friendly to proc_func
# Copy incase of `proc_func` changing the data inplace....
df = proc_func(fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig).copy())
# Fetch column first will be more friendly to SepDataFrame
df = self._fetch_df_by_col(df, col_set)
return fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig)

View File

@@ -13,6 +13,7 @@ from qlib.data import D
from qlib.data import filter as filter_module
from qlib.data.filter import BaseDFilter
from qlib.utils import load_dataset, init_instance_by_config
from qlib.log import get_module_logger
class DataLoader(abc.ABC):
@@ -224,6 +225,10 @@ class DataLoaderDH(DataLoader):
DataLoader based on (D)ata (H)andler
It is designed to load multiple data from data handler
- If you just want to load data from single datahandler, you can write them in single data handler
TODO: What make this module not that easy to use.
- For online scenario
- The underlayer data handler should be configured. But data loader doesn't provide such interface & hook.
"""
def __init__(self, handler_config: dict, fetch_kwargs: dict = {}, is_group=False):
@@ -265,7 +270,7 @@ class DataLoaderDH(DataLoader):
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
if instruments is not None:
LOG.warning(f"instruments[{instruments}] is ignored")
get_module_logger(self.__class__.__name__).warning(f"instruments[{instruments}] is ignored")
if self.is_group:
df = pd.concat(

View File

@@ -2,6 +2,7 @@
# Licensed under the MIT License.
import abc
from typing import Union, Text
import numpy as np
import pandas as pd
import copy
@@ -14,7 +15,7 @@ from ...utils.paral import datetime_groupby_apply
EPS = 1e-12
def get_group_columns(df: pd.DataFrame, group: str):
def get_group_columns(df: pd.DataFrame, group: Union[Text, None]):
"""
get a group of columns from multi-index columns DataFrame

View File

115
qlib/model/ens/ensemble.py Normal file
View File

@@ -0,0 +1,115 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
Ensemble module can merge the objects in an Ensemble. For example, if there are many submodels predictions, we may need to merge them into an ensemble prediction.
"""
from typing import Union
import pandas as pd
from qlib.utils import FLATTEN_TUPLE, flatten_dict
class Ensemble:
"""Merge the ensemble_dict into an ensemble object.
For example: {Rollinga_b: object, Rollingb_c: object} -> object
When calling this class:
Args:
ensemble_dict (dict): the ensemble dict like {name: things} waiting for merging
Returns:
object: the ensemble object
"""
def __call__(self, ensemble_dict: dict, *args, **kwargs):
raise NotImplementedError(f"Please implement the `__call__` method.")
class SingleKeyEnsemble(Ensemble):
"""
Extract the object if there is only one key and value in the dict. Make the result more readable.
{Only key: Only value} -> Only value
If there is more than 1 key or less than 1 key, then do nothing.
Even you can run this recursively to make dict more readable.
NOTE: Default runs recursively.
When calling this class:
Args:
ensemble_dict (dict): the dict. The key of the dict will be ignored.
Returns:
dict: the readable dict.
"""
def __call__(self, ensemble_dict: Union[dict, object], recursion: bool = True) -> object:
if not isinstance(ensemble_dict, dict):
return ensemble_dict
if recursion:
tmp_dict = {}
for k, v in ensemble_dict.items():
tmp_dict[k] = self(v, recursion)
ensemble_dict = tmp_dict
keys = list(ensemble_dict.keys())
if len(keys) == 1:
ensemble_dict = ensemble_dict[keys[0]]
return ensemble_dict
class RollingEnsemble(Ensemble):
"""Merge a dict of rolling dataframe like `prediction` or `IC` into an ensemble.
NOTE: The values of dict must be pd.DataFrame, and have the index "datetime".
When calling this class:
Args:
ensemble_dict (dict): a dict like {"A": pd.DataFrame, "B": pd.DataFrame}.
The key of the dict will be ignored.
Returns:
pd.DataFrame: the complete result of rolling.
"""
def __call__(self, ensemble_dict: dict) -> pd.DataFrame:
artifact_list = list(ensemble_dict.values())
artifact_list.sort(key=lambda x: x.index.get_level_values("datetime").min())
artifact = pd.concat(artifact_list)
# If there are duplicated predition, use the latest perdiction
artifact = artifact[~artifact.index.duplicated(keep="last")]
artifact = artifact.sort_index()
return artifact
class AverageEnsemble(Ensemble):
"""
Average and standardize a dict of same shape dataframe like `prediction` or `IC` into an ensemble.
NOTE: The values of dict must be pd.DataFrame, and have the index "datetime". If it is a nested dict, then flat it.
When calling this class:
Args:
ensemble_dict (dict): a dict like {"A": pd.DataFrame, "B": pd.DataFrame}.
The key of the dict will be ignored.
Returns:
pd.DataFrame: the complete result of averaging and standardizing.
"""
def __call__(self, ensemble_dict: dict) -> pd.DataFrame:
# need to flatten the nested dict
ensemble_dict = flatten_dict(ensemble_dict, sep=FLATTEN_TUPLE)
values = list(ensemble_dict.values())
results = pd.concat(values, axis=1)
results = results.groupby("datetime").apply(lambda df: (df - df.mean()) / df.std())
results = results.mean(axis=1)
results = results.sort_index()
return results

113
qlib/model/ens/group.py Normal file
View File

@@ -0,0 +1,113 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
Group can group a set of objects based on `group_func` and change them to a dict.
After group, we provide a method to reduce them.
For example:
group: {(A,B,C1): object, (A,B,C2): object} -> {(A,B): {C1: object, C2: object}}
reduce: {(A,B): {C1: object, C2: object}} -> {(A,B): object}
"""
from qlib.model.ens.ensemble import Ensemble, RollingEnsemble
from typing import Callable, Union
from joblib import Parallel, delayed
class Group:
"""Group the objects based on dict"""
def __init__(self, group_func=None, ens: Ensemble = None):
"""
Init Group.
Args:
group_func (Callable, optional): Given a dict and return the group key and one of the group elements.
For example: {(A,B,C1): object, (A,B,C2): object} -> {(A,B): {C1: object, C2: object}}
Defaults to None.
ens (Ensemble, optional): If not None, do ensemble for grouped value after grouping.
"""
self._group_func = group_func
self._ens_func = ens
def group(self, *args, **kwargs) -> dict:
"""
Group a set of objects and change them to a dict.
For example: {(A,B,C1): object, (A,B,C2): object} -> {(A,B): {C1: object, C2: object}}
Returns:
dict: grouped dict
"""
if isinstance(getattr(self, "_group_func", None), Callable):
return self._group_func(*args, **kwargs)
else:
raise NotImplementedError(f"Please specify valid `group_func`.")
def reduce(self, *args, **kwargs) -> dict:
"""
Reduce grouped dict.
For example: {(A,B): {C1: object, C2: object}} -> {(A,B): object}
Returns:
dict: reduced dict
"""
if isinstance(getattr(self, "_ens_func", None), Callable):
return self._ens_func(*args, **kwargs)
else:
raise NotImplementedError(f"Please specify valid `_ens_func`.")
def __call__(self, ungrouped_dict: dict, n_jobs: int = 1, verbose: int = 0, *args, **kwargs) -> dict:
"""
Group the ungrouped_dict into different groups.
Args:
ungrouped_dict (dict): the ungrouped dict waiting for grouping like {name: things}
Returns:
dict: grouped_dict like {G1: object, G2: object}
n_jobs: how many progress you need.
verbose: the print mode for Parallel.
"""
# NOTE: The multiprocessing will raise error if you use `Serializable`
# Because the `Serializable` will affect the behaviors of pickle
grouped_dict = self.group(ungrouped_dict, *args, **kwargs)
key_l = []
job_l = []
for key, value in grouped_dict.items():
key_l.append(key)
job_l.append(delayed(Group.reduce)(self, value))
return dict(zip(key_l, Parallel(n_jobs=n_jobs, verbose=verbose)(job_l)))
class RollingGroup(Group):
"""Group the rolling dict"""
def group(self, rolling_dict: dict) -> dict:
"""Given an rolling dict likes {(A,B,R): things}, return the grouped dict likes {(A,B): {R:things}}
NOTE: There is an assumption which is the rolling key is at the end of the key tuple, because the rolling results always need to be ensemble firstly.
Args:
rolling_dict (dict): an rolling dict. If the key is not a tuple, then do nothing.
Returns:
dict: grouped dict
"""
grouped_dict = {}
for key, values in rolling_dict.items():
if isinstance(key, tuple):
grouped_dict.setdefault(key[:-1], {})[key[-1]] = values
return grouped_dict
def __init__(self):
super().__init__(ens=RollingEnsemble())

View File

@@ -1,27 +0,0 @@
import abc
import typing
class TaskGen(metaclass=abc.ABCMeta):
@abc.abstractmethod
def __call__(self, *args, **kwargs) -> typing.List[dict]:
"""
generate
Parameters
----------
args, kwargs:
The info for generating tasks
Example 1):
input: a specific task template
output: rolling version of the tasks
Example 2):
input: a specific task template
output: a set of tasks with different losses
Returns
-------
typing.List[dict]:
A list of tasks
"""
pass

View File

@@ -1,42 +1,446 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from qlib.utils import init_instance_by_config, flatten_dict
"""
The Trainer will train a list of tasks and return a list of model recorders.
There are two steps in each Trainer including ``train``(make model recorder) and ``end_train``(modify model recorder).
This is a concept called ``DelayTrainer``, which can be used in online simulating for parallel training.
In ``DelayTrainer``, the first step is only to save some necessary info to model recorders, and the second step which will be finished in the end can do some concurrent and time-consuming operations such as model fitting.
``Qlib`` offer two kinds of Trainer, ``TrainerR`` is the simplest way and ``TrainerRM`` is based on TaskManager to help manager tasks lifecycle automatically.
"""
import socket
from typing import Callable, List
from qlib.data.dataset import Dataset
from qlib.model.base import Model
from qlib.utils import flatten_dict, get_cls_kwargs, init_instance_by_config
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord
from qlib.workflow.recorder import Recorder
from qlib.workflow.task.manage import TaskManager, run_task
def task_train(task_config: dict, experiment_name):
def begin_task_train(task_config: dict, experiment_name: str, recorder_name: str = None) -> Recorder:
"""
task based training
Begin task training to start a recorder and save the task config.
Args:
task_config (dict): the config of a task
experiment_name (str): the name of experiment
recorder_name (str): the given name will be the recorder name. None for using rid.
Returns:
Recorder: the model recorder
"""
with R.start(experiment_name=experiment_name, recorder_name=recorder_name):
R.log_params(**flatten_dict(task_config))
R.save_objects(**{"task": task_config}) # keep the original format and datatype
R.set_tags(**{"hostname": socket.gethostname()})
recorder: Recorder = R.get_recorder()
return recorder
def end_task_train(rec: Recorder, experiment_name: str) -> Recorder:
"""
Finish task training with real model fitting and saving.
Args:
rec (Recorder): the recorder will be resumed
experiment_name (str): the name of experiment
Returns:
Recorder: the model recorder
"""
with R.start(experiment_name=experiment_name, recorder_id=rec.info["id"], resume=True):
task_config = R.load_object("task")
# model & dataset initiation
model: Model = init_instance_by_config(task_config["model"])
dataset: Dataset = init_instance_by_config(task_config["dataset"])
# model training
model.fit(dataset)
R.save_objects(**{"params.pkl": model})
# this dataset is saved for online inference. So the concrete data should not be dumped
dataset.config(dump_all=False, recursive=True)
R.save_objects(**{"dataset": dataset})
# generate records: prediction, backtest, and analysis
records = task_config.get("record", [])
if isinstance(records, dict): # prevent only one dict
records = [records]
for record in records:
cls, kwargs = get_cls_kwargs(record, default_module="qlib.workflow.record_temp")
if cls is SignalRecord:
rconf = {"model": model, "dataset": dataset, "recorder": rec}
else:
rconf = {"recorder": rec}
r = cls(**kwargs, **rconf)
r.generate()
return rec
def task_train(task_config: dict, experiment_name: str) -> Recorder:
"""
Task based training, will be divided into two steps.
Parameters
----------
task_config : dict
A dict describes a task setting.
The config of a task.
experiment_name: str
The name of experiment
Returns
----------
Recorder: The instance of the recorder
"""
recorder = begin_task_train(task_config, experiment_name)
recorder = end_task_train(recorder, experiment_name)
return recorder
class Trainer:
"""
The trainer can train a list of models.
There are Trainer and DelayTrainer, which can be distinguished by when it will finish real training.
"""
# model initiaiton
model = init_instance_by_config(task_config["model"])
dataset = init_instance_by_config(task_config["dataset"])
def __init__(self):
self.delay = False
# start exp
with R.start(experiment_name=experiment_name):
# train model
R.log_params(**flatten_dict(task_config))
model.fit(dataset)
recorder = R.get_recorder()
R.save_objects(**{"params.pkl": model})
def train(self, tasks: list, *args, **kwargs) -> list:
"""
Given a list of task definitions, begin training, and return the models.
# generate records: prediction, backtest, and analysis
for record in task_config["record"]:
if record["class"] == SignalRecord.__name__:
srconf = {"model": model, "dataset": dataset, "recorder": recorder}
record["kwargs"].update(srconf)
sr = init_instance_by_config(record)
sr.generate()
else:
rconf = {"recorder": recorder}
record["kwargs"].update(rconf)
ar = init_instance_by_config(record)
ar.generate()
For Trainer, it finishes real training in this method.
For DelayTrainer, it only does some preparation in this method.
Args:
tasks: a list of tasks
Returns:
list: a list of models
"""
raise NotImplementedError(f"Please implement the `train` method.")
def end_train(self, models: list, *args, **kwargs) -> list:
"""
Given a list of models, finished something at the end of training if you need.
The models may be Recorder, txt file, database, and so on.
For Trainer, it does some finishing touches in this method.
For DelayTrainer, it finishes real training in this method.
Args:
models: a list of models
Returns:
list: a list of models
"""
# do nothing if you finished all work in `train` method
return models
def is_delay(self) -> bool:
"""
If Trainer will delay finishing `end_train`.
Returns:
bool: if DelayTrainer
"""
return self.delay
class TrainerR(Trainer):
"""
Trainer based on (R)ecorder.
It will train a list of tasks and return a list of model recorders in a linear way.
Assumption: models were defined by `task` and the results will be saved to `Recorder`.
"""
# Those tag will help you distinguish whether the Recorder has finished traning
STATUS_KEY = "train_status"
STATUS_BEGIN = "begin_task_train"
STATUS_END = "end_task_train"
def __init__(self, experiment_name: str = None, train_func: Callable = task_train):
"""
Init TrainerR.
Args:
experiment_name (str, optional): the default name of experiment.
train_func (Callable, optional): default training method. Defaults to `task_train`.
"""
super().__init__()
self.experiment_name = experiment_name
self.train_func = train_func
def train(self, tasks: list, train_func: Callable = None, experiment_name: str = None, **kwargs) -> List[Recorder]:
"""
Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed.
Args:
tasks (list): a list of definitions based on `task` dict
train_func (Callable): the training method which needs at least `tasks` and `experiment_name`. None for the default training method.
experiment_name (str): the experiment name, None for use default name.
kwargs: the params for train_func.
Returns:
List[Recorder]: a list of Recorders
"""
if len(tasks) == 0:
return []
if train_func is None:
train_func = self.train_func
if experiment_name is None:
experiment_name = self.experiment_name
recs = []
for task in tasks:
rec = train_func(task, experiment_name, **kwargs)
rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN})
recs.append(rec)
return recs
def end_train(self, recs: list, **kwargs) -> List[Recorder]:
"""
Set STATUS_END tag to the recorders.
Args:
recs (list): a list of trained recorders.
Returns:
List[Recorder]: the same list as the param.
"""
for rec in recs:
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
return recs
class DelayTrainerR(TrainerR):
"""
A delayed implementation based on TrainerR, which means `train` method may only do some preparation and `end_train` method can do the real model fitting.
"""
def __init__(self, experiment_name: str = None, train_func=begin_task_train, end_train_func=end_task_train):
"""
Init TrainerRM.
Args:
experiment_name (str): the default name of experiment.
train_func (Callable, optional): default train method. Defaults to `begin_task_train`.
end_train_func (Callable, optional): default end_train method. Defaults to `end_task_train`.
"""
super().__init__(experiment_name, train_func)
self.end_train_func = end_train_func
self.delay = True
def end_train(self, recs, end_train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:
"""
Given a list of Recorder and return a list of trained Recorder.
This class will finish real data loading and model fitting.
Args:
recs (list): a list of Recorder, the tasks have been saved to them
end_train_func (Callable, optional): the end_train method which needs at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
experiment_name (str): the experiment name, None for use default name.
kwargs: the params for end_train_func.
Returns:
List[Recorder]: a list of Recorders
"""
if end_train_func is None:
end_train_func = self.end_train_func
if experiment_name is None:
experiment_name = self.experiment_name
for rec in recs:
if rec.list_tags()[self.STATUS_KEY] == self.STATUS_END:
continue
end_train_func(rec, experiment_name, **kwargs)
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
return recs
class TrainerRM(Trainer):
"""
Trainer based on (R)ecorder and Task(M)anager.
It can train a list of tasks and return a list of model recorders in a multiprocessing way.
Assumption: `task` will be saved to TaskManager and `task` will be fetched and trained from TaskManager
"""
# Those tag will help you distinguish whether the Recorder has finished traning
STATUS_KEY = "train_status"
STATUS_BEGIN = "begin_task_train"
STATUS_END = "end_task_train"
def __init__(self, experiment_name: str = None, task_pool: str = None, train_func=task_train):
"""
Init TrainerR.
Args:
experiment_name (str): the default name of experiment.
task_pool (str): task pool name in TaskManager. None for use same name as experiment_name.
train_func (Callable, optional): default training method. Defaults to `task_train`.
"""
super().__init__()
self.experiment_name = experiment_name
self.task_pool = task_pool
self.train_func = train_func
def train(
self,
tasks: list,
train_func: Callable = None,
experiment_name: str = None,
before_status: str = TaskManager.STATUS_WAITING,
after_status: str = TaskManager.STATUS_DONE,
**kwargs,
) -> List[Recorder]:
"""
Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed.
This method defaults to a single process, but TaskManager offered a great way to parallel training.
Users can customize their train_func to realize multiple processes or even multiple machines.
Args:
tasks (list): a list of definitions based on `task` dict
train_func (Callable): the training method which needs at least `task`s and `experiment_name`. None for the default training method.
experiment_name (str): the experiment name, None for use default name.
before_status (str): the tasks in before_status will be fetched and trained. Can be STATUS_WAITING, STATUS_PART_DONE.
after_status (str): the tasks after trained will become after_status. Can be STATUS_WAITING, STATUS_PART_DONE.
kwargs: the params for train_func.
Returns:
List[Recorder]: a list of Recorders
"""
if len(tasks) == 0:
return []
if train_func is None:
train_func = self.train_func
if experiment_name is None:
experiment_name = self.experiment_name
task_pool = self.task_pool
if task_pool is None:
task_pool = experiment_name
tm = TaskManager(task_pool=task_pool)
_id_list = tm.create_task(tasks) # all tasks will be saved to MongoDB
run_task(
train_func,
task_pool,
experiment_name=experiment_name,
before_status=before_status,
after_status=after_status,
**kwargs,
)
recs = []
for _id in _id_list:
rec = tm.re_query(_id)["res"]
rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN})
recs.append(rec)
return recs
def end_train(self, recs: list, **kwargs) -> List[Recorder]:
"""
Set STATUS_END tag to the recorders.
Args:
recs (list): a list of trained recorders.
Returns:
List[Recorder]: the same list as the param.
"""
for rec in recs:
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
return recs
class DelayTrainerRM(TrainerRM):
"""
A delayed implementation based on TrainerRM, which means `train` method may only do some preparation and `end_train` method can do the real model fitting.
"""
def __init__(
self,
experiment_name: str = None,
task_pool: str = None,
train_func=begin_task_train,
end_train_func=end_task_train,
):
"""
Init DelayTrainerRM.
Args:
experiment_name (str): the default name of experiment.
task_pool (str): task pool name in TaskManager. None for use same name as experiment_name.
train_func (Callable, optional): default train method. Defaults to `begin_task_train`.
end_train_func (Callable, optional): default end_train method. Defaults to `end_task_train`.
"""
super().__init__(experiment_name, task_pool, train_func)
self.end_train_func = end_train_func
self.delay = True
def train(self, tasks: list, train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:
"""
Same as `train` of TrainerRM, after_status will be STATUS_PART_DONE.
Args:
tasks (list): a list of definition based on `task` dict
train_func (Callable): the train method which need at least `task`s and `experiment_name`. Defaults to None for using self.train_func.
experiment_name (str): the experiment name, None for use default name.
Returns:
List[Recorder]: a list of Recorders
"""
if len(tasks) == 0:
return []
return super().train(
tasks,
train_func=train_func,
experiment_name=experiment_name,
after_status=TaskManager.STATUS_PART_DONE,
**kwargs,
)
def end_train(self, recs, end_train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:
"""
Given a list of Recorder and return a list of trained Recorder.
This class will finish real data loading and model fitting.
NOTE: This method will train all STATUS_PART_DONE tasks in the task pool, not only the ``recs``.
Args:
recs (list): a list of Recorder, the tasks have been saved to them.
end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
experiment_name (str): the experiment name, None for use default name.
kwargs: the params for end_train_func.
Returns:
List[Recorder]: a list of Recorders
"""
if end_train_func is None:
end_train_func = self.end_train_func
if experiment_name is None:
experiment_name = self.experiment_name
task_pool = self.task_pool
if task_pool is None:
task_pool = experiment_name
tasks = []
for rec in recs:
tasks.append(rec.load_object("task"))
run_task(
end_train_func,
task_pool,
query={"filter": {"$in": tasks}}, # only train these tasks
experiment_name=experiment_name,
before_status=TaskManager.STATUS_PART_DONE,
**kwargs,
)
for rec in recs:
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
return recs

View File

@@ -6,6 +6,7 @@ from __future__ import division
from __future__ import print_function
import os
import pickle
import re
import copy
import json
@@ -24,7 +25,9 @@ import collections
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Union, Tuple, Text, Optional
from typing import Union, Tuple, Any, Text, Optional
from types import ModuleType
from urllib.parse import urlparse
from ..config import C
from ..log import get_module_logger, set_log_with_config
@@ -165,24 +168,25 @@ def parse_field(field):
return re.sub(r"\$(\w+)", r'Feature("\1")', re.sub(r"(\w+\s*)\(", r"Operators.\1(", field))
def get_module_by_module_path(module_path):
def get_module_by_module_path(module_path: Union[str, ModuleType]):
"""Load module path
:param module_path:
:return:
"""
if module_path.endswith(".py"):
module_spec = importlib.util.spec_from_file_location("", module_path)
module = importlib.util.module_from_spec(module_spec)
module_spec.loader.exec_module(module)
if isinstance(module_path, ModuleType):
module = module_path
else:
module = importlib.import_module(module_path)
if module_path.endswith(".py"):
module_spec = importlib.util.spec_from_file_location("", module_path)
module = importlib.util.module_from_spec(module_spec)
module_spec.loader.exec_module(module)
else:
module = importlib.import_module(module_path)
return module
def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
def get_cls_kwargs(config: Union[dict, str], default_module: Union[str, ModuleType] = None) -> (type, dict):
"""
extract class and kwargs from config info
@@ -191,8 +195,10 @@ def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
config : [dict, str]
similar to config
module : Python module
default_module : Python module or str
It should be a python module to load the class type
This function will load class from the config['module_path'] first.
If config['module_path'] doesn't exists, it will load the class from default_module.
Returns
-------
@@ -200,10 +206,14 @@ def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
the class object and it's arguments.
"""
if isinstance(config, dict):
module = get_module_by_module_path(config.get("module_path", default_module))
# raise AttributeError
klass = getattr(module, config["class"])
kwargs = config.get("kwargs", {})
elif isinstance(config, str):
module = get_module_by_module_path(default_module)
klass = getattr(module, config)
kwargs = {}
else:
@@ -212,8 +222,8 @@ def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
def init_instance_by_config(
config: Union[str, dict, object], module=None, accept_types: Union[type, Tuple[type]] = (), **kwargs
) -> object:
config: Union[str, dict, object], default_module=None, accept_types: Union[type, Tuple[type]] = (), **kwargs
) -> Any:
"""
get initialized instance with config
@@ -227,13 +237,19 @@ def init_instance_by_config(
'model_path': path, # It is optional if module is given
}
str example.
"ClassName": getattr(module, config)() will be used.
1) specify a pickle object
- path like 'file:///<path to pickle file>/obj.pkl'
2) specify a class name
- "ClassName": getattr(module, config)() will be used.
object example:
instance of accept_types
module : Python module
default_module : Python module
Optional. It should be a python module.
NOTE: the "module_path" will be override by `module` arguments
This function will load class from the config['module_path'] first.
If config['module_path'] doesn't exists, it will load the class from default_module.
accept_types: Union[type, Tuple[type]]
Optional. If the config is a instance of specific type, return the config directly.
This will be passed into the second parameter of isinstance.
@@ -246,10 +262,14 @@ def init_instance_by_config(
if isinstance(config, accept_types):
return config
if module is None:
module = get_module_by_module_path(config["module_path"])
if isinstance(config, str):
# path like 'file:///<path to pickle file>/obj.pkl'
pr = urlparse(config)
if pr.scheme == "file":
with open(os.path.join(pr.netloc, pr.path), "rb") as f:
return pickle.load(f)
klass, cls_kwargs = get_cls_kwargs(config, module)
klass, cls_kwargs = get_cls_kwargs(config, default_module=default_module)
return klass(**cls_kwargs, **kwargs)
@@ -502,7 +522,7 @@ def get_date_range(trading_date, left_shift=0, right_shift=0, future=False):
return calendar
def get_date_by_shift(trading_date, shift, future=False, clip_shift=True):
def get_date_by_shift(trading_date, shift, future=False, clip_shift=True, freq="day"):
"""get trading date with shift bias wil cur_date
e.g. : shift == 1, return next trading date
shift == -1, return previous trading date
@@ -515,7 +535,7 @@ def get_date_by_shift(trading_date, shift, future=False, clip_shift=True):
"""
from qlib.data import D
cal = D.calendar(future=future)
cal = D.calendar(future=future, freq=freq)
if pd.to_datetime(trading_date) not in list(cal):
raise ValueError("{} is not trading day!".format(str(trading_date)))
_index = bisect.bisect_left(cal, trading_date)
@@ -696,23 +716,33 @@ def lazy_sort_index(df: pd.DataFrame, axis=0) -> pd.DataFrame:
return df.sort_index(axis=axis)
def flatten_dict(d, parent_key="", sep="."):
"""flatten_dict.
FLATTEN_TUPLE = "_FLATTEN_TUPLE"
def flatten_dict(d, parent_key="", sep=".") -> dict:
"""
Flatten a nested dict.
>>> flatten_dict({'a': 1, 'c': {'a': 2, 'b': {'x': 5, 'y' : 10}}, 'd': [1, 2, 3]})
>>> {'a': 1, 'c.a': 2, 'c.b.x': 5, 'd': [1, 2, 3], 'c.b.y': 10}
Parameters
----------
d :
d
parent_key :
parent_key
sep :
sep
>>> flatten_dict({'a': 1, 'c': {'a': 2, 'b': {'x': 5, 'y' : 10}}, 'd': [1, 2, 3]}, sep=FLATTEN_TUPLE)
>>> {'a': 1, ('c','a'): 2, ('c','b','x'): 5, 'd': [1, 2, 3], ('c','b','y'): 10}
Args:
d (dict): the dict waiting for flatting
parent_key (str, optional): the parent key, will be a prefix in new key. Defaults to "".
sep (str, optional): the separator for string connecting. FLATTEN_TUPLE for tuple connecting.
Returns:
dict: flatten dict
"""
items = []
for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k
if sep == FLATTEN_TUPLE:
new_key = (parent_key, k) if parent_key else k
else:
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, collections.abc.MutableMapping):
items.extend(flatten_dict(v, new_key, sep=sep).items())
else:

View File

@@ -3,16 +3,24 @@
from pathlib import Path
import pickle
import typing
import dill
from typing import Union
class Serializable:
"""
Serializable behaves like pickle.
But it only saves the state whose name **does not** start with `_`
Serializable will change the behaviors of pickle.
- It only saves the state whose name **does not** start with `_`
It provides a syntactic sugar for distinguish the attributes which user doesn't want.
- For examples, a learnable Datahandler just wants to save the parameters without data when dumping to disk
"""
pickle_backend = "pickle" # another optional value is "dill" which can pickle more things of python.
default_dump_all = False # if dump all things
def __init__(self):
self._dump_all = False
self._dump_all = self.default_dump_all
self._exclude = []
def __getstate__(self) -> dict:
@@ -33,18 +41,86 @@ class Serializable:
@property
def exclude(self):
"""
What attribute will be dumped
What attribute will not be dumped
"""
return getattr(self, "_exclude", [])
def config(self, dump_all: bool = None, exclude: list = None):
if dump_all is not None:
self._dump_all = dump_all
FLAG_KEY = "_qlib_serial_flag"
if exclude is not None:
self._exclude = exclude
def config(self, dump_all: bool = None, exclude: list = None, recursive=False):
"""
configure the serializable object
def to_pickle(self, path: [Path, str], dump_all: bool = None, exclude: list = None):
Parameters
----------
dump_all : bool
will the object dump all object
exclude : list
What attribute will not be dumped
recursive : bool
will the configuration be recursive
"""
params = {"dump_all": dump_all, "exclude": exclude}
for k, v in params.items():
if v is not None:
attr_name = f"_{k}"
setattr(self, attr_name, v)
if recursive:
for obj in self.__dict__.values():
# set flag to prevent endless loop
self.__dict__[self.FLAG_KEY] = True
if isinstance(obj, Serializable) and self.FLAG_KEY not in obj.__dict__:
obj.config(**params, recursive=True)
del self.__dict__[self.FLAG_KEY]
def to_pickle(self, path: Union[Path, str], dump_all: bool = None, exclude: list = None):
"""
Dump self to a pickle file.
Args:
path (Union[Path, str]): the path to dump
dump_all (bool, optional): if need to dump all things. Defaults to None.
exclude (list, optional): will exclude the attributes in this list when dumping. Defaults to None.
"""
self.config(dump_all=dump_all, exclude=exclude)
with Path(path).open("wb") as f:
pickle.dump(self, f)
self.get_backend().dump(self, f)
@classmethod
def load(cls, filepath):
"""
Load the collector from a filepath.
Args:
filepath (str): the path of file
Raises:
TypeError: the pickled file must be `Collector`
Returns:
Collector: the instance of Collector
"""
with open(filepath, "rb") as f:
object = cls.get_backend().load(f)
if isinstance(object, cls):
return object
else:
raise TypeError(f"The instance of {type(object)} is not a valid `{type(cls)}`!")
@classmethod
def get_backend(cls):
"""
Return the real backend of a Serializable class. The pickle_backend value can be "pickle" or "dill".
Returns:
module: pickle or dill module based on pickle_backend
"""
if cls.pickle_backend == "pickle":
return pickle
elif cls.pickle_backend == "dill":
return dill
else:
raise ValueError("Unknown pickle backend, please use 'pickle' or 'dill'.")

View File

@@ -331,7 +331,7 @@ class QlibRecorder:
"""
self.exp_manager.set_uri(uri)
def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=None):
def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=None) -> Recorder:
"""
Method for retrieving a recorder.

View File

View File

@@ -0,0 +1,304 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
OnlineManager can manage a set of `Online Strategy <#Online Strategy>`_ and run them dynamically.
With the change of time, the decisive models will be also changed. In this module, we call those contributing models `online` models.
In every routine(such as every day or every minute), the `online` models may be changed and the prediction of them needs to be updated.
So this module provides a series of methods to control this process.
This module also provides a method to simulate `Online Strategy <#Online Strategy>`_ in history.
Which means you can verify your strategy or find a better one.
There are 4 total situations for using different trainers in different situations:
========================= ===================================================================================
Situations Description
========================= ===================================================================================
Online + Trainer When you REAL want to do a routine, the Trainer will help you train the models.
Online + DelayTrainer In normal online routine, whether Trainer or DelayTrainer will REAL train models
in this routine. So it is not necessary to use DelayTrainer when do a REAL routine.
Simulation + Trainer When your models have some temporal dependence on the previous models, then you
need to consider using Trainer. This means it will REAL train your models in
every routine and prepare signals for every routine.
Simulation + DelayTrainer When your models don't have any temporal dependence, you can use DelayTrainer
for the ability to multitasking. It means all tasks in all routines
can be REAL trained at the end of simulating. The signals will be prepared well at
different time segments (based on whether or not any new model is online).
========================= ===================================================================================
"""
import logging
from typing import Callable, Dict, List, Union
import pandas as pd
from qlib import get_module_logger
from qlib.data.data import D
from qlib.log import set_global_logger_level
from qlib.model.ens.ensemble import AverageEnsemble
from qlib.model.trainer import DelayTrainerR, Trainer, TrainerR
from qlib.utils import flatten_dict
from qlib.utils.serial import Serializable
from qlib.workflow.online.strategy import OnlineStrategy
from qlib.workflow.task.collect import MergeCollector
class OnlineManager(Serializable):
"""
OnlineManager can manage online models with `Online Strategy <#Online Strategy>`_.
It also provides a history recording of which models are online at what time.
"""
STATUS_SIMULATING = "simulating" # when calling `simulate`
STATUS_NORMAL = "normal" # the normal status
def __init__(
self,
strategies: Union[OnlineStrategy, List[OnlineStrategy]],
trainer: Trainer = None,
begin_time: Union[str, pd.Timestamp] = None,
freq="day",
):
"""
Init OnlineManager.
One OnlineManager must have at least one OnlineStrategy.
Args:
strategies (Union[OnlineStrategy, List[OnlineStrategy]]): an instance of OnlineStrategy or a list of OnlineStrategy
begin_time (Union[str,pd.Timestamp], optional): the OnlineManager will begin at this time. Defaults to None for using the latest date.
trainer (Trainer): the trainer to train task. None for using TrainerR.
freq (str, optional): data frequency. Defaults to "day".
"""
self.logger = get_module_logger(self.__class__.__name__)
if not isinstance(strategies, list):
strategies = [strategies]
self.strategies = strategies
self.freq = freq
if begin_time is None:
begin_time = D.calendar(freq=self.freq).max()
self.begin_time = pd.Timestamp(begin_time)
self.cur_time = self.begin_time
# OnlineManager will recorder the history of online models, which is a dict like {pd.Timestamp, {strategy, [online_models]}}.
self.history = {}
if trainer is None:
trainer = TrainerR()
self.trainer = trainer
self.signals = None
self.status = self.STATUS_NORMAL
def first_train(self, strategies: List[OnlineStrategy] = None, model_kwargs: dict = {}):
"""
Get tasks from every strategy's first_tasks method and train them.
If using DelayTrainer, it can finish training all together after every strategy's first_tasks.
Args:
strategies (List[OnlineStrategy]): the strategies list (need this param when adding strategies). None for use default strategies.
model_kwargs (dict): the params for `prepare_online_models`
"""
if strategies is None:
strategies = self.strategies
for strategy in strategies:
self.logger.info(f"Strategy `{strategy.name_id}` begins first training...")
tasks = strategy.first_tasks()
models = self.trainer.train(tasks, experiment_name=strategy.name_id)
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
self.logger.info(f"Finished training {len(models)} models.")
online_models = strategy.prepare_online_models(models, **model_kwargs)
self.history.setdefault(self.cur_time, {})[strategy] = online_models
def routine(
self,
cur_time: Union[str, pd.Timestamp] = None,
task_kwargs: dict = {},
model_kwargs: dict = {},
signal_kwargs: dict = {},
):
"""
Typical update process for every strategy and record the online history.
The typical update process after a routine, such as day by day or month by month.
The process is: Update predictions -> Prepare tasks -> Prepare online models -> Prepare signals.
If using DelayTrainer, it can finish training all together after every strategy's prepare_tasks.
Args:
cur_time (Union[str,pd.Timestamp], optional): run routine method in this time. Defaults to None.
task_kwargs (dict): the params for `prepare_tasks`
model_kwargs (dict): the params for `prepare_online_models`
signal_kwargs (dict): the params for `prepare_signals`
"""
if cur_time is None:
cur_time = D.calendar(freq=self.freq).max()
self.cur_time = pd.Timestamp(cur_time) # None for latest date
for strategy in self.strategies:
self.logger.info(f"Strategy `{strategy.name_id}` begins routine...")
if self.status == self.STATUS_NORMAL:
strategy.tool.update_online_pred()
tasks = strategy.prepare_tasks(self.cur_time, **task_kwargs)
models = self.trainer.train(tasks)
if self.status == self.STATUS_NORMAL or not self.trainer.is_delay():
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
self.logger.info(f"Finished training {len(models)} models.")
online_models = strategy.prepare_online_models(models, **model_kwargs)
self.history.setdefault(self.cur_time, {})[strategy] = online_models
if not self.trainer.is_delay():
self.prepare_signals(**signal_kwargs)
def get_collector(self) -> MergeCollector:
"""
Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results from every strategy.
This collector can be a basis as the signals preparation.
Returns:
MergeCollector: the collector to merge other collectors.
"""
collector_dict = {}
for strategy in self.strategies:
collector_dict[strategy.name_id] = strategy.get_collector()
return MergeCollector(collector_dict, process_list=[])
def add_strategy(self, strategies: Union[OnlineStrategy, List[OnlineStrategy]]):
"""
Add some new strategies to OnlineManager.
Args:
strategy (Union[OnlineStrategy, List[OnlineStrategy]]): a list of OnlineStrategy
"""
if not isinstance(strategies, list):
strategies = [strategies]
self.first_train(strategies)
self.strategies.extend(strategies)
def prepare_signals(self, prepare_func: Callable = AverageEnsemble(), over_write=False):
"""
After preparing the data of the last routine (a box in box-plot) which means the end of the routine, we can prepare trading signals for the next routine.
NOTE: Given a set prediction, all signals before these prediction end times will be prepared well.
Even if the latest signal already exists, the latest calculation result will be overwritten.
.. note::
Given a prediction of a certain time, all signals before this time will be prepared well.
Args:
prepare_func (Callable, optional): Get signals from a dict after collecting. Defaults to AverageEnsemble(), the results collected by MergeCollector must be {xxx:pred}.
over_write (bool, optional): If True, the new signals will overwrite. If False, the new signals will append to the end of signals. Defaults to False.
Returns:
pd.DataFrame: the signals.
"""
signals = prepare_func(self.get_collector()())
old_signals = self.signals
if old_signals is not None and not over_write:
old_max = old_signals.index.get_level_values("datetime").max()
new_signals = signals.loc[old_max:]
signals = pd.concat([old_signals, new_signals], axis=0)
else:
new_signals = signals
self.logger.info(f"Finished preparing new {len(new_signals)} signals.")
self.signals = signals
return new_signals
def get_signals(self) -> Union[pd.Series, pd.DataFrame]:
"""
Get prepared online signals.
Returns:
Union[pd.Series, pd.DataFrame]: pd.Series for only one signals every datetime.
pd.DataFrame for multiple signals, for example, buy and sell operations use different trading signals.
"""
return self.signals
SIM_LOG_LEVEL = logging.INFO + 1 # when simulating, reduce information
SIM_LOG_NAME = "SIMULATE_INFO"
def simulate(
self, end_time, frequency="day", task_kwargs={}, model_kwargs={}, signal_kwargs={}
) -> Union[pd.Series, pd.DataFrame]:
"""
Starting from the current time, this method will simulate every routine in OnlineManager until the end time.
Considering the parallel training, the models and signals can be prepared after all routine simulating.
The delay training way can be ``DelayTrainer`` and the delay preparing signals way can be ``delay_prepare``.
Args:
end_time: the time the simulation will end
frequency: the calendar frequency
task_kwargs (dict): the params for `prepare_tasks`
model_kwargs (dict): the params for `prepare_online_models`
signal_kwargs (dict): the params for `prepare_signals`
Returns:
Union[pd.Series, pd.DataFrame]: pd.Series for only one signals every datetime.
pd.DataFrame for multiple signals, for example, buy and sell operations use different trading signals.
"""
self.status = self.STATUS_SIMULATING
cal = D.calendar(start_time=self.cur_time, end_time=end_time, freq=frequency)
self.first_train()
simulate_level = self.SIM_LOG_LEVEL
set_global_logger_level(simulate_level)
logging.addLevelName(simulate_level, self.SIM_LOG_NAME)
for cur_time in cal:
self.logger.log(level=simulate_level, msg=f"Simulating at {str(cur_time)}......")
self.routine(
cur_time,
task_kwargs=task_kwargs,
model_kwargs=model_kwargs,
signal_kwargs=signal_kwargs,
)
# delay prepare the models and signals
if self.trainer.is_delay():
self.delay_prepare(model_kwargs=model_kwargs, signal_kwargs=signal_kwargs)
# FIXME: get logging level firstly and restore it here
set_global_logger_level(logging.DEBUG)
self.logger.info(f"Finished preparing signals")
self.status = self.STATUS_NORMAL
return self.get_signals()
def delay_prepare(self, model_kwargs={}, signal_kwargs={}):
"""
Prepare all models and signals if something is waiting for preparation.
Args:
model_kwargs: the params for `end_train`
signal_kwargs: the params for `prepare_signals`
"""
last_models = {}
signals_time = D.calendar()[0]
need_prepare = False
for cur_time, strategy_models in self.history.items():
self.cur_time = cur_time
for strategy, models in strategy_models.items():
# only new online models need to prepare
if last_models.setdefault(strategy, set()) != set(models):
models = self.trainer.end_train(models, experiment_name=strategy.name_id, **model_kwargs)
strategy.tool.reset_online_tag(models)
need_prepare = True
last_models[strategy] = set(models)
if need_prepare:
# NOTE: Assumption: the predictions of online models need less than next cur_time, or this method will work in a wrong way.
self.prepare_signals(**signal_kwargs)
if signals_time > cur_time:
self.logger.warn(
f"The signals have already parpred to {signals_time} by last preparation, but current time is only {cur_time}. This may be because the online models predict more than they should, which can cause signals to be contaminated by the offline models."
)
need_prepare = False
signals_time = self.signals.index.get_level_values("datetime").max()

View File

@@ -0,0 +1,211 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
OnlineStrategy module is an element of online serving.
"""
from copy import deepcopy
from typing import List, Tuple, Union
from qlib.data.data import D
from qlib.log import get_module_logger
from qlib.model.ens.group import RollingGroup
from qlib.workflow.online.utils import OnlineTool, OnlineToolR
from qlib.workflow.recorder import Recorder
from qlib.workflow.task.collect import Collector, RecorderCollector
from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.utils import TimeAdjuster
class OnlineStrategy:
"""
OnlineStrategy is working with `Online Manager <#Online Manager>`_, responding to how the tasks are generated, the models are updated and signals are prepared.
"""
def __init__(self, name_id: str):
"""
Init OnlineStrategy.
This module **MUST** use `Trainer <../reference/api.html#Trainer>`_ to finishing model training.
Args:
name_id (str): a unique name or id.
trainer (Trainer, optional): a instance of Trainer. Defaults to None.
"""
self.name_id = name_id
self.logger = get_module_logger(self.__class__.__name__)
self.tool = OnlineTool()
def prepare_tasks(self, cur_time, **kwargs) -> List[dict]:
"""
After the end of a routine, check whether we need to prepare and train some new tasks based on cur_time (None for latest)..
Return the new tasks waiting for training.
You can find the last online models by OnlineTool.online_models.
"""
raise NotImplementedError(f"Please implement the `prepare_tasks` method.")
def prepare_online_models(self, trained_models, cur_time=None) -> List[object]:
"""
Select some models from trained models and set them to online models.
This is a typical implementation to online all trained models, you can override it to implement the complex method.
You can find the last online models by OnlineTool.online_models if you still need them.
NOTE: Reset all online models to trained models. If there are no trained models, then do nothing.
Args:
models (list): a list of models.
cur_time (pd.Dataframe): current time from OnlineManger. None for the latest.
Returns:
List[object]: a list of online models.
"""
if not trained_models:
return self.tool.online_models()
self.tool.reset_online_tag(trained_models)
return trained_models
def first_tasks(self) -> List[dict]:
"""
Generate a series of tasks firstly and return them.
"""
raise NotImplementedError(f"Please implement the `first_tasks` method.")
def get_collector(self) -> Collector:
"""
Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect different results of this strategy.
For example:
1) collect predictions in Recorder
2) collect signals in a txt file
Returns:
Collector
"""
raise NotImplementedError(f"Please implement the `get_collector` method.")
class RollingStrategy(OnlineStrategy):
"""
This example strategy always uses the latest rolling model sas online models.
"""
def __init__(
self,
name_id: str,
task_template: Union[dict, List[dict]],
rolling_gen: RollingGen,
):
"""
Init RollingStrategy.
Assumption: the str of name_id, the experiment name, and the trainer's experiment name are the same.
Args:
name_id (str): a unique name or id. Will be also the name of the Experiment.
task_template (Union[dict, List[dict]]): a list of task_template or a single template, which will be used to generate many tasks using rolling_gen.
rolling_gen (RollingGen): an instance of RollingGen
"""
super().__init__(name_id=name_id)
self.exp_name = self.name_id
if not isinstance(task_template, list):
task_template = [task_template]
self.task_template = task_template
self.rg = rolling_gen
self.tool = OnlineToolR(self.exp_name)
self.ta = TimeAdjuster()
def get_collector(self, process_list=[RollingGroup()], rec_key_func=None, rec_filter_func=None, artifacts_key=None):
"""
Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results. The returned collector must distinguish results in different models.
Assumption: the models can be distinguished based on the model name and rolling test segments.
If you do not want this assumption, please implement your method or use another rec_key_func.
Args:
rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
artifacts_key (List[str], optional): the artifacts key you want to get. If None, get all artifacts.
"""
def rec_key(recorder):
task_config = recorder.load_object("task")
model_key = task_config["model"]["class"]
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
return model_key, rolling_key
if rec_key_func is None:
rec_key_func = rec_key
artifacts_collector = RecorderCollector(
experiment=self.exp_name,
process_list=process_list,
rec_key_func=rec_key_func,
rec_filter_func=rec_filter_func,
artifacts_key=artifacts_key,
)
return artifacts_collector
def first_tasks(self) -> List[dict]:
"""
Use rolling_gen to generate different tasks based on task_template.
Returns:
List[dict]: a list of tasks
"""
return task_generator(
tasks=self.task_template,
generators=self.rg, # generate different date segment
)
def prepare_tasks(self, cur_time) -> List[dict]:
"""
Prepare new tasks based on cur_time (None for the latest).
You can find the last online models by OnlineToolR.online_models.
Returns:
List[dict]: a list of new tasks.
"""
latest_records, max_test = self._list_latest(self.tool.online_models())
if max_test is None:
self.logger.warn(f"No latest online recorders, no new tasks.")
return []
calendar_latest = D.calendar(end_time=cur_time)[-1] if cur_time is None else cur_time
self.logger.info(
f"The interval between current time {calendar_latest} and last rolling test begin time {max_test[0]} is {self.ta.cal_interval(calendar_latest, max_test[0])}, the rolling step is {self.rg.step}"
)
if self.ta.cal_interval(calendar_latest, max_test[0]) >= self.rg.step:
old_tasks = []
tasks_tmp = []
for rec in latest_records:
task = rec.load_object("task")
old_tasks.append(deepcopy(task))
test_begin = task["dataset"]["kwargs"]["segments"]["test"][0]
# modify the test segment to generate new tasks
task["dataset"]["kwargs"]["segments"]["test"] = (test_begin, calendar_latest)
tasks_tmp.append(task)
new_tasks_tmp = task_generator(tasks_tmp, self.rg)
new_tasks = [task for task in new_tasks_tmp if task not in old_tasks]
return new_tasks
return []
def _list_latest(self, rec_list: List[Recorder]):
"""
List latest recorder form rec_list
Args:
rec_list (List[Recorder]): a list of Recorder
Returns:
List[Recorder], pd.Timestamp: the latest recorders and their test end time
"""
if len(rec_list) == 0:
return rec_list, None
max_test = max(rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] for rec in rec_list)
latest_rec = []
for rec in rec_list:
if rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] == max_test:
latest_rec.append(rec)
return latest_rec, max_test

View File

@@ -0,0 +1,160 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
Updater is a module to update artifacts such as predictions when the stock data is updating.
"""
from abc import ABCMeta, abstractmethod
import pandas as pd
from qlib import get_module_logger
from qlib.data import D
from qlib.data.dataset import DatasetH
from qlib.data.dataset.handler import DataHandlerLP
from qlib.model import Model
from qlib.utils import get_date_by_shift
from qlib.workflow.recorder import Recorder
class RMDLoader:
"""
Recorder Model Dataset Loader
"""
def __init__(self, rec: Recorder):
self.rec = rec
def get_dataset(self, start_time, end_time, segments=None) -> DatasetH:
"""
Load, config and setup dataset.
This dataset is for inference.
Args:
start_time :
the start_time of underlying data
end_time :
the end_time of underlying data
segments : dict
the segments config for dataset
Due to the time series dataset (TSDatasetH), the test segments maybe different from start_time and end_time
Returns:
DatasetH: the instance of DatasetH
"""
if segments is None:
segments = {"test": (start_time, end_time)}
dataset: DatasetH = self.rec.load_object("dataset")
dataset.config(handler_kwargs={"start_time": start_time, "end_time": end_time}, segments=segments)
dataset.setup_data(handler_kwargs={"init_type": DataHandlerLP.IT_LS})
return dataset
def get_model(self) -> Model:
return self.rec.load_object("params.pkl")
class RecordUpdater(metaclass=ABCMeta):
"""
Update a specific recorders
"""
def __init__(self, record: Recorder, *args, **kwargs):
self.record = record
self.logger = get_module_logger(self.__class__.__name__)
@abstractmethod
def update(self, *args, **kwargs):
"""
Update info for specific recorder
"""
...
class PredUpdater(RecordUpdater):
"""
Update the prediction in the Recorder
"""
def __init__(self, record: Recorder, to_date=None, hist_ref: int = 0, freq="day"):
"""
Init PredUpdater.
Args:
record : Recorder
to_date :
update to prediction to the `to_date`
hist_ref : int
Sometimes, the dataset will have historical depends.
Leave the problem to users to set the length of historical dependency
.. note::
the start_time is not included in the hist_ref
"""
# TODO: automate this hist_ref in the future.
super().__init__(record=record)
self.to_date = to_date
self.hist_ref = hist_ref
self.freq = freq
self.rmdl = RMDLoader(rec=record)
if to_date == None:
to_date = D.calendar(freq=freq)[-1]
self.to_date = pd.Timestamp(to_date)
self.old_pred = record.load_object("pred.pkl")
self.last_end = self.old_pred.index.get_level_values("datetime").max()
def prepare_data(self) -> DatasetH:
"""
Load dataset
Separating this function will make it easier to reuse the dataset
Returns:
DatasetH: the instance of DatasetH
"""
start_time_buffer = get_date_by_shift(self.last_end, -self.hist_ref + 1, clip_shift=False, freq=self.freq)
start_time = get_date_by_shift(self.last_end, 1, freq=self.freq)
seg = {"test": (start_time, self.to_date)}
dataset = self.rmdl.get_dataset(start_time=start_time_buffer, end_time=self.to_date, segments=seg)
return dataset
def update(self, dataset: DatasetH = None):
"""
Update the prediction in a recorder.
Args:
DatasetH: the instance of DatasetH. None for reprepare.
"""
# FIXME: the problem below is not solved
# The model dumped on GPU instances can not be loaded on CPU instance. Follow exception will raised
# RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
# https://github.com/pytorch/pytorch/issues/16797
start_time = get_date_by_shift(self.last_end, 1, freq=self.freq)
if start_time >= self.to_date:
self.logger.info(
f"The prediction in {self.record.info['id']} are latest ({start_time}). No need to update to {self.to_date}."
)
return
# load dataset
if dataset is None:
# For reusing the dataset
dataset = self.prepare_data()
# Load model
model = self.rmdl.get_model()
new_pred: pd.Series = model.predict(dataset)
cb_pred = pd.concat([self.old_pred, new_pred.to_frame("score")], axis=0)
cb_pred = cb_pred.sort_index()
self.record.save_objects(**{"pred.pkl": cb_pred})
self.logger.info(f"Finish updating new {new_pred.shape[0]} predictions in {self.record.info['id']}.")

View File

@@ -0,0 +1,168 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
OnlineTool is a module to set and unset a series of `online` models.
The `online` models are some decisive models in some time points, which can be changed with the change of time.
This allows us to use efficient submodels as the market-style changing.
"""
from typing import List, Union
from qlib.log import get_module_logger
from qlib.workflow.online.update import PredUpdater
from qlib.workflow.recorder import Recorder
from qlib.workflow.task.utils import list_recorders
class OnlineTool:
"""
OnlineTool will manage `online` models in an experiment that includes the model recorders.
"""
ONLINE_KEY = "online_status" # the online status key in recorder
ONLINE_TAG = "online" # the 'online' model
OFFLINE_TAG = "offline" # the 'offline' model, not for online serving
def __init__(self):
"""
Init OnlineTool.
"""
self.logger = get_module_logger(self.__class__.__name__)
def set_online_tag(self, tag, recorder: Union[list, object]):
"""
Set `tag` to the model to sign whether online.
Args:
tag (str): the tags in `ONLINE_TAG`, `OFFLINE_TAG`
recorder (Union[list,object]): the model's recorder
"""
raise NotImplementedError(f"Please implement the `set_online_tag` method.")
def get_online_tag(self, recorder: object) -> str:
"""
Given a model recorder and return its online tag.
Args:
recorder (Object): the model's recorder
Returns:
str: the online tag
"""
raise NotImplementedError(f"Please implement the `get_online_tag` method.")
def reset_online_tag(self, recorder: Union[list, object]):
"""
Offline all models and set the recorders to 'online'.
Args:
recorder (Union[list,object]):
the recorder you want to reset to 'online'.
"""
raise NotImplementedError(f"Please implement the `reset_online_tag` method.")
def online_models(self) -> list:
"""
Get current `online` models
Returns:
list: a list of `online` models.
"""
raise NotImplementedError(f"Please implement the `online_models` method.")
def update_online_pred(self, to_date=None):
"""
Update the predictions of `online` models to to_date.
Args:
to_date (pd.Timestamp): the pred before this date will be updated. None for updating to the latest.
"""
raise NotImplementedError(f"Please implement the `update_online_pred` method.")
class OnlineToolR(OnlineTool):
"""
The implementation of OnlineTool based on (R)ecorder.
"""
def __init__(self, experiment_name: str):
"""
Init OnlineToolR.
Args:
experiment_name (str): the experiment name.
"""
super().__init__()
self.exp_name = experiment_name
def set_online_tag(self, tag, recorder: Union[Recorder, List]):
"""
Set `tag` to the model's recorder to sign whether online.
Args:
tag (str): the tags in `ONLINE_TAG`, `NEXT_ONLINE_TAG`, `OFFLINE_TAG`
recorder (Union[Recorder, List]): a list of Recorder or an instance of Recorder
"""
if isinstance(recorder, Recorder):
recorder = [recorder]
for rec in recorder:
rec.set_tags(**{self.ONLINE_KEY: tag})
self.logger.info(f"Set {len(recorder)} models to '{tag}'.")
def get_online_tag(self, recorder: Recorder) -> str:
"""
Given a model recorder and return its online tag.
Args:
recorder (Recorder): an instance of recorder
Returns:
str: the online tag
"""
tags = recorder.list_tags()
return tags.get(self.ONLINE_KEY, self.OFFLINE_TAG)
def reset_online_tag(self, recorder: Union[Recorder, List]):
"""
Offline all models and set the recorders to 'online'.
Args:
recorder (Union[Recorder, List]):
the recorder you want to reset to 'online'.
"""
if isinstance(recorder, Recorder):
recorder = [recorder]
recs = list_recorders(self.exp_name)
self.set_online_tag(self.OFFLINE_TAG, list(recs.values()))
self.set_online_tag(self.ONLINE_TAG, recorder)
def online_models(self) -> list:
"""
Get current `online` models
Returns:
list: a list of `online` models.
"""
return list(list_recorders(self.exp_name, lambda rec: self.get_online_tag(rec) == self.ONLINE_TAG).values())
def update_online_pred(self, to_date=None):
"""
Update the predictions of online models to to_date.
Args:
to_date (pd.Timestamp): the pred before this date will be updated. None for updating to latest time in Calendar.
"""
online_models = self.online_models()
for rec in online_models:
hist_ref = 0
task = rec.load_object("task")
# Special treatment of historical dependencies
if task["dataset"]["class"] == "TSDatasetH":
hist_ref = task["dataset"]["kwargs"]["step_len"]
PredUpdater(rec, to_date=to_date, hist_ref=hist_ref).update()
self.logger.info(f"Finished updating {len(online_models)} online model predictions of {self.exp_name}.")

View File

@@ -151,6 +151,10 @@ class SignalRecord(RecordTemp):
del params["data_key"]
# The backend handler should be DataHandler
raw_label = self.dataset.prepare(**params)
except AttributeError:
# The data handler is initialize with `drop_raw=True`...
# So raw_label is not available
raw_label = None
self.recorder.save_objects(**{"label.pkl": raw_label})
self.dataset.__class__ = orig_cls
@@ -236,6 +240,9 @@ class SigAnaRecord(SignalRecord):
pred = self.load("pred.pkl")
label = self.load("label.pkl")
if label is None or not isinstance(label, pd.DataFrame) or label.empty:
logger.warn(f"Empty label.")
return
ic, ric = calc_ic(pred.iloc[:, 0], label.iloc[:, 0])
metrics = {
"IC": ic.mean(),

View File

@@ -39,6 +39,9 @@ class Recorder:
def __str__(self):
return str(self.info)
def __hash__(self) -> int:
return hash(self.info["id"])
@property
def info(self):
output = dict()
@@ -232,6 +235,14 @@ class MLflowRecorder(Recorder):
client=self.client,
)
def __hash__(self) -> int:
return hash(self.info["id"])
def __eq__(self, o: object) -> bool:
if isinstance(o, MLflowRecorder):
return self.info["id"] == o.info["id"]
return False
@property
def uri(self):
return self._uri

View File

@@ -0,0 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
Task related workflow is implemented in this folder
A typical task workflow
| Step | Description |
|-----------------------+------------------------------------------------|
| TaskGen | Generating tasks. |
| TaskManager(optional) | Manage generated tasks |
| run task | retrive tasks from TaskManager and run tasks. |
"""

View File

@@ -0,0 +1,219 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
Collector module can collect objects from everywhere and process them such as merging, grouping, averaging and so on.
"""
from typing import Callable, Dict, List
from qlib.utils.serial import Serializable
from qlib.workflow import R
class Collector(Serializable):
"""The collector to collect different results"""
pickle_backend = "dill" # use dill to dump user method
def __init__(self, process_list=[]):
"""
Init Collector.
Args:
process_list (list or Callable): the list of processors or the instance of a processor to process dict.
"""
if not isinstance(process_list, list):
process_list = [process_list]
self.process_list = process_list
def collect(self) -> dict:
"""
Collect the results and return a dict like {key: things}
Returns:
dict: the dict after collecting.
For example:
{"prediction": pd.Series}
{"IC": {"Xgboost": pd.Series, "LSTM": pd.Series}}
......
"""
raise NotImplementedError(f"Please implement the `collect` method.")
@staticmethod
def process_collect(collected_dict, process_list=[], *args, **kwargs) -> dict:
"""
Do a series of processing to the dict returned by collect and return a dict like {key: things}
For example, you can group and ensemble.
Args:
collected_dict (dict): the dict return by `collect`
process_list (list or Callable): the list of processors or the instance of a processor to process dict.
The processor order is the same as the list order.
For example: [Group1(..., Ensemble1()), Group2(..., Ensemble2())]
Returns:
dict: the dict after processing.
"""
if not isinstance(process_list, list):
process_list = [process_list]
result = {}
for artifact in collected_dict:
value = collected_dict[artifact]
for process in process_list:
if not callable(process):
raise NotImplementedError(f"{type(process)} is not supported in `process_collect`.")
value = process(value, *args, **kwargs)
result[artifact] = value
return result
def __call__(self, *args, **kwargs) -> dict:
"""
Do the workflow including ``collect`` and ``process_collect``
Returns:
dict: the dict after collecting and processing.
"""
collected = self.collect()
return self.process_collect(collected, self.process_list, *args, **kwargs)
class MergeCollector(Collector):
"""
A collector to collect the results of other Collectors
For example:
We have 2 collector, which named A and B.
A can collect {"prediction": pd.Series} and B can collect {"IC": {"Xgboost": pd.Series, "LSTM": pd.Series}}.
Then after this class's collect, we can collect {"A_prediction": pd.Series, "B_IC": {"Xgboost": pd.Series, "LSTM": pd.Series}}
......
"""
def __init__(self, collector_dict: Dict[str, Collector], process_list: List[Callable] = [], merge_func=None):
"""
Init MergeCollector.
Args:
collector_dict (Dict[str,Collector]): the dict like {collector_key, Collector}
process_list (List[Callable]): the list of processors or the instance of processor to process dict.
merge_func (Callable): a method to generate outermost key. The given params are ``collector_key`` from collector_dict and ``key`` from every collector after collecting.
None for using tuple to connect them, such as "ABC"+("a","b") -> ("ABC", ("a","b")).
"""
super().__init__(process_list=process_list)
self.collector_dict = collector_dict
self.merge_func = merge_func
def collect(self) -> dict:
"""
Collect all results of collector_dict and change the outermost key to a recombination key.
Returns:
dict: the dict after collecting.
"""
collect_dict = {}
for collector_key, collector in self.collector_dict.items():
tmp_dict = collector()
for key, value in tmp_dict.items():
if self.merge_func is not None:
collect_dict[self.merge_func(collector_key, key)] = value
else:
collect_dict[(collector_key, key)] = value
return collect_dict
class RecorderCollector(Collector):
ART_KEY_RAW = "__raw"
def __init__(
self,
experiment,
process_list=[],
rec_key_func=None,
rec_filter_func=None,
artifacts_path={"pred": "pred.pkl"},
artifacts_key=None,
):
"""
Init RecorderCollector.
Args:
experiment (Experiment or str): an instance of an Experiment or the name of an Experiment
process_list (list or Callable): the list of processors or the instance of a processor to process dict.
rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
artifacts_path (dict, optional): The artifacts name and its path in Recorder. Defaults to {"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}.
artifacts_key (str or List, optional): the artifacts key you want to get. If None, get all artifacts.
"""
super().__init__(process_list=process_list)
if isinstance(experiment, str):
experiment = R.get_exp(experiment_name=experiment)
self.experiment = experiment
self.artifacts_path = artifacts_path
if rec_key_func is None:
rec_key_func = lambda rec: rec.info["id"]
if artifacts_key is None:
artifacts_key = list(self.artifacts_path.keys())
self.rec_key_func = rec_key_func
self.artifacts_key = artifacts_key
self.rec_filter_func = rec_filter_func
def collect(self, artifacts_key=None, rec_filter_func=None, only_exist=True) -> dict:
"""
Collect different artifacts based on recorder after filtering.
Args:
artifacts_key (str or List, optional): the artifacts key you want to get. If None, use the default.
rec_filter_func (Callable, optional): filter the recorder by return True or False. If None, use the default.
only_exist (bool, optional): if only collect the artifacts when a recorder really has.
If True, the recorder with exception when loading will not be collected. But if False, it will raise the exception.
Returns:
dict: the dict after collected like {artifact: {rec_key: object}}
"""
if artifacts_key is None:
artifacts_key = self.artifacts_key
if rec_filter_func is None:
rec_filter_func = self.rec_filter_func
if isinstance(artifacts_key, str):
artifacts_key = [artifacts_key]
collect_dict = {}
# filter records
recs = self.experiment.list_recorders()
recs_flt = {}
for rid, rec in recs.items():
if rec_filter_func is None or rec_filter_func(rec):
recs_flt[rid] = rec
for _, rec in recs_flt.items():
rec_key = self.rec_key_func(rec)
for key in artifacts_key:
if self.ART_KEY_RAW == key:
artifact = rec
else:
try:
artifact = rec.load_object(self.artifacts_path[key])
except Exception as e:
if only_exist:
# only collect existing artifact
continue
raise e
collect_dict.setdefault(key, {})[rec_key] = artifact
return collect_dict
def get_exp_name(self) -> str:
"""
Get experiment name
Returns:
str: experiment name
"""
return self.experiment.name

231
qlib/workflow/task/gen.py Normal file
View File

@@ -0,0 +1,231 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
TaskGenerator module can generate many tasks based on TaskGen and some task templates.
"""
import abc
import copy
from typing import List, Union, Callable
from .utils import TimeAdjuster
def task_generator(tasks, generators) -> list:
"""
Use a list of TaskGen and a list of task templates to generate different tasks.
For examples:
There are 3 task templates a,b,c and 2 TaskGen A,B. A will generates 2 tasks from a template and B will generates 3 tasks from a template.
task_generator([a, b, c], [A, B]) will finally generate 3*2*3 = 18 tasks.
Parameters
----------
tasks : List[dict] or dict
a list of task templates or a single task
generators : List[TaskGen] or TaskGen
a list of TaskGen or a single TaskGen
Returns
-------
list
a list of tasks
"""
if isinstance(tasks, dict):
tasks = [tasks]
if isinstance(generators, TaskGen):
generators = [generators]
# generate gen_task_list
for gen in generators:
new_task_list = []
for task in tasks:
new_task_list.extend(gen.generate(task))
tasks = new_task_list
return tasks
class TaskGen(metaclass=abc.ABCMeta):
"""
The base class for generating different tasks
Example 1:
input: a specific task template and rolling steps
output: rolling version of the tasks
Example 2:
input: a specific task template and losses list
output: a set of tasks with different losses
"""
@abc.abstractmethod
def generate(self, task: dict) -> List[dict]:
"""
Generate different tasks based on a task template
Parameters
----------
task: dict
a task template
Returns
-------
typing.List[dict]:
A list of tasks
"""
pass
def __call__(self, *args, **kwargs):
"""
This is just a syntactic sugar for generate
"""
return self.generate(*args, **kwargs)
def handler_mod(task: dict, rolling_gen):
"""
Help to modify the handler end time when using RollingGen
Args:
task (dict): a task template
rg (RollingGen): an instance of RollingGen
"""
try:
interval = rolling_gen.ta.cal_interval(
task["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"],
task["dataset"]["kwargs"]["segments"][rolling_gen.test_key][1],
)
# if end_time < the end of test_segments, then change end_time to allow load more data
if interval < 0:
task["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"] = copy.deepcopy(
task["dataset"]["kwargs"]["segments"][rolling_gen.test_key][1]
)
except KeyError:
# Maybe dataset do not have handler, then do nothing.
pass
class RollingGen(TaskGen):
ROLL_EX = TimeAdjuster.SHIFT_EX # fixed start date, expanding end date
ROLL_SD = TimeAdjuster.SHIFT_SD # fixed segments size, slide it from start date
def __init__(self, step: int = 40, rtype: str = ROLL_EX, ds_extra_mod_func: Union[None, Callable] = handler_mod):
"""
Generate tasks for rolling
Parameters
----------
step : int
step to rolling
rtype : str
rolling type (expanding, sliding)
ds_extra_mod_func: Callable
A method like: handler_mod(task: dict, rg: RollingGen)
Do some extra action after generating a task. For example, use ``handler_mod`` to modify the end time of the handler of a dataset.
"""
self.step = step
self.rtype = rtype
self.ds_extra_mod_func = ds_extra_mod_func
self.ta = TimeAdjuster(future=True)
self.test_key = "test"
self.train_key = "train"
def generate(self, task: dict) -> List[dict]:
"""
Converting the task into a rolling task.
Parameters
----------
task: dict
A dict describing a task. For example.
.. code-block:: python
DEFAULT_TASK = {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
},
"dataset": {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": "csi100",
},
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-20"), # Please avoid leaking the future test data into validation
"test": ("2017-01-01", "2020-08-01"),
},
},
},
"record": [
{
"class": "SignalRecord",
"module_path": "qlib.workflow.record_temp",
},
]
}
Returns
----------
List[dict]: a list of tasks
"""
res = []
prev_seg = None
test_end = None
while True:
t = copy.deepcopy(task)
# calculate segments
if prev_seg is None:
# First rolling
# 1) prepare the end point
segments: dict = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"]))
test_end = self.ta.max() if segments[self.test_key][1] is None else segments[self.test_key][1]
# 2) and init test segments
test_start_idx = self.ta.align_idx(segments[self.test_key][0])
segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1))
else:
segments = {}
try:
for k, seg in prev_seg.items():
# decide how to shift
# expanding only for train data, the segments size of test data and valid data won't change
if k == self.train_key and self.rtype == self.ROLL_EX:
rtype = self.ta.SHIFT_EX
else:
rtype = self.ta.SHIFT_SD
# shift the segments data
segments[k] = self.ta.shift(seg, step=self.step, rtype=rtype)
if segments[self.test_key][0] > test_end:
break
except KeyError:
# We reach the end of tasks
# No more rolling
break
# update segments of this task
t["dataset"]["kwargs"]["segments"] = copy.deepcopy(segments)
prev_seg = segments
if self.ds_extra_mod_func is not None:
self.ds_extra_mod_func(t, self)
res.append(t)
return res

View File

@@ -0,0 +1,493 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
TaskManager can fetch unused tasks automatically and manage the lifecycle of a set of tasks with error handling.
These features can run tasks concurrently and ensure every task will be used only once.
Task Manager will store all tasks in `MongoDB <https://www.mongodb.com/>`_.
Users **MUST** finished the configuration of `MongoDB <https://www.mongodb.com/>`_ when using this module.
A task in TaskManager consists of 3 parts
- tasks description: the desc will define the task
- tasks status: the status of the task
- tasks result: A user can get the task with the task description and task result.
"""
import concurrent
import pickle
import time
from contextlib import contextmanager
from typing import Callable, List
import fire
import pymongo
from bson.binary import Binary
from bson.objectid import ObjectId
from pymongo.errors import InvalidDocument
from qlib import auto_init, get_module_logger
from tqdm.cli import tqdm
from .utils import get_mongodb
class TaskManager:
"""
TaskManager
Here is what will a task looks like when it created by TaskManager
.. code-block:: python
{
'def': pickle serialized task definition. using pickle will make it easier
'filter': json-like data. This is for filtering the tasks.
'status': 'waiting' | 'running' | 'done'
'res': pickle serialized task result,
}
The tasks manager assumes that you will only update the tasks you fetched.
The mongo fetch one and update will make it date updating secure.
.. note::
Assumption: the data in MongoDB was encoded and the data out of MongoDB was decoded
Here are four status which are:
STATUS_WAITING: waiting for training
STATUS_RUNNING: training
STATUS_PART_DONE: finished some step and waiting for next step
STATUS_DONE: all work done
"""
STATUS_WAITING = "waiting"
STATUS_RUNNING = "running"
STATUS_DONE = "done"
STATUS_PART_DONE = "part_done"
ENCODE_FIELDS_PREFIX = ["def", "res"]
def __init__(self, task_pool: str = None):
"""
Init Task Manager, remember to make the statement of MongoDB url and database name firstly.
Parameters
----------
task_pool: str
the name of Collection in MongoDB
"""
self.mdb = get_mongodb()
if task_pool is not None:
self.task_pool = getattr(self.mdb, task_pool)
self.logger = get_module_logger(self.__class__.__name__)
def list(self) -> list:
"""
List the all collection(task_pool) of the db
Returns:
list
"""
return self.mdb.list_collection_names()
def _encode_task(self, task):
for prefix in self.ENCODE_FIELDS_PREFIX:
for k in list(task.keys()):
if k.startswith(prefix):
task[k] = Binary(pickle.dumps(task[k]))
return task
def _decode_task(self, task):
for prefix in self.ENCODE_FIELDS_PREFIX:
for k in list(task.keys()):
if k.startswith(prefix):
task[k] = pickle.loads(task[k])
return task
def _dict_to_str(self, flt):
return {k: str(v) for k, v in flt.items()}
def replace_task(self, task, new_task):
"""
Use a new task to replace a old one
Args:
task: old task
new_task: new task
"""
new_task = self._encode_task(new_task)
query = {"_id": ObjectId(task["_id"])}
try:
self.task_pool.replace_one(query, new_task)
except InvalidDocument:
task["filter"] = self._dict_to_str(task["filter"])
self.task_pool.replace_one(query, new_task)
def insert_task(self, task):
"""
Insert a task.
Args:
task: the task waiting for insert
Returns:
pymongo.results.InsertOneResult
"""
try:
insert_result = self.task_pool.insert_one(task)
except InvalidDocument:
task["filter"] = self._dict_to_str(task["filter"])
insert_result = self.task_pool.insert_one(task)
return insert_result
def insert_task_def(self, task_def):
"""
Insert a task to task_pool
Parameters
----------
task_def: dict
the task definition
Returns
-------
pymongo.results.InsertOneResult
"""
task = self._encode_task(
{
"def": task_def,
"filter": task_def, # FIXME: catch the raised error
"status": self.STATUS_WAITING,
}
)
insert_result = self.insert_task(task)
return insert_result
def create_task(self, task_def_l, dry_run=False, print_nt=False) -> List[str]:
"""
If the tasks in task_def_l are new, then insert new tasks into the task_pool, and record inserted_id.
If a task is not new, then just query its _id.
Parameters
----------
task_def_l: list
a list of task
dry_run: bool
if insert those new tasks to task pool
print_nt: bool
if print new task
Returns
-------
List[str]
a list of the _id of task_def_l
"""
new_tasks = []
_id_list = []
for t in task_def_l:
try:
r = self.task_pool.find_one({"filter": t})
except InvalidDocument:
r = self.task_pool.find_one({"filter": self._dict_to_str(t)})
if r is None:
new_tasks.append(t)
if not dry_run:
insert_result = self.insert_task_def(t)
_id_list.append(insert_result.inserted_id)
else:
_id_list.append(None)
else:
_id_list.append(self._decode_task(r)["_id"])
self.logger.info(f"Total Tasks: {len(task_def_l)}, New Tasks: {len(new_tasks)}")
if print_nt: # print new task
for t in new_tasks:
print(t)
if dry_run:
return []
return _id_list
def fetch_task(self, query={}, status=STATUS_WAITING) -> dict:
"""
Use query to fetch tasks.
Args:
query (dict, optional): query dict. Defaults to {}.
status (str, optional): [description]. Defaults to STATUS_WAITING.
Returns:
dict: a task(document in collection) after decoding
"""
query = query.copy()
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
query.update({"status": status})
task = self.task_pool.find_one_and_update(
query, {"$set": {"status": self.STATUS_RUNNING}}, sort=[("priority", pymongo.DESCENDING)]
)
# null will be at the top after sorting when using ASCENDING, so the larger the number higher, the higher the priority
if task is None:
return None
task["status"] = self.STATUS_RUNNING
return self._decode_task(task)
@contextmanager
def safe_fetch_task(self, query={}, status=STATUS_WAITING):
"""
Fetch task from task_pool using query with contextmanager
Parameters
----------
query: dict
the dict of query
Returns
-------
dict: a task(document in collection) after decoding
"""
task = self.fetch_task(query=query, status=status)
try:
yield task
except Exception:
if task is not None:
self.logger.info("Returning task before raising error")
self.return_task(task)
self.logger.info("Task returned")
raise
def task_fetcher_iter(self, query={}):
while True:
with self.safe_fetch_task(query=query) as task:
if task is None:
break
yield task
def query(self, query={}, decode=True):
"""
Query task in collection.
This function may raise exception `pymongo.errors.CursorNotFound: cursor id not found` if it takes too long to iterate the generator
Parameters
----------
query: dict
the dict of query
decode: bool
Returns
-------
dict: a task(document in collection) after decoding
"""
query = query.copy()
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
for t in self.task_pool.find(query):
yield self._decode_task(t)
def re_query(self, _id):
"""
Use _id to query task.
Args:
_id (str): _id of a document
Returns:
dict: a task(document in collection) after decoding
"""
t = self.task_pool.find_one({"_id": ObjectId(_id)})
return self._decode_task(t)
def commit_task_res(self, task, res, status=STATUS_DONE):
"""
Commit the result to task['res'].
Args:
task ([type]): [description]
res (object): the result you want to save
status (str, optional): STATUS_WAITING, STATUS_RUNNING, STATUS_DONE, STATUS_PART_DONE. Defaults to STATUS_DONE.
"""
# A workaround to use the class attribute.
if status is None:
status = TaskManager.STATUS_DONE
self.task_pool.update_one({"_id": task["_id"]}, {"$set": {"status": status, "res": Binary(pickle.dumps(res))}})
def return_task(self, task, status=STATUS_WAITING):
"""
Return a task to status. Alway using in error handling.
Args:
task ([type]): [description]
status (str, optional): STATUS_WAITING, STATUS_RUNNING, STATUS_DONE, STATUS_PART_DONE. Defaults to STATUS_WAITING.
"""
if status is None:
status = TaskManager.STATUS_WAITING
update_dict = {"$set": {"status": status}}
self.task_pool.update_one({"_id": task["_id"]}, update_dict)
def remove(self, query={}):
"""
Remove the task using query
Parameters
----------
query: dict
the dict of query
"""
query = query.copy()
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
self.task_pool.delete_many(query)
def task_stat(self, query={}) -> dict:
"""
Count the tasks in every status.
Args:
query (dict, optional): the query dict. Defaults to {}.
Returns:
dict
"""
query = query.copy()
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
tasks = self.query(query=query, decode=False)
status_stat = {}
for t in tasks:
status_stat[t["status"]] = status_stat.get(t["status"], 0) + 1
return status_stat
def reset_waiting(self, query={}):
"""
Reset all running task into waiting status. Can be used when some running task exit unexpected.
Args:
query (dict, optional): the query dict. Defaults to {}.
"""
query = query.copy()
# default query
if "status" not in query:
query["status"] = self.STATUS_RUNNING
return self.reset_status(query=query, status=self.STATUS_WAITING)
def reset_status(self, query, status):
query = query.copy()
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
print(self.task_pool.update_many(query, {"$set": {"status": status}}))
def prioritize(self, task, priority: int):
"""
Set priority for task
Parameters
----------
task : dict
The task query from the database
priority : int
the target priority
"""
update_dict = {"$set": {"priority": priority}}
self.task_pool.update_one({"_id": task["_id"]}, update_dict)
def _get_undone_n(self, task_stat):
return task_stat.get(self.STATUS_WAITING, 0) + task_stat.get(self.STATUS_RUNNING, 0)
def _get_total(self, task_stat):
return sum(task_stat.values())
def wait(self, query={}):
task_stat = self.task_stat(query)
total = self._get_total(task_stat)
last_undone_n = self._get_undone_n(task_stat)
with tqdm(total=total, initial=total - last_undone_n) as pbar:
while True:
time.sleep(10)
undone_n = self._get_undone_n(self.task_stat(query))
pbar.update(last_undone_n - undone_n)
last_undone_n = undone_n
if undone_n == 0:
break
def __str__(self):
return f"TaskManager({self.task_pool})"
def run_task(
task_func: Callable,
task_pool: str,
query: dict = {},
force_release: bool = False,
before_status: str = TaskManager.STATUS_WAITING,
after_status: str = TaskManager.STATUS_DONE,
**kwargs,
):
"""
While the task pool is not empty (has WAITING tasks), use task_func to fetch and run tasks in task_pool
After running this method, here are 4 situations (before_status -> after_status):
STATUS_WAITING -> STATUS_DONE: use task["def"] as `task_func` param
STATUS_WAITING -> STATUS_PART_DONE: use task["def"] as `task_func` param
STATUS_PART_DONE -> STATUS_PART_DONE: use task["res"] as `task_func` param
STATUS_PART_DONE -> STATUS_DONE: use task["res"] as `task_func` param
Parameters
----------
task_func : Callable
def (task_def, **kwargs) -> <res which will be committed>
the function to run the task
task_pool : str
the name of the task pool (Collection in MongoDB)
query: dict
will use this dict to query task_pool when fetching task
force_release : bool
will the program force to release the resource
before_status : str:
the tasks in before_status will be fetched and trained. Can be STATUS_WAITING, STATUS_PART_DONE.
after_status : str:
the tasks after trained will become after_status. Can be STATUS_WAITING, STATUS_PART_DONE.
kwargs
the params for `task_func`
"""
tm = TaskManager(task_pool)
ever_run = False
while True:
with tm.safe_fetch_task(status=before_status, query=query) as task:
if task is None:
break
get_module_logger("run_task").info(task["def"])
# when fetching `WAITING` task, use task["def"] to train
if before_status == TaskManager.STATUS_WAITING:
param = task["def"]
# when fetching `PART_DONE` task, use task["res"] to train because the middle result has been saved to task["res"]
elif before_status == TaskManager.STATUS_PART_DONE:
param = task["res"]
else:
raise ValueError("The fetched task must be `STATUS_WAITING` or `STATUS_PART_DONE`!")
if force_release:
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
res = executor.submit(task_func, param, **kwargs).result()
else:
res = task_func(param, **kwargs)
tm.commit_task_res(task, res, status=after_status)
ever_run = True
return ever_run
if __name__ == "__main__":
# This is for using it in cmd
# E.g. : `python -m qlib.workflow.task.manage list`
auto_init()
fire.Fire(TaskManager)

258
qlib/workflow/task/utils.py Normal file
View File

@@ -0,0 +1,258 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
Some tools for task management.
"""
import bisect
import pandas as pd
from qlib.data import D
from qlib.workflow import R
from qlib.config import C
from qlib.log import get_module_logger
from pymongo import MongoClient
from pymongo.database import Database
from typing import Union
def get_mongodb() -> Database:
"""
Get database in MongoDB, which means you need to declare the address and the name of a database at first.
For example:
Using qlib.init():
mongo_conf = {
"task_url": task_url, # your MongoDB url
"task_db_name": task_db_name, # database name
}
qlib.init(..., mongo=mongo_conf)
After qlib.init():
C["mongo"] = {
"task_url" : "mongodb://localhost:27017/",
"task_db_name" : "rolling_db"
}
Returns:
Database: the Database instance
"""
try:
cfg = C["mongo"]
except KeyError:
get_module_logger("task").error("Please configure `C['mongo']` before using TaskManager")
raise
client = MongoClient(cfg["task_url"])
return client.get_database(name=cfg["task_db_name"])
def list_recorders(experiment, rec_filter_func=None):
"""
List all recorders which can pass the filter in an experiment.
Args:
experiment (str or Experiment): the name of an Experiment or an instance
rec_filter_func (Callable, optional): return True to retain the given recorder. Defaults to None.
Returns:
dict: a dict {rid: recorder} after filtering.
"""
if isinstance(experiment, str):
experiment = R.get_exp(experiment_name=experiment)
recs = experiment.list_recorders()
recs_flt = {}
for rid, rec in recs.items():
if rec_filter_func is None or rec_filter_func(rec):
recs_flt[rid] = rec
return recs_flt
class TimeAdjuster:
"""
Find appropriate date and adjust date.
"""
def __init__(self, future=True, end_time=None):
self._future = future
self.cals = D.calendar(future=future, end_time=end_time)
def set_end_time(self, end_time=None):
"""
Set end time. None for use calendar's end time.
Args:
end_time
"""
self.cals = D.calendar(future=self._future, end_time=end_time)
def get(self, idx: int):
"""
Get datetime by index.
Parameters
----------
idx : int
index of the calendar
"""
if idx >= len(self.cals):
return None
return self.cals[idx]
def max(self) -> pd.Timestamp:
"""
Return the max calendar datetime
"""
return max(self.cals)
def align_idx(self, time_point, tp_type="start") -> int:
"""
Align the index of time_point in the calendar.
Parameters
----------
time_point
tp_type : str
Returns
-------
index : int
"""
time_point = pd.Timestamp(time_point)
if tp_type == "start":
idx = bisect.bisect_left(self.cals, time_point)
elif tp_type == "end":
idx = bisect.bisect_right(self.cals, time_point) - 1
else:
raise NotImplementedError(f"This type of input is not supported")
return idx
def cal_interval(self, time_point_A, time_point_B) -> int:
"""
Calculate the trading day interval (time_point_A - time_point_B)
Args:
time_point_A : time_point_A
time_point_B : time_point_B (is the past of time_point_A)
Returns:
int: the interval between A and B
"""
return self.align_idx(time_point_A) - self.align_idx(time_point_B)
def align_time(self, time_point, tp_type="start") -> pd.Timestamp:
"""
Align time_point to trade date of calendar
Args:
time_point
Time point
tp_type : str
time point type (`"start"`, `"end"`)
Returns:
pd.Timestamp
"""
return self.cals[self.align_idx(time_point, tp_type=tp_type)]
def align_seg(self, segment: Union[dict, tuple]) -> Union[dict, tuple]:
"""
Align the given date to the trade date
for example:
.. code-block:: python
input: {'train': ('2008-01-01', '2014-12-31'), 'valid': ('2015-01-01', '2016-12-31'), 'test': ('2017-01-01', '2020-08-01')}
output: {'train': (Timestamp('2008-01-02 00:00:00'), Timestamp('2014-12-31 00:00:00')),
'valid': (Timestamp('2015-01-05 00:00:00'), Timestamp('2016-12-30 00:00:00')),
'test': (Timestamp('2017-01-03 00:00:00'), Timestamp('2020-07-31 00:00:00'))}
Parameters
----------
segment
Returns
-------
Union[dict, tuple]: the start and end trade date (pd.Timestamp) between the given start and end date.
"""
if isinstance(segment, dict):
return {k: self.align_seg(seg) for k, seg in segment.items()}
elif isinstance(segment, tuple) or isinstance(segment, list):
return self.align_time(segment[0], tp_type="start"), self.align_time(segment[1], tp_type="end")
else:
raise NotImplementedError(f"This type of input is not supported")
def truncate(self, segment: tuple, test_start, days: int) -> tuple:
"""
Truncate the segment based on the test_start date
Parameters
----------
segment : tuple
time segment
test_start
days : int
The trading days to be truncated
the data in this segment may need 'days' data
Returns
---------
tuple: new segment
"""
test_idx = self.align_idx(test_start)
if isinstance(segment, tuple):
new_seg = []
for time_point in segment:
tp_idx = min(self.align_idx(time_point), test_idx - days)
assert tp_idx > 0
new_seg.append(self.get(tp_idx))
return tuple(new_seg)
else:
raise NotImplementedError(f"This type of input is not supported")
SHIFT_SD = "sliding"
SHIFT_EX = "expanding"
def shift(self, seg: tuple, step: int, rtype=SHIFT_SD) -> tuple:
"""
Shift the datatime of segment
Parameters
----------
seg :
datetime segment
step : int
rolling step
rtype : str
rolling type ("sliding" or "expanding")
Returns
--------
tuple: new segment
Raises
------
KeyError:
shift will raise error if the index(both start and end) is out of self.cal
"""
if isinstance(seg, tuple):
start_idx, end_idx = self.align_idx(seg[0], tp_type="start"), self.align_idx(seg[1], tp_type="end")
if rtype == self.SHIFT_SD:
start_idx += step
end_idx += step
elif rtype == self.SHIFT_EX:
end_idx += step
else:
raise NotImplementedError(f"This type of input is not supported")
if start_idx > len(self.cals):
raise KeyError("The segment is out of valid calendar")
return self.get(start_idx), self.get(end_idx)
else:
raise NotImplementedError(f"This type of input is not supported")

View File

@@ -55,7 +55,9 @@ REQUIRED = [
"tornado",
"joblib>=0.17.0",
"ruamel.yaml>=0.16.12",
"pymongo==3.7.2", # For task management
"scikit-learn>=0.22",
"dill",
]
# Numpy include