1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-04 19:41:00 +08:00

Merge branch 'online_srv' into online_srv_blin

This commit is contained in:
you-n-g
2021-05-07 21:07:27 +08:00
committed by GitHub
55 changed files with 2041 additions and 1022 deletions

View File

@@ -1,4 +1,4 @@
.. _task_managment:
.. _task_management:
=================================
Task Management
@@ -10,15 +10,17 @@ Introduction
=============
The `Workflow <../component/introduction.html>`_ part introduces how to run research workflow in a loosely-coupled way. But it can only execute one ``task`` when you use ``qrun``.
To automatically generate and execute different tasks, ``Task Management`` provides a whole process including `Task Generating`_, `Task Storing`_, `Task Running`_ and `Task Collecting`_.
To automatically generate and execute different tasks, ``Task Management`` provides a whole process including `Task Generating`_, `Task Storing`_, `Task Training`_ and `Task Collecting`_.
With this module, users can run their ``task`` automatically at different periods, in different losses, or even by different models.
An example of the entire process is shown `here <https://github.com/microsoft/qlib/tree/main/examples/taskmanager/task_manager_rolling.py>`_.
This whole process can be used in `Online Serving <../component/online.html>`_.
An example of the entire process is shown `here <https://github.com/microsoft/qlib/tree/main/examples/model_rolling/task_manager_rolling.py>`_.
Task Generating
===============
A ``task`` consists of `Model`, `Dataset`, `Record` or anything added by users.
The specific task template(/definition/config) can be viewed in
The specific task template can be viewed in
`Task Section <../component/workflow.html#task-section>`_.
Even though the task template is fixed, users can customize their ``TaskGen`` to generate different ``task`` by task template.
@@ -27,15 +29,16 @@ Here is the base class of ``TaskGen``:
.. autoclass:: qlib.workflow.task.gen.TaskGen
:members:
``Qlib`` provider a class `RollingGen <https://github.com/microsoft/qlib/tree/main/qlib/workflow/task/gen.py>`_ to generate a list of ``task`` of the dataset in different date segments.
This class allows users to verify the effect of data from different periods on the model in one experiment.
``Qlib`` provides a class `RollingGen <https://github.com/microsoft/qlib/tree/main/qlib/workflow/task/gen.py>`_ to generate a list of ``task`` of the dataset in different date segments.
This class allows users to verify the effect of data from different periods on the model in one experiment. More information in `here <../reference/api.html#TaskGen>`_.
Task Storing
===============
To achieve higher efficiency and the possibility of cluster operation, ``Task Manager`` will store all tasks in `MongoDB <https://www.mongodb.com/>`_.
``TaskManager`` can fetch undone tasks automatically and manage the lifecycle of a set of tasks with error handling.
Users **MUST** finished the configuration of `MongoDB <https://www.mongodb.com/>`_ when using this module.
Users need to provide the URL and database name of ``task`` storing like this.
Users need to provide the MongoDB URL and database name for using ``TaskManager`` in `initialization <../start/initialization.html#Parameters>`_ or make statement like this.
.. code-block:: python
@@ -45,13 +48,12 @@ Users need to provide the URL and database name of ``task`` storing like this.
"task_db_name" : "rolling_db" # database name
}
The CRUD methods of ``task`` can be found in TaskManager.
More methods can be seen in the `Github <https://github.com/microsoft/qlib/tree/main/qlib/workflow/task/manage.py>`_.
.. autoclass:: qlib.workflow.task.manage.TaskManager
:members:
Task Running
More information of ``Task Manager`` can be found in `here <../reference/api.html#TaskManager>`_.
Task Training
===============
After generating and storing those ``task``, it's time to run the ``task`` which are in the *WAITING* status.
``Qlib`` provides a method called ``run_task`` to run those ``task`` in task pool, however, users can also customize how tasks are executed.
@@ -60,14 +62,24 @@ It will run the whole workflow defined by ``task``, which includes *Model*, *Dat
.. autofunction:: qlib.workflow.task.manage.run_task
Meanwhile, ``Qlib`` provides a module called ``Trainer``.
``Trainer`` will train a list of tasks and return a list of model recorder.
``Qlib`` offer two kind of Trainer, TrainerR is the simplest way and TrainerRM is based on TaskManager to help manager tasks lifecycle automatically.
If you do not want to use ``Task Manager`` to manage tasks, then use TrainerR to train a list of tasks generated by ``TaskGen`` is enough.
More information is in `here <../reference/api.html#Trainer>`_.
Task Collecting
===============
To see the results of ``task`` after running or to update something, ``Qlib`` provides a ``TaskCollector`` to collect the tasks by filter condition (optional).
Here are some methods in this class.
To collect the results of ``task`` after training, ``Qlib`` provides `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_ to collect the results in a readable, expandable and loosely-coupled way.
.. autoclass:: qlib.workflow.task.collect.TaskCollector
:members:
`Collector <../reference/api.html#Collector>`_ can collect object from everywhere and process them such as merging, grouping, averaging and so on. It has 2 step action including ``collect`` (collect anything in a dict) and ``process_collect`` (process collected dict).
``Qlib`` provides a concrete `example <https://github.com/microsoft/qlib/tree/main/examples/taskmanager/task_manager_rolling_with_updating.py>`_, including a whole process of `Task Generating`_ (using `RollingGen <https://github.com/microsoft/qlib/tree/main/qlib/workflow/task/gen.py>`_), `Task Storing`_, `Task Running`_ and `Task Collecting`_.
Besides, the `example <https://github.com/microsoft/qlib/tree/main/examples/taskmanager/task_manager_rolling_with_updating.py>`_ uses a ``ModelUpdater`` inherited from ``TaskCollector``, which can update the inferences and retrain the model if it is out of date.
Actually, the model updating can be viewed as a subset of ``Online Serving``.
`Group <../reference/api.html#Group>`_ also has 2 steps including ``group`` (can group a set of object based on `group_func` and change them to a dict) and ``reduce`` (can make a dict become an ensemble based on some rule).
For example: {(A,B,C1): object, (A,B,C2): object} ---``group``---> {(A,B): {C1: object, C2: object}} ---``reduce``---> {(A,B): object}
`Ensemble <../reference/api.html#Ensemble>`_ can merge the objects in an ensemble.
For example: {C1: object, C2: object} ---``Ensemble``---> object
So the hierarchy is ``Collector``'s second step correspond to ``Group``. And ``Group``'s second step correspond to ``Ensemble``.
For more information, please see `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_, or the `example <https://github.com/microsoft/qlib/tree/main/examples/model_rolling/task_manager_rolling.py>`_

View File

@@ -182,6 +182,11 @@ The `trade unit` defines the unit number of stocks can be used in a trade, and t
qlib.init(provider_uri='~/.qlib/qlib_data/us_data', region=REG_US)
.. note::
PRs for new data source are highly welcome! Users could commit the code to crawl data as a PR like `the examples here <https://github.com/microsoft/qlib/tree/main/scripts>`_. And then we will use the code to create data cache on our server which other users could use directly.
Data API
========================

41
docs/component/online.rst Normal file
View File

@@ -0,0 +1,41 @@
.. _online:
=================================
Online Serving
=================================
.. currentmodule:: qlib
Introduction
=============
In addition to backtesting, one way to test a model is effective is to make predictions in real market conditions or even do real trading based on those predictions.
``Online Serving`` is a set of module for online models using latest data,
which including `Online Manager <#Online Manager>`_, `Online Strategy <#Online Strategy>`_, `Online Tool <#Online Tool>`_, `Updater <#Updater>`_.
`Here <https://github.com/microsoft/qlib/tree/main/examples/online_srv>`_ are several examples for reference, which demonstrate different features of ``Online Serving``.
If you have many models or `task` need to be managed, please consider `Task Management <../advanced/task_management.html>`_.
The `examples <https://github.com/microsoft/qlib/tree/main/examples/online_srv>`_ maybe based on `Task Management <../advanced/task_management.html>`_ such as ``TrainerRM`` or ``Collector``.
Online Manager
=============
.. automodule:: qlib.workflow.online.manager
:members:
Online Strategy
=============
.. automodule:: qlib.workflow.online.strategy
:members:
Online Tool
=============
.. automodule:: qlib.workflow.online.utils
:members:
Updater
=============
.. automodule:: qlib.workflow.online.update
:members:

View File

@@ -42,6 +42,7 @@ Document Structure
Intraday Trading: Model&Strategy Testing <component/backtest.rst>
Qlib Recorder: Experiment Management <component/recorder.rst>
Analysis: Evaluation & Results Analysis <component/report.rst>
Online Serving: Online Management & Strategy & Tool <component/online.rst>
.. toctree::
:maxdepth: 3

View File

@@ -154,36 +154,71 @@ Record Template
.. automodule:: qlib.workflow.record_temp
:members:
Task Management
====================
RollingGen
TaskGen
--------------------
.. autoclass:: qlib.workflow.task.gen.RollingGen
.. automodule:: qlib.workflow.task.gen
:members:
TaskManager
--------------------
.. autoclass:: qlib.workflow.task.manage.TaskManager
.. automodule:: qlib.workflow.task.manage
:members:
TaskCollector
Trainer
--------------------
.. autoclass:: qlib.workflow.task.collect.TaskCollector
.. automodule:: qlib.model.trainer
:members:
ModelUpdater
Collector
--------------------
.. autoclass:: qlib.workflow.task.update.ModelUpdater
.. automodule:: qlib.workflow.task.collect
:members:
TimeAdjuster
Group
--------------------
.. autoclass:: qlib.workflow.task.utils.TimeAdjuster
.. automodule:: qlib.model.ens.group
:members:
Ensemble
--------------------
.. automodule:: qlib.model.ens.ensemble
:members:
Utils
--------------------
.. automodule:: qlib.workflow.task.utils
:members:
Online Serving
====================
Online Manager
--------------------
.. automodule:: qlib.workflow.online.manager
:members:
Online Strategy
--------------------
.. automodule:: qlib.workflow.online.strategy
:members:
Online Tool
--------------------
.. automodule:: qlib.workflow.online.utils
:members:
RecordUpdater
--------------------
.. automodule:: qlib.workflow.online.update
:members:
Utils
====================

View File

@@ -17,6 +17,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
| ALSTM (Yao Qin, et al.) | Alpha360 | 0.0493±0.01 | 0.3778±0.06| 0.0585±0.00 | 0.4606±0.04 | 0.0513±0.03 | 0.6727±0.38| -0.1085±0.02 |
| GATs (Petar Velickovic, et al.) | Alpha360 | 0.0475±0.00 | 0.3515±0.02| 0.0592±0.00 | 0.4585±0.01 | 0.0876±0.02 | 1.1513±0.27| -0.0795±0.02 |
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha360 | 0.0407±0.00| 0.3053±0.00 | 0.0490±0.00 | 0.3840±0.00 | 0.0380±0.02 | 0.5000±0.21 | -0.0984±0.02 |
| TabNet (Sercan O. Arik, et al.)| Alpha360 | 0.0192±0.00 | 0.1401±0.00| 0.0291±0.00 | 0.2163±0.00 | -0.0258±0.00 | -0.2961±0.00| -0.1429±0.00 |
## Alpha158 dataset
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
@@ -32,6 +33,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
| ALSTM (Yao Qin, et al.) | Alpha158 (with selected 20 features) | 0.0385±0.01 | 0.3022±0.06| 0.0478±0.00 | 0.3874±0.04 | 0.0486±0.03 | 0.7141±0.45| -0.1088±0.03 |
| GATs (Petar Velickovic, et al.) | Alpha158 (with selected 20 features) | 0.0349±0.00 | 0.2511±0.01| 0.0457±0.00 | 0.3537±0.01 | 0.0578±0.02 | 0.8221±0.25| -0.0824±0.02 |
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha158 | 0.0544±0.00 | 0.4338±0.01 | 0.0523±0.00 | 0.4257±0.01 | 0.1253±0.01 | 1.4105±0.14 | -0.0902±0.01 |
| TabNet (Sercan O. Arik, et al.)| Alpha158 | 0.0383±0.00 | 0.3414±0.00| 0.0388±0.00 | 0.3460±0.00 | 0.0226±0.00 | 0.2652±0.00| -0.1072±0.00 |
- The selected 20 features are based on the feature importance of a lightgbm-based model.
- The base model of DoubleEnsemble is LGBM.

View File

@@ -132,7 +132,7 @@ class GenericDataFormatter(abc.ABC):
return -1, -1
def get_column_definition(self):
""""Returns formatted column definition in order expected by the TFT."""
"""Returns formatted column definition in order expected by the TFT."""
column_definition = self._column_definition

View File

@@ -25,4 +25,11 @@ The example is given in `workflow.py`, users can run the code as follows.
Run the example by running the following command:
```bash
python workflow.py dump_and_load_dataset
```
```
## Benchmarks Performance
### Signal Test
Here are the results of signal test for benchmark models. We will keep updating benchmark models in future.
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Long precision| Short Precision | Long-Short Average Return | Long-Short Average Sharpe |
|---|---|---|---|---|---|---|---|---|---|
| LightGBM | Alpha158 | 0.3042±0.00 | 1.5372±0.00| 0.3117±0.00 | 1.6258±0.00 | 0.6720±0.00 | 0.6870±0.00 | 0.000769±0.00 | 1.0190±0.00 |

View File

@@ -0,0 +1,65 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data_1min"
region: cn
market: &market 'csi300'
start_time: &start_time "2020-09-15 00:00:00"
end_time: &end_time "2021-01-18 16:00:00"
train_end_time: &train_end_time "2020-11-15 16:00:00"
valid_start_time: &valid_start_time "2020-11-16 00:00:00"
valid_end_time: &valid_end_time "2020-11-30 16:00:00"
test_start_time: &test_start_time "2020-12-01 00:00:00"
data_handler_config: &data_handler_config
start_time: *start_time
end_time: *end_time
fit_start_time: *start_time
fit_end_time: *train_end_time
instruments: *market
freq: '1min'
infer_processors:
- class: 'RobustZScoreNorm'
kwargs:
fields_group: 'feature'
clip_outlier: false
- class: "Fillna"
kwargs:
fields_group: 'feature'
learn_processors:
- class: 'DropnaLabel'
- class: 'CSRankNorm'
kwargs:
fields_group: 'label'
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
task:
model:
class: "HFLGBModel"
module_path: "qlib.contrib.model.highfreq_gdbt_model"
kwargs:
objective: 'binary'
metric: ['binary_logloss','auc']
verbosity: -1
learning_rate: 0.01
max_depth: 8
num_leaves: 150
lambda_l1: 1.5
lambda_l2: 1
num_threads: 20
dataset:
class: "DatasetH"
module_path: "qlib.data.dataset"
kwargs:
handler:
class: "Alpha158"
module_path: "qlib.contrib.data.handler"
kwargs: *data_handler_config
segments:
train: [*start_time, *train_end_time]
valid: [*train_end_time, *valid_end_time]
test: [*test_start_time, *end_time]
record:
- class: "SignalRecord"
module_path: "qlib.workflow.record_temp"
kwargs: {}
- class: "HFSignalRecord"
module_path: "qlib.workflow.record_temp"
kwargs: {}

View File

@@ -1,24 +1,23 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
This example shows how a TrainerRM work based on TaskManager with rolling tasks.
After training, how to collect the rolling results will be showed in task_collecting.
"""
from pprint import pprint
import time
import fire
import qlib
from qlib.config import REG_CN
from qlib.model.trainer import TrainerR, task_train
from qlib.workflow import R
from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.manage import TaskManager, run_task
from qlib.workflow.task.manage import TaskManager
from qlib.workflow.task.collect import RecorderCollector
from qlib.model.ens.ensemble import RollingEnsemble, ens_workflow
import pandas as pd
from qlib.workflow.task.utils import list_recorders
from qlib.model.ens.group import RollingGroup
from qlib.model.trainer import TrainerRM
"""
This example shows how a Trainer work based on TaskManager with rolling tasks.
After training, how to collect the rolling results will be showed in task_collecting.
"""
data_handler_config = {
"start_time": "2008-01-01",
@@ -139,11 +138,13 @@ class RollingTaskExample:
return True
return False
artifact = ens_workflow(
RecorderCollector(experiment=self.experiment_name, rec_key_func=rec_key, rec_filter_func=my_filter),
RollingGroup(),
collector = RecorderCollector(
experiment=self.experiment_name,
process_list=RollingGroup(),
rec_key_func=rec_key,
rec_filter_func=my_filter,
)
print(artifact)
print(collector())
def main(self):
self.reset()

View File

@@ -1,21 +1,18 @@
import fire
import qlib
from qlib.model.ens.ensemble import ens_workflow
from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerRM
from qlib.workflow import R
from qlib.workflow.online.manager import RollingOnlineManager
from qlib.workflow.online.simulator import OnlineSimulator
from qlib.workflow.task.collect import RecorderCollector
from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.manage import TaskManager
from qlib.workflow.task.utils import list_recorders
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
This examples is about the OnlineManager and OnlineSimulator based on rolling tasks.
The OnlineManager will focus on the updating of your online models.
The OnlineSimulator will focus on the simulating real updating routine of your online models.
This examples is about how can simulate the OnlineManager based on rolling tasks.
"""
import fire
import qlib
from qlib.model.trainer import DelayTrainerRM
from qlib.workflow.online.manager import OnlineManager
from qlib.workflow.online.strategy import RollingAverageStrategy
from qlib.workflow.task.gen import RollingGen
from qlib.workflow.task.manage import TaskManager
data_handler_config = {
"start_time": "2018-01-01",
@@ -86,10 +83,10 @@ class OnlineSimulationExample:
rolling_step=80,
start_time="2018-09-10",
end_time="2018-10-31",
tasks=[task_xgboost_config], # , task_lgb_config]
tasks=[task_xgboost_config, task_lgb_config],
):
"""
init OnlineManagerExample.
Init OnlineManagerExample.
Args:
provider_uri (str, optional): the provider uri. Defaults to "~/.qlib/qlib_data/cn_data".
@@ -105,6 +102,8 @@ class OnlineSimulationExample:
"""
self.exp_name = exp_name
self.task_pool = task_pool
self.start_time = start_time
self.end_time = end_time
mongo_conf = {
"task_url": task_url,
"task_db_name": task_db_name,
@@ -115,62 +114,30 @@ class OnlineSimulationExample:
) # The rolling tasks generator, modify_end_time is false because we just need simulate to 2018-10-31.
self.trainer = DelayTrainerRM(self.exp_name, self.task_pool)
self.task_manager = TaskManager(self.task_pool) # A good way to manage all your tasks
self.rolling_online_manager = RollingOnlineManager(
experiment_name=exp_name,
rolling_gen=self.rolling_gen,
trainer=self.trainer,
self.rolling_online_manager = OnlineManager(
RollingAverageStrategy(
exp_name, task_template=tasks, rolling_gen=self.rolling_gen, trainer=self.trainer, need_log=False
),
begin_time=self.start_time,
need_log=False,
) # The OnlineManager based on Rolling
self.onlinesimulator = OnlineSimulator(
start_time=start_time,
end_time=end_time,
online_manager=self.rolling_online_manager,
)
self.tasks = tasks
# Reset all things to the first status, be careful to save important data
def reset(self):
print("========== reset ==========")
self.task_manager.remove()
exp = R.get_exp(experiment_name=self.exp_name)
for rid in exp.list_recorders():
exp.delete_recorder(rid)
for rid in list_recorders(
RollingOnlineManager.SIGNAL_EXP, lambda x: True if x.info["name"] == self.exp_name else False
):
exp.delete_recorder(rid)
# Run this firstly to see the workflow in OnlineManager
def first_train(self):
print("========== first train ==========")
self.reset()
self.rolling_online_manager.first_train(self.tasks)
# Run this secondly to see the simulating in OnlineSimulator
def simulate(self):
print("========== simulate ==========")
self.onlinesimulator.simulate()
print(self.rolling_online_manager.collect_artifact())
print("========== online models ==========")
recs_dict = self.onlinesimulator.online_models()
for time, recs in recs_dict.items():
print(f"{str(time[0])} to {str(time[1])}:")
for rec in recs:
print(rec.info["id"])
print("========== online signals ==========")
print(self.rolling_online_manager.get_signals())
# Run this to run all workflow automaticly
# Run this to run all workflow automatically
def main(self):
self.first_train()
self.simulate()
print("========== reset ==========")
self.rolling_online_manager.reset()
print("========== simulate ==========")
self.rolling_online_manager.simulate(end_time=self.end_time)
print("========== collect results ==========")
print(self.rolling_online_manager.get_collector()())
print("========== signals ==========")
print(self.rolling_online_manager.get_signals())
print("========== online history ==========")
print(self.rolling_online_manager.get_online_history(self.exp_name))
if __name__ == "__main__":
## to run all workflow automaticly with your own parameters, use the command below
## to run all workflow automatically with your own parameters, use the command below
# python online_management_simulate.py main --experiment_name="your_exp_name" --rolling_step=60
fire.Fire(OnlineSimulationExample)

View File

@@ -1,22 +1,26 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
This example show how OnlineManager works with rolling tasks.
There are two parts including first train and routine.
Firstly, the OnlineManager will finish the first training and set trained models to `online` models.
Next, the OnlineManager will finish a routine process, including update online prediction -> prepare signals -> prepare tasks -> prepare new models -> reset online models
"""
import os
from pathlib import Path
import pickle
import fire
import qlib
from qlib.workflow import R
from qlib.workflow.online.strategy import RollingAverageStrategy
from qlib.workflow.task.gen import RollingGen
from qlib.workflow.task.manage import TaskManager
from qlib.workflow.online.manager import RollingOnlineManager
from qlib.workflow.online.manager import OnlineManager
from qlib.workflow.task.utils import list_recorders
from qlib.model.trainer import TrainerRM
"""
This example show how RollingOnlineManager works with rolling tasks.
There are two parts including first train and routine.
Firstly, the RollingOnlineManager will finish the first training and set trained models to `online` models.
Next, the RollingOnlineManager will finish a routine process, including update online prediction -> prepare signals -> prepare tasks -> prepare new models -> reset online models
"""
data_handler_config = {
"start_time": "2013-01-01",
"end_time": "2020-09-25",
@@ -77,58 +81,75 @@ task_xgboost_config = {
class RollingOnlineExample:
def __init__(
self,
exp_name="rolling_exp",
task_pool="rolling_task",
provider_uri="~/.qlib/qlib_data/cn_data",
region="cn",
task_url="mongodb://10.0.0.4:27017/",
task_db_name="rolling_db",
rolling_step=550,
tasks=[task_xgboost_config], # , task_lgb_config],
):
self.exp_name = exp_name
self.task_pool = task_pool
mongo_conf = {
"task_url": task_url, # your MongoDB url
"task_db_name": task_db_name, # database name
}
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
self.rolling_online_manager = RollingOnlineManager(
experiment_name=exp_name,
rolling_gen=RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD),
trainer=TrainerRM(self.exp_name, self.task_pool),
)
self.tasks = tasks
self.rolling_step = rolling_step
strategy = []
for task in tasks:
name_id = task["model"]["class"] # NOTE: Assumption: The model class can specify only one strategy
strategy.append(
RollingAverageStrategy(
name_id,
task,
RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD),
TrainerRM(experiment_name=name_id, task_pool=name_id),
)
)
_ROLLING_MANAGER_PATH = ".rolling_manager" # the RollingOnlineManager will dump to this file, for it will be loaded when calling routine.
self.rolling_online_manager = OnlineManager(strategy)
self.collector = self.rolling_online_manager.get_collector()
_ROLLING_MANAGER_PATH = (
".RollingOnlineExample" # the OnlineManager will dump to this file, for it can be loaded when calling routine.
)
# Reset all things to the first status, be careful to save important data
def reset(self):
print("========== reset ==========")
TaskManager(self.task_pool).remove()
exp = R.get_exp(experiment_name=self.exp_name)
for rid in exp.list_recorders():
exp.delete_recorder(rid)
for task in self.tasks:
name_id = task["model"]["class"] + "_" + str(self.rolling_step)
TaskManager(name_id).remove()
exp = R.get_exp(experiment_name=name_id)
for rid in exp.list_recorders():
exp.delete_recorder(rid)
if os.path.exists(self._ROLLING_MANAGER_PATH):
os.remove(self._ROLLING_MANAGER_PATH)
if os.path.exists(self._ROLLING_MANAGER_PATH):
os.remove(self._ROLLING_MANAGER_PATH)
for rid in list_recorders(
RollingOnlineManager.SIGNAL_EXP, lambda x: True if x.info["name"] == self.exp_name else False
):
exp.delete_recorder(rid)
for rid in list_recorders("OnlineManagerSignals", lambda x: True if x.info["name"] == name_id else False):
exp.delete_recorder(rid)
def first_run(self):
print("========== reset ==========")
self.rolling_online_manager.reset()
print("========== first_run ==========")
self.reset()
self.rolling_online_manager.first_train([task_xgboost_config, task_lgb_config])
self.rolling_online_manager.first_train()
print("========== dump ==========")
self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)
print(self.rolling_online_manager.collect_artifact())
print("========== collect results ==========")
print(self.collector())
def routine(self):
print("========== routine ==========")
print("========== load ==========")
with Path(self._ROLLING_MANAGER_PATH).open("rb") as f:
self.rolling_online_manager = pickle.load(f)
print("========== routine ==========")
self.rolling_online_manager.routine()
print(self.rolling_online_manager.collect_artifact())
print("========== collect results ==========")
print(self.collector())
print("========== signals ==========")
print(self.rolling_online_manager.get_signals())
def main(self):
self.first_run()
@@ -137,11 +158,11 @@ class RollingOnlineExample:
if __name__ == "__main__":
####### to train the first version's models, use the command below
# python task_manager_rolling_with_updating.py first_run
# python rolling_online_management.py first_run
####### to update the models and predictions after the trading time, use the command below
# python task_manager_rolling_with_updating.py after_day
# python rolling_online_management.py after_day
####### to define your own parameters, use `--`
# python task_manager_rolling_with_updating.py first_run --exp_name='your_exp_name' --rolling_step=40
# python rolling_online_management.py first_run --exp_name='your_exp_name' --rolling_step=40
fire.Fire(RollingOnlineExample)

View File

@@ -1,16 +1,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
This example show how OnlineTool works when we need update prediction.
There are two parts including first_train and update_online_pred.
Firstly, we will finish the training and set the trained model to `online` model.
Next, we will finish updating online prediction.
"""
import fire
import qlib
from qlib.config import REG_CN
from qlib.model.trainer import task_train
from qlib.workflow.online.manager import OnlineManagerR
from qlib.workflow.task.utils import list_recorders
"""
This example show how OnlineManager works when we need update prediction.
There are two parts including first_train and update_online_pred.
Firstly, the RollingOnlineManager will finish the first training and set the trained model to `online` model.
Next, the RollingOnlineManager will finish updating online prediction
"""
from qlib.workflow.online.utils import OnlineToolR
data_handler_config = {
"start_time": "2008-01-01",
@@ -65,15 +66,15 @@ class UpdatePredExample:
):
qlib.init(provider_uri=provider_uri, region=region)
self.experiment_name = experiment_name
self.online_manager = OnlineManagerR(self.experiment_name)
self.online_tool = OnlineToolR(self.experiment_name)
self.task_config = task_config
def first_train(self):
rec = task_train(self.task_config, experiment_name=self.experiment_name)
self.online_manager.reset_online_tag(rec) # set to online model
self.online_tool.reset_online_tag(rec) # set to online model
def update_online_pred(self):
self.online_manager.update_online_pred()
self.online_tool.update_online_pred()
def main(self):
self.first_train()

View File

@@ -15,7 +15,8 @@ LOG = get_module_logger("backtest")
def backtest(pred, strategy, executor, trade_exchange, shift, verbose, account, benchmark, return_order):
"""Parameters
"""
Parameters
----------
pred : pandas.DataFrame
predict should has <datetime, instrument> index and one `score` column
@@ -124,7 +125,9 @@ def backtest(pred, strategy, executor, trade_exchange, shift, verbose, account,
def update_account(trade_account, trade_info, trade_exchange, trade_date):
"""Update the account and strategy
"""
Update the account and strategy
Parameters
----------
trade_account : Account()

View File

@@ -128,7 +128,7 @@ class Position:
return self.position["cash"]
def get_stock_amount_dict(self):
"""generate stock amount dict {stock_id : amount of stock} """
"""generate stock amount dict {stock_id : amount of stock}"""
d = {}
stock_list = self.get_stock_list()
for stock_code in stock_list:

View File

@@ -8,6 +8,59 @@ import pandas as pd
from typing import Tuple
def calc_long_short_prec(
pred: pd.Series, label: pd.Series, date_col="datetime", quantile: float = 0.2, dropna=False, is_alpha=False
) -> Tuple[pd.Series, pd.Series]:
"""
calculate the precision for long and short operation
:param pred/label: index is **pd.MultiIndex**, index name is **[datetime, instruments]**; columns names is **[score]**.
.. code-block:: python
score
datetime instrument
2020-12-01 09:30:00 SH600068 0.553634
SH600195 0.550017
SH600276 0.540321
SH600584 0.517297
SH600715 0.544674
label :
label
date_col :
date_col
Returns
-------
(pd.Series, pd.Series)
long precision and short precision in time level
"""
if is_alpha:
label = label - label.mean(level=date_col)
if int(1 / quantile) >= len(label.index.get_level_values(1).unique()):
raise ValueError("Need more instruments to calculate precision")
df = pd.DataFrame({"pred": pred, "label": label})
if dropna:
df.dropna(inplace=True)
group = df.groupby(level=date_col)
N = lambda x: int(len(x) * quantile)
# find the top/low quantile of prediction and treat them as long and short target
long = group.apply(lambda x: x.nlargest(N(x), columns="pred").label).reset_index(level=0, drop=True)
short = group.apply(lambda x: x.nsmallest(N(x), columns="pred").label).reset_index(level=0, drop=True)
groupll = long.groupby(date_col)
l_dom = groupll.apply(lambda x: x > 0)
l_c = groupll.count()
groups = short.groupby(date_col)
s_dom = groups.apply(lambda x: x < 0)
s_c = groups.count()
return (l_dom.groupby(date_col).sum() / l_c), (s_dom.groupby(date_col).sum() / s_c)
def calc_ic(pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False) -> Tuple[pd.Series, pd.Series]:
"""calc_ic.

View File

@@ -0,0 +1,157 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import numpy as np
import pandas as pd
import lightgbm as lgb
from qlib.model.base import ModelFT
from qlib.data.dataset import DatasetH
from qlib.data.dataset.handler import DataHandlerLP
import warnings
class HFLGBModel(ModelFT):
"""LightGBM Model for high frequency prediction"""
def __init__(self, loss="mse", **kwargs):
if loss not in {"mse", "binary"}:
raise NotImplementedError
self.params = {"objective": loss, "verbosity": -1}
self.params.update(kwargs)
self.model = None
def _cal_signal_metrics(self, y_test, l_cut, r_cut):
"""
Calcaute the signal metrics by daily level
"""
up_pre, down_pre = [], []
up_alpha_ll, down_alpha_ll = [], []
for date in y_test.index.get_level_values(0).unique():
df_res = y_test.loc[date].sort_values("pred")
if int(l_cut * len(df_res)) < 10:
warnings.warn("Warning: threhold is too low or instruments number is not enough")
continue
top = df_res.iloc[: int(l_cut * len(df_res))]
bottom = df_res.iloc[int(r_cut * len(df_res)) :]
down_precision = len(top[top[top.columns[0]] < 0]) / (len(top))
up_precision = len(bottom[bottom[top.columns[0]] > 0]) / (len(bottom))
down_alpha = top[top.columns[0]].mean()
up_alpha = bottom[bottom.columns[0]].mean()
up_pre.append(up_precision)
down_pre.append(down_precision)
up_alpha_ll.append(up_alpha)
down_alpha_ll.append(down_alpha)
return (
np.array(up_pre).mean(),
np.array(down_pre).mean(),
np.array(up_alpha_ll).mean(),
np.array(down_alpha_ll).mean(),
)
def hf_signal_test(self, dataset: DatasetH, threhold=0.2):
"""
Test the sigal in high frequency test set
"""
if self.model == None:
raise ValueError("Model hasn't been trained yet")
df_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
df_test.dropna(inplace=True)
x_test, y_test = df_test["feature"], df_test["label"]
# Convert label into alpha
y_test[y_test.columns[0]] = y_test[y_test.columns[0]] - y_test[y_test.columns[0]].mean(level=0)
res = pd.Series(self.model.predict(x_test.values), index=x_test.index)
y_test["pred"] = res
up_p, down_p, up_a, down_a = self._cal_signal_metrics(y_test, threhold, 1 - threhold)
print("===============================")
print("High frequency signal test")
print("===============================")
print("Test set precision: ")
print("Positive precision: {}, Negative precision: {}".format(up_p, down_p))
print("Test Alpha Average in test set: ")
print("Positive average alpha: {}, Negative average alpha: {}".format(up_a, down_a))
def _prepare_data(self, dataset: DatasetH):
df_train, df_valid = dataset.prepare(
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
)
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_train["feature"], df_valid["label"]
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
l_name = df_train["label"].columns[0]
# Convert label into alpha
df_train["label"][l_name] = df_train["label"][l_name] - df_train["label"][l_name].mean(level=0)
df_valid["label"][l_name] = df_valid["label"][l_name] - df_valid["label"][l_name].mean(level=0)
mapping_fn = lambda x: 0 if x < 0 else 1
df_train["label_c"] = df_train["label"][l_name].apply(mapping_fn)
df_valid["label_c"] = df_valid["label"][l_name].apply(mapping_fn)
x_train, y_train = df_train["feature"], df_train["label_c"].values
x_valid, y_valid = df_valid["feature"], df_valid["label_c"].values
else:
raise ValueError("LightGBM doesn't support multi-label training")
dtrain = lgb.Dataset(x_train.values, label=y_train)
dvalid = lgb.Dataset(x_valid.values, label=y_valid)
return dtrain, dvalid
def fit(
self,
dataset: DatasetH,
num_boost_round=1000,
early_stopping_rounds=50,
verbose_eval=20,
evals_result=dict(),
**kwargs
):
dtrain, dvalid = self._prepare_data(dataset)
self.model = lgb.train(
self.params,
dtrain,
num_boost_round=num_boost_round,
valid_sets=[dtrain, dvalid],
valid_names=["train", "valid"],
early_stopping_rounds=early_stopping_rounds,
verbose_eval=verbose_eval,
evals_result=evals_result,
**kwargs
)
evals_result["train"] = list(evals_result["train"].values())[0]
evals_result["valid"] = list(evals_result["valid"].values())[0]
def predict(self, dataset):
if self.model is None:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
return pd.Series(self.model.predict(x_test.values), index=x_test.index)
def finetune(self, dataset: DatasetH, num_boost_round=10, verbose_eval=20):
"""
finetune model
Parameters
----------
dataset : DatasetH
dataset for finetuning
num_boost_round : int
number of round to finetune model
verbose_eval : int
verbose level
"""
# Based on existing model and finetune by train more rounds
dtrain, _ = self._prepare_data(dataset)
self.model = lgb.train(
self.params,
dtrain,
num_boost_round=num_boost_round,
init_model=self.model,
valid_sets=[dtrain],
valid_names=["train"],
verbose_eval=verbose_eval,
)

View File

@@ -214,7 +214,7 @@ def cumulative_return_graph(
features_df = D.features(D.instruments('csi500'), ['Ref($close, -1)/$close - 1'], pred_df_dates.min(), pred_df_dates.max())
features_df.columns = ['label']
qcr.cumulative_return_graph(positions, report_normal_df, features_df)
qcr.analysis_position.cumulative_return_graph(positions, report_normal_df, features_df)
Graph desc:

View File

@@ -94,7 +94,7 @@ def rank_label_graph(
features_df = D.features(D.instruments('csi500'), ['Ref($close, -1)/$close-1'], pred_df_dates.min(), pred_df_dates.max())
features_df.columns = ['label']
qcr.rank_label_graph(positions, features_df, pred_df_dates.min(), pred_df_dates.max())
qcr.analysis_position.rank_label_graph(positions, features_df, pred_df_dates.min(), pred_df_dates.max())
:param position: position data; **qlib.contrib.backtest.backtest.backtest** result.

View File

@@ -186,7 +186,7 @@ def report_graph(report_df: pd.DataFrame, show_notebook: bool = True) -> [list,
report_normal_df, _ = backtest(pred_df, strategy, **bparas)
qcr.report_graph(report_normal_df)
qcr.analysis_position.report_graph(report_normal_df)
:param report_df: **df.index.name** must be **date**, **df.columns** must contain **return**, **turnover**, **cost**, **bench**.

View File

@@ -18,7 +18,7 @@ from ...utils import get_module_by_module_path
class BaseGraph:
""""""
""" """
_name = None

View File

@@ -1,10 +1,11 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import logging
import pandas as pd
import numpy as np
from sklearn.metrics import mean_squared_error
from typing import Dict, Text, Any
import numpy as np
from ...contrib.eva.alpha import calc_ic
from ...workflow.record_temp import RecordTemp
@@ -12,7 +13,7 @@ from ...workflow.record_temp import SignalRecord
from ...data import dataset as qlib_dataset
from ...log import get_module_logger
logger = get_module_logger("workflow", "INFO")
logger = get_module_logger("workflow", logging.INFO)
class MultiSegRecord(RecordTemp):

View File

@@ -522,6 +522,9 @@ class LocalCalendarProvider(CalendarProvider):
# if future calendar not exists, return current calendar
if not os.path.exists(fname):
get_module_logger("data").warning(f"{freq}_future.txt not exists, return current calendar!")
get_module_logger("data").warning(
"You can get future calendar by referring to the following document: https://github.com/microsoft/qlib/blob/main/scripts/data_collector/contrib/README.md"
)
fname = self._uri_cal.format(freq)
else:
fname = self._uri_cal.format(freq)
@@ -1016,7 +1019,8 @@ class ClientProvider(BaseProvider):
self.logger = get_module_logger(self.__class__.__name__)
if isinstance(Cal, ClientCalendarProvider):
Cal.set_conn(self.client)
Inst.set_conn(self.client)
if isinstance(Inst, ClientInstrumentProvider):
Inst.set_conn(self.client)
if hasattr(DatasetD, "provider"):
DatasetD.provider.set_conn(self.client)
else:

View File

@@ -27,7 +27,7 @@ class Dataset(Serializable):
- setup data
- The data related attributes' names should start with '_' so that it will not be saved on disk when serializing.
The data could specify the info to caculate the essential data for preparation
The data could specify the info to calculate the essential data for preparation
"""
self.setup_data(**kwargs)
super().__init__()
@@ -92,7 +92,7 @@ class DatasetH(Dataset):
handler : Union[dict, DataHandler]
handler could be:
- insntance of `DataHandler`
- instance of `DataHandler`
- config of `DataHandler`. Please refer to `DataHandler`
@@ -124,7 +124,7 @@ class DatasetH(Dataset):
Parameters
----------
handler_kwargs : dict
Config of DataHanlder, which could include the following arguments:
Config of DataHandler, which could include the following arguments:
- arguments of DataHandler.conf_data, such as 'instruments', 'start_time' and 'end_time'.
@@ -148,11 +148,11 @@ class DatasetH(Dataset):
Parameters
----------
handler_kwargs : dict
init arguments of DataHanlder, which could include the following arguments:
init arguments of DataHandler, which could include the following arguments:
- init_type : Init Type of Handler
- enable_cache : wheter to enable cache
- enable_cache : whether to enable cache
"""
super().setup_data(**kwargs)
@@ -238,7 +238,7 @@ class TSDataSampler:
(T)ime-(S)eries DataSampler
This is the result of TSDatasetH
It works like `torch.data.utils.Dataset`, it provides a very convient interface for constructing time-series
It works like `torch.data.utils.Dataset`, it provides a very convenient interface for constructing time-series
dataset based on tabular data.
If user have further requirements for processing data, user could process them based on `TSDataSampler` or create
@@ -310,7 +310,7 @@ class TSDataSampler:
self.start_idx, self.end_idx = self.data_index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end))
self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance
del self.data # save memory
@staticmethod
@@ -472,7 +472,7 @@ class TSDatasetH(DatasetH):
(T)ime-(S)eries Dataset (H)andler
Covnert the tabular data to Time-Series data
Convert the tabular data to Time-Series data
Requirements analysis

View File

@@ -36,7 +36,7 @@ class DataHandler(Serializable):
The data handler try to maintain a handler with 2 level.
`datetime` & `instruments`.
Any order of the index level can be suported (The order will be implied in the data).
Any order of the index level can be supported (The order will be implied in the data).
The order <`datetime`, `instruments`> will be used when the dataframe index name is missed.
Example of the data:
@@ -77,7 +77,7 @@ class DataHandler(Serializable):
data_loader : Tuple[dict, str, DataLoader]
data loader to load the data.
init_data :
intialize the original data in the constructor.
initialize the original data in the constructor.
fetch_orig : bool
Return the original data instead of copy if possible.
"""
@@ -128,7 +128,7 @@ class DataHandler(Serializable):
def setup_data(self, enable_cache: bool = False):
"""
Set Up the data in case of running intialization for multiple time
Set Up the data in case of running initialization for multiple time
It is responsible for maintaining following variable
1) self._data
@@ -453,7 +453,7 @@ class DataHandlerLP(DataHandler):
def setup_data(self, init_type: str = IT_FIT_SEQ, **kwargs):
"""
Set up the data in case of running intialization for multiple time
Set up the data in case of running initialization for multiple time
Parameters
----------

View File

@@ -130,7 +130,7 @@ class FilterCol(Processor):
class TanhProcess(Processor):
""" Use tanh to process noise data"""
"""Use tanh to process noise data"""
def __call__(self, df):
def tanh_denoise(data):
@@ -145,7 +145,7 @@ class TanhProcess(Processor):
class ProcessInf(Processor):
"""Process infinity """
"""Process infinity"""
def __call__(self, df):
def replace_inf(data):

View File

@@ -12,7 +12,41 @@ from contextlib import contextmanager
from .config import C
def get_module_logger(module_name, level: Optional[int] = None):
class MetaLogger(type):
def __new__(cls, name, bases, dict):
wrapper_dict = logging.Logger.__dict__.copy()
for key in wrapper_dict:
if key not in dict and key != "__reduce__":
dict[key] = wrapper_dict[key]
return type.__new__(cls, name, bases, dict)
class QlibLogger(metaclass=MetaLogger):
"""
Customized logger for Qlib.
"""
def __init__(self, module_name):
self.module_name = module_name
self.level = 0
@property
def logger(self):
logger = logging.getLogger(self.module_name)
logger.setLevel(self.level)
return logger
def setLevel(self, level):
self.level = level
def __getattr__(self, name):
# During unpickling, python will call __getattr__. Use this line to avoid maximum recursion error.
if name in {"__setstate__"}:
raise AttributeError
return self.logger.__getattribute__(name)
def get_module_logger(module_name, level: Optional[int] = None) -> logging.Logger:
"""
Get a logger for a specific module.
@@ -27,7 +61,7 @@ def get_module_logger(module_name, level: Optional[int] = None):
module_name = "qlib.{}".format(module_name)
# Get logger.
module_logger = logging.getLogger(module_name)
module_logger = QlibLogger(module_name)
module_logger.setLevel(level)
return module_logger

View File

@@ -11,11 +11,11 @@ class BaseModel(Serializable, metaclass=abc.ABCMeta):
@abc.abstractmethod
def predict(self, *args, **kwargs) -> object:
""" Make predictions after modeling things """
"""Make predictions after modeling things"""
pass
def __call__(self, *args, **kwargs) -> object:
""" leverage Python syntactic sugar to make the models' behaviors like functions """
"""leverage Python syntactic sugar to make the models' behaviors like functions"""
return self.predict(*args, **kwargs)

View File

@@ -1,36 +1,12 @@
from abc import abstractmethod
from typing import Callable, Union
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
Ensemble can merge the objects in an Ensemble. For example, if there are many submodels predictions, we may need to merge them in an ensemble predictions.
"""
from typing import Union
import pandas as pd
from qlib.workflow.task.collect import Collector
from qlib.utils.serial import Serializable
def ens_workflow(collector: Collector, process_list, *args, **kwargs):
"""the ensemble workflow based on collector and different dict processors.
Args:
collector (Collector): the collector to collect the result into {result_key: things}
process_list (list or Callable): the list of processors or the instance of processor to process dict.
The processor order is same as the list order.
For example: [Group1(..., Ensemble1()), Group2(..., Ensemble2())]
Returns:
dict: the ensemble dict
"""
collect_dict = collector.collect()
if not isinstance(process_list, list):
process_list = [process_list]
ensemble = {}
for artifact in collect_dict:
value = collect_dict[artifact]
for process in process_list:
if not callable(process):
raise NotImplementedError(f"{type(process)} is not supported in `ens_workflow`.")
value = process(value, *args, **kwargs)
ensemble[artifact] = value
return ensemble
class Ensemble:
@@ -49,21 +25,45 @@ class Ensemble:
raise NotImplementedError(f"Please implement the `__call__` method.")
class SingleKeyEnsemble(Ensemble):
"""
Extract the object if there is only one key and value in dict. Make result more readable.
{Only key: Only value} -> Only value
If there are more than 1 key or less than 1 key, then do nothing.
Even you can run this recursively to make dict more readable.
NOTE: Default run recursively.
"""
def __call__(self, ensemble_dict: Union[dict, object], recursion: bool = True) -> object:
if not isinstance(ensemble_dict, dict):
return ensemble_dict
if recursion:
tmp_dict = {}
for k, v in ensemble_dict.items():
tmp_dict[k] = self(v, recursion)
ensemble_dict = tmp_dict
keys = list(ensemble_dict.keys())
if len(keys) == 1:
ensemble_dict = ensemble_dict[keys[0]]
return ensemble_dict
class RollingEnsemble(Ensemble):
"""Merge the rolling objects in an Ensemble"""
def __call__(self, ensemble_dict: dict):
def __call__(self, ensemble_dict: dict) -> pd.DataFrame:
"""Merge a dict of rolling dataframe like `prediction` or `IC` into an ensemble.
NOTE: The values of dict must be pd.Dataframe, and have the index "datetime"
NOTE: The values of dict must be pd.DataFrame, and have the index "datetime"
Args:
ensemble_dict (dict): a dict like {"A": pd.Dataframe, "B": pd.Dataframe}.
ensemble_dict (dict): a dict like {"A": pd.DataFrame, "B": pd.DataFrame}.
The key of the dict will be ignored.
Returns:
pd.Dataframe: the complete result of rolling.
pd.DataFrame: the complete result of rolling.
"""
artifact_list = list(ensemble_dict.values())
artifact_list.sort(key=lambda x: x.index.get_level_values("datetime").min())
@@ -72,3 +72,24 @@ class RollingEnsemble(Ensemble):
artifact = artifact[~artifact.index.duplicated(keep="last")]
artifact = artifact.sort_index()
return artifact
class AverageEnsemble(Ensemble):
def __call__(self, ensemble_dict: dict):
"""
Average a dict of same shape dataframe like `prediction` or `IC` into an ensemble.
NOTE: The values of dict must be pd.DataFrame, and have the index "datetime"
Args:
ensemble_dict (dict): a dict like {"A": pd.DataFrame, "B": pd.DataFrame}.
The key of the dict will be ignored.
Returns:
pd.DataFrame: the complete result of averaging.
"""
values = list(ensemble_dict.values())
results = pd.concat(values, axis=1)
results = results.mean(axis=1).to_frame("score")
results = results.sort_index()
return results

View File

@@ -1,3 +1,17 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
Group can group a set of object based on `group_func` and change them to a dict.
After group, we provide a method to reduce them.
For example:
group: {(A,B,C1): object, (A,B,C2): object} -> {(A,B): {C1: object, C2: object}}
reduce: {(A,B): {C1: object, C2: object}} -> {(A,B): object}
"""
from qlib.model.ens.ensemble import Ensemble, RollingEnsemble
from typing import Callable, Union
from joblib import Parallel, delayed
@@ -21,20 +35,20 @@ class Group:
self._group_func = group_func
self._ens_func = ens
def group(self, *args, **kwargs):
def group(self, *args, **kwargs) -> dict:
# TODO: such design is weird when `_group_func` is the only configurable part in the class
if isinstance(getattr(self, "_group_func", None), Callable):
return self._group_func(*args, **kwargs)
else:
raise NotImplementedError(f"Please specify valid `group_func`.")
def reduce(self, *args, **kwargs):
def reduce(self, *args, **kwargs) -> dict:
if isinstance(getattr(self, "_ens_func", None), Callable):
return self._ens_func(*args, **kwargs)
else:
raise NotImplementedError(f"Please specify valid `_ens_func`.")
def __call__(self, ungrouped_dict: dict, n_jobs=1, verbose=0, *args, **kwargs):
def __call__(self, ungrouped_dict: dict, n_jobs=1, verbose=0, *args, **kwargs) -> dict:
"""Group the ungrouped_dict into different groups.
Args:
@@ -59,7 +73,7 @@ class Group:
class RollingGroup(Group):
"""group the rolling dict"""
def group(self, rolling_dict: dict):
def group(self, rolling_dict: dict) -> dict:
"""Given an rolling dict likes {(A,B,R): things}, return the grouped dict likes {(A,B): {R:things}}
NOTE: There is a assumption which is the rolling key is at the end of key tuple, because the rolling results always need to be ensemble firstly.

View File

@@ -1,27 +0,0 @@
import abc
import typing
class TaskGen(metaclass=abc.ABCMeta):
@abc.abstractmethod
def __call__(self, *args, **kwargs) -> typing.List[dict]:
"""
generate
Parameters
----------
args, kwargs:
The info for generating tasks
Example 1):
input: a specific task template
output: rolling version of the tasks
Example 2):
input: a specific task template
output: a set of tasks with different losses
Returns
-------
typing.List[dict]:
A list of tasks
"""
pass

View File

@@ -1,58 +1,69 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import copy
"""
The Trainer will train a list of tasks and return a list of model recorder.
There are two steps in each Trainer including ``train``(make model recorder) and ``end_train``(modify model recorder).
This is concept called ``DelayTrainer``, which can be used in online simulating to parallel training.
In ``DelayTrainer``, the first step is only to save some necessary info to model recorder, and the second step which will be finished in the end can do some concurrent and time-consuming operations such as model fitting.
``Qlib`` offer two kind of Trainer, ``TrainerR`` is the simplest way and ``TrainerRM`` is based on TaskManager to help manager tasks lifecycle automatically.
"""
import socket
import time
from xxlimited import Str
from qlib.utils import init_instance_by_config, flatten_dict, get_cls_kwargs
from qlib.workflow import R
from qlib.workflow.recorder import Recorder
from qlib.workflow.record_temp import SignalRecord
from qlib.workflow.task.manage import TaskManager, run_task
from typing import Callable, List
from qlib.data.dataset import Dataset
from qlib.model.base import Model
import socket
from qlib.utils import flatten_dict, get_cls_kwargs, init_instance_by_config
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord
from qlib.workflow.recorder import Recorder
from qlib.workflow.task.manage import TaskManager, run_task
def begin_task_train(task_config: dict, experiment_name: str, *args, **kwargs) -> Recorder:
def begin_task_train(task_config: dict, experiment_name: str, recorder_name: str = None) -> Recorder:
"""
Begin a task training with starting a recorder and saving the task config.
Begin a task training to start a recorder and save the task config.
Args:
task_config (dict)
experiment_name (str)
task_config (dict): the config of a task
experiment_name (str): the name of experiment
recorder_name (str): the given name will be the recorder name. None for using rid.
Returns:
Recorder
Recorder: the model recorder
"""
with R.start(experiment_name=experiment_name, recorder_name=str(time.time())):
with R.start(experiment_name=experiment_name, recorder_name=recorder_name):
R.log_params(**flatten_dict(task_config))
R.save_objects(**{"task": task_config}) # keep the original format and datatype
R.set_tags(**{"hostname": socket.gethostname(), "train_status": "begin_task_train"})
R.set_tags(**{"hostname": socket.gethostname()})
recorder: Recorder = R.get_recorder()
return recorder
def end_task_train(rec: Recorder, experiment_name: str, *args, **kwargs):
def end_task_train(rec: Recorder, experiment_name: str) -> Recorder:
"""
Finished task training with real model fitting and saving.
Finish task training with real model fitting and saving.
Args:
rec (Recorder): This recorder will be resumed
experiment_name (str)
rec (Recorder): the recorder will be resumed
experiment_name (str): the name of experiment
Returns:
Recorder
Recorder: the model recorder
"""
with R.start(experiment_name=experiment_name, recorder_name=rec.info["name"], resume=True):
with R.start(experiment_name=experiment_name, recorder_id=rec.info["id"], resume=True):
task_config = R.load_object("task")
# model & dataset initiaiton
# model & dataset initiation
model: Model = init_instance_by_config(task_config["model"])
dataset: Dataset = init_instance_by_config(task_config["dataset"])
# model training
model.fit(dataset)
R.save_objects(**{"params.pkl": model})
# This dataset is saved for online inference. So the concrete data should not be dumped
# this dataset is saved for online inference. So the concrete data should not be dumped
dataset.config(dump_all=False, recursive=True)
R.save_objects(**{"dataset": dataset})
# generate records: prediction, backtest, and analysis
@@ -67,18 +78,18 @@ def end_task_train(rec: Recorder, experiment_name: str, *args, **kwargs):
rconf = {"recorder": rec}
r = cls(**kwargs, **rconf)
r.generate()
R.set_tags(**{"train_status": "end_task_train"})
return rec
def task_train(task_config: dict, experiment_name: str) -> Recorder:
"""
task based training
Task based training, will be divided into two steps.
Parameters
----------
task_config : dict
A dict describes a task setting.
The config of a task.
experiment_name: str
The name of experiment
@@ -96,39 +107,79 @@ class Trainer:
The trainer which can train a list of model
"""
def train(self, tasks: list, *args, **kwargs):
"""Given a list of model definition, begin a training and return the models.
def __init__(self):
self.delay = False
def train(self, tasks: list, *args, **kwargs) -> list:
"""
Given a list of model definition, begin a training and return the models.
Args:
tasks: a list of tasks
Returns:
list: a list of models
"""
raise NotImplementedError(f"Please implement the `train` method.")
def end_train(self, models, *args, **kwargs):
"""Given a list of models, finished something in the end of training if you need.
def end_train(self, models: list, *args, **kwargs) -> list:
"""
Given a list of models, finished something in the end of training if you need.
The models maybe Recorder, txt file, database and so on.
Args:
models: a list of models
Returns:
list: a list of models
"""
# do nothing if you finished all work in `train` method
return models
def is_delay(self) -> bool:
"""
If Trainer will delay finishing `end_train`.
Returns:
bool: if DelayTrainer
"""
return self.delay
def reset(self):
"""
Reset the Trainer status.
"""
pass
class TrainerR(Trainer):
"""Trainer based on (R)ecorder.
"""
Trainer based on (R)ecorder.
It will train a list of tasks and return a list of model recorder in a linear way.
Assumption: models were defined by `task` and the results will saved to `Recorder`
"""
def __init__(self, experiment_name, train_func=task_train):
def __init__(self, experiment_name: str, train_func: Callable = task_train):
"""
Init TrainerR.
Args:
experiment_name (str): the name of experiment.
train_func (Callable, optional): default training method. Defaults to `task_train`.
"""
super().__init__()
self.experiment_name = experiment_name
self.train_func = train_func
def train(self, tasks: list, train_func=None, *args, **kwargs):
"""Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed.
def train(self, tasks: list, train_func: Callable = None, **kwargs) -> List[Recorder]:
"""
Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed.
Args:
tasks (list): a list of definition based on `task` dict
train_func (Callable): the train method which need at least `task` and `experiment_name`. None for default.
train_func (Callable): the train method which need at least `task`s and `experiment_name`. None for default training method.
kwargs: the params for train_func.
Returns:
list: a list of Recorders
@@ -137,17 +188,74 @@ class TrainerR(Trainer):
train_func = self.train_func
recs = []
for task in tasks:
recs.append(train_func(task, self.experiment_name, *args, **kwargs))
rec = train_func(task, self.experiment_name, **kwargs)
rec.set_tags(**{"train_status": "begin_task_train"})
recs.append(rec)
return recs
def end_train(self, recs: list, **kwargs) -> list:
for rec in recs:
rec.set_tags(**{"train_status": "end_task_train"})
return recs
class DelayTrainerR(TrainerR):
"""
A delayed implementation based on TrainerR, which means `train` method may only do some preparation and `end_train` method can do the real model fitting.
"""
def __init__(self, experiment_name, train_func=begin_task_train, end_train_func=end_task_train):
"""
Init TrainerRM.
Args:
experiment_name (str): the name of experiment.
train_func (Callable, optional): default train method. Defaults to `begin_task_train`.
end_train_func (Callable, optional): default end_train method. Defaults to `end_task_train`.
"""
super().__init__(experiment_name, train_func)
self.end_train_func = end_train_func
self.delay = True
def end_train(self, recs, end_train_func=None, **kwargs) -> List[Recorder]:
"""
Given a list of Recorder and return a list of trained Recorder.
This class will finish real data loading and model fitting.
Args:
recs (list): a list of Recorder, the tasks have been saved to them
end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
kwargs: the params for end_train_func.
Returns:
list: a list of Recorders
"""
if end_train_func is None:
end_train_func = self.end_train_func
for rec in recs:
end_train_func(rec, **kwargs)
rec.set_tags(**{"train_status": "end_task_train"})
return recs
class TrainerRM(Trainer):
"""Trainer based on (R)ecorder and Task(M)anager
"""
Trainer based on (R)ecorder and Task(M)anager.
It can train a list of tasks and return a list of model recorder in a multiprocessing way.
Assumption: `task` will be saved to TaskManager and `task` will be fetched and trained from TaskManager
"""
def __init__(self, experiment_name: str, task_pool: str, train_func=task_train):
"""
Init TrainerR.
Args:
experiment_name (str): the name of experiment.
task_pool (str): task pool name in TaskManager.
train_func (Callable, optional): default training method. Defaults to `task_train`.
"""
super().__init__()
self.experiment_name = experiment_name
self.task_pool = task_pool
self.train_func = train_func
@@ -155,20 +263,23 @@ class TrainerRM(Trainer):
def train(
self,
tasks: list,
train_func=None,
before_status=TaskManager.STATUS_WAITING,
after_status=TaskManager.STATUS_DONE,
*args,
train_func: Callable = None,
before_status: str = TaskManager.STATUS_WAITING,
after_status: str = TaskManager.STATUS_DONE,
**kwargs,
):
"""Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed.
) -> List[Recorder]:
"""
Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed.
This method defaults to a single process, but TaskManager offered a great way to parallel training.
Users can customize their train_func to realize multiple processes or even multiple machines.
Args:
tasks (list): a list of definition based on `task` dict
train_func (Callable): the train method which need at least `task` and `experiment_name`. None for default.
train_func (Callable): the train method which need at least `task`s and `experiment_name`. None for default training method.
before_status (str): the tasks in before_status will be fetched and trained. Can be STATUS_WAITING, STATUS_PART_DONE.
after_status (str): the tasks after trained will become after_status. Can be STATUS_WAITING, STATUS_PART_DONE.
kwargs: the params for train_func.
Returns:
list: a list of Recorders
@@ -183,63 +294,29 @@ class TrainerRM(Trainer):
experiment_name=self.experiment_name,
before_status=before_status,
after_status=after_status,
*args,
**kwargs,
)
recs = []
for _id in _id_list:
recs.append(tm.re_query(_id)["res"])
rec = tm.re_query(_id)["res"]
rec.set_tags(**{"train_status": "begin_task_train"})
recs.append(rec)
return recs
class DelayTrainerR(TrainerR):
"""
A delayed implementation based on TrainerR, which means `train` method may only do some preparation and `end_train` method can do the real model fitting.
"""
def __init__(self, experiment_name, train_func=begin_task_train, end_train_func=end_task_train):
super().__init__(experiment_name, train_func)
self.end_train_func = end_train_func
self.recs = []
def train(self, tasks: list, train_func, *args, **kwargs):
"""
Same as `train` of TrainerR, the results will be recorded in self.recs
Args:
tasks (list): a list of definition based on `task` dict
train_func (Callable): the train method which need at least `task` and `experiment_name`. None for default.
Returns:
list: a list of Recorders
"""
self.recs = super().train(tasks, train_func=train_func, *args, **kwargs)
return self.recs
def end_train(self, recs=None, end_train_func=None):
"""
Given a list of Recorder and return a list of trained Recorder.
This class will finished real data loading and model fitting.
Args:
recs (list, optional): a list of Recorder, the tasks have been saved to them. Defaults to None for using self.recs.
end_train_func (Callable, optional): the end_train method which need at least `rec` and `experiment_name`. Defaults to None for using self.end_train_func.
Returns:
list: a list of Recorders
"""
if recs is None:
recs = copy.deepcopy(self.recs)
# the models will be only trained once
self.recs = []
if end_train_func is None:
end_train_func = self.end_train_func
def end_train(self, recs: list, **kwargs) -> list:
for rec in recs:
end_train_func(rec)
rec.set_tags(**{"train_status": "end_task_train"})
return recs
def reset(self):
"""
.. note::
this method will delete all task in this task_pool!
"""
tm = TaskManager(task_pool=self.task_pool)
tm.remove()
class DelayTrainerRM(TrainerRM):
"""
@@ -250,28 +327,28 @@ class DelayTrainerRM(TrainerRM):
def __init__(self, experiment_name, task_pool: str, train_func=begin_task_train, end_train_func=end_task_train):
super().__init__(experiment_name, task_pool, train_func)
self.end_train_func = end_train_func
self.delay = True
def train(self, tasks: list, train_func=None, *args, **kwargs):
def train(self, tasks: list, train_func=None, **kwargs):
"""
Same as `train` of TrainerRM, the results will be recorded in self.recs
Same as `train` of TrainerRM, after_status will be STATUS_PART_DONE.
Args:
tasks (list): a list of definition based on `task` dict
train_func (Callable): the train method which need at least `task` and `experiment_name`. None for default.
train_func (Callable): the train method which need at least `task`s and `experiment_name`. Defaults to None for using self.train_func.
Returns:
list: a list of Recorders
"""
return super().train(tasks, train_func=train_func, after_status=TaskManager.STATUS_PART_DONE, *args, **kwargs)
return super().train(tasks, train_func=train_func, after_status=TaskManager.STATUS_PART_DONE, **kwargs)
def end_train(self, recs, end_train_func=None):
def end_train(self, recs, end_train_func=None, **kwargs):
"""
Given a list of Recorder and return a list of trained Recorder.
This class will finished real data loading and model fitting.
This class will finish real data loading and model fitting.
Args:
recs (list, optional): a list of Recorder, the tasks have been saved to them. Defaults to None for using self.recs..
end_train_func (Callable, optional): the end_train method which need at least `rec` and `experiment_name`. Defaults to None for using self.end_train_func.
recs (list): a list of Recorder, the tasks have been saved to them.
end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
kwargs: the params for end_train_func.
Returns:
list: a list of Recorders
@@ -284,5 +361,8 @@ class DelayTrainerRM(TrainerRM):
self.task_pool,
experiment_name=self.experiment_name,
before_status=TaskManager.STATUS_PART_DONE,
**kwargs,
)
for rec in recs:
rec.set_tags(**{"train_status": "end_task_train"})
return recs

View File

@@ -5,9 +5,9 @@ import abc
class BaseOptimizer(abc.ABC):
""" Construct portfolio with a optimization related method """
"""Construct portfolio with a optimization related method"""
@abc.abstractmethod
def __call__(self, *args, **kwargs) -> object:
""" Generate a optimized portfolio allocation """
"""Generate a optimized portfolio allocation"""
pass

View File

@@ -3,11 +3,12 @@
from pathlib import Path
import pickle
from typing import Union
class Serializable:
"""
Serializable will change the behaviours of pickle.
Serializable will change the behaviors of pickle.
- It only saves the state whose name **does not** start with `_`
It provides a syntactic sugar for distinguish the attributes which user doesn't want.
- For examples, a learnable Datahandler just wants to save the parameters without data when dumping to disk
@@ -70,7 +71,7 @@ class Serializable:
obj.config(**params, recursive=True)
del self.__dict__[self.FLAG_KEY]
def to_pickle(self, path: [Path, str], dump_all: bool = None, exclude: list = None):
def to_pickle(self, path: Union[Path, str], dump_all: bool = None, exclude: list = None):
self.config(dump_all=dump_all, exclude=exclude)
with Path(path).open("wb") as f:
pickle.dump(self, f)

View File

@@ -23,7 +23,10 @@ class QlibRecorder:
@contextmanager
def start(
self,
*,
experiment_id: Optional[Text] = None,
experiment_name: Optional[Text] = None,
recorder_id: Optional[Text] = None,
recorder_name: Optional[Text] = None,
uri: Optional[Text] = None,
resume: bool = False,
@@ -45,8 +48,12 @@ class QlibRecorder:
Parameters
----------
experiment_id : str
id of the experiment one wants to start.
experiment_name : str
name of the experiment one wants to start.
recorder_id : str
id of the recorder under the experiment one wants to start.
recorder_name : str
name of the recorder under the experiment one wants to start.
uri : str
@@ -57,7 +64,14 @@ class QlibRecorder:
resume : bool
whether to resume the specific recorder with given name under the given experiment.
"""
run = self.start_exp(experiment_name, recorder_name, uri, resume)
run = self.start_exp(
experiment_id=experiment_id,
experiment_name=experiment_name,
recorder_id=recorder_id,
recorder_name=recorder_name,
uri=uri,
resume=resume,
)
try:
yield run
except Exception as e:
@@ -65,7 +79,9 @@ class QlibRecorder:
raise e
self.end_exp(Recorder.STATUS_FI)
def start_exp(self, experiment_name=None, recorder_name=None, uri=None, resume=False):
def start_exp(
self, *, experiment_id=None, experiment_name=None, recorder_id=None, recorder_name=None, uri=None, resume=False
):
"""
Lower level method for starting an experiment. When use this method, one should end the experiment manually
and the status of the recorder may not be handled properly. Here is the example code:
@@ -79,8 +95,12 @@ class QlibRecorder:
Parameters
----------
experiment_id : str
id of the experiment one wants to start.
experiment_name : str
the name of the experiment to be started
recorder_id : str
id of the recorder under the experiment one wants to start.
recorder_name : str
name of the recorder under the experiment one wants to start.
uri : str
@@ -93,7 +113,14 @@ class QlibRecorder:
-------
An experiment instance being started.
"""
return self.exp_manager.start_exp(experiment_name, recorder_name, uri, resume)
return self.exp_manager.start_exp(
experiment_id=experiment_id,
experiment_name=experiment_name,
recorder_id=recorder_id,
recorder_name=recorder_name,
uri=uri,
resume=resume,
)
def end_exp(self, recorder_status=Recorder.STATUS_FI):
"""

View File

@@ -1,14 +1,14 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import mlflow
import mlflow, logging
from mlflow.entities import ViewType
from mlflow.exceptions import MlflowException
from pathlib import Path
from .recorder import Recorder, MLflowRecorder
from ..log import get_module_logger
logger = get_module_logger("workflow", "INFO")
logger = get_module_logger("workflow", logging.INFO)
class Experiment:
@@ -39,12 +39,14 @@ class Experiment:
output["recorders"] = list(recorders.keys())
return output
def start(self, recorder_name=None, resume=False):
def start(self, *, recorder_id=None, recorder_name=None, resume=False):
"""
Start the experiment and set it to be active. This method will also start a new recorder.
Parameters
----------
recorder_id : str
the id of the recorder to be created.
recorder_name : str
the name of the recorder to be created.
resume : bool
@@ -238,14 +240,14 @@ class MLflowExperiment(Experiment):
def __repr__(self):
return "{name}(id={id}, info={info})".format(name=self.__class__.__name__, id=self.id, info=self.info)
def start(self, recorder_name=None, resume=False):
def start(self, *, recorder_id=None, recorder_name=None, resume=False):
logger.info(f"Experiment {self.id} starts running ...")
# Get or create recorder
if recorder_name is None:
recorder_name = self._default_rec_name
# resume the recorder
if resume:
recorder, _ = self._get_or_create_rec(recorder_name=recorder_name)
recorder, _ = self._get_or_create_rec(recorder_id=recorder_id, recorder_name=recorder_name)
# create a new recorder
else:
recorder = self.create_recorder(recorder_name)

View File

@@ -4,7 +4,7 @@
import mlflow
from mlflow.exceptions import MlflowException
from mlflow.entities import ViewType
import os
import os, logging
from pathlib import Path
from contextlib import contextmanager
from typing import Optional, Text
@@ -14,7 +14,7 @@ from ..config import C
from .recorder import Recorder
from ..log import get_module_logger
logger = get_module_logger("workflow", "INFO")
logger = get_module_logger("workflow", logging.INFO)
class ExpManager:
@@ -33,7 +33,10 @@ class ExpManager:
def start_exp(
self,
*,
experiment_id: Optional[Text] = None,
experiment_name: Optional[Text] = None,
recorder_id: Optional[Text] = None,
recorder_name: Optional[Text] = None,
uri: Optional[Text] = None,
resume: bool = False,
@@ -45,8 +48,12 @@ class ExpManager:
Parameters
----------
experiment_id : str
id of the active experiment.
experiment_name : str
name of the active experiment.
recorder_id : str
id of the recorder to be started.
recorder_name : str
name of the recorder to be started.
uri : str
@@ -298,7 +305,10 @@ class MLflowExpManager(ExpManager):
def start_exp(
self,
*,
experiment_id: Optional[Text] = None,
experiment_name: Optional[Text] = None,
recorder_id: Optional[Text] = None,
recorder_name: Optional[Text] = None,
uri: Optional[Text] = None,
resume: bool = False,
@@ -308,11 +318,11 @@ class MLflowExpManager(ExpManager):
# Create experiment
if experiment_name is None:
experiment_name = self._default_exp_name
experiment, _ = self._get_or_create_exp(experiment_name=experiment_name)
experiment, _ = self._get_or_create_exp(experiment_id=experiment_id, experiment_name=experiment_name)
# Set up active experiment
self.active_experiment = experiment
# Start the experiment
self.active_experiment.start(recorder_name, resume)
self.active_experiment.start(recorder_id=recorder_id, recorder_name=recorder_name, resume=resume)
return self.active_experiment

View File

@@ -1,477 +1,177 @@
from copy import deepcopy
from operator import index
import pandas as pd
from qlib.model.ens.ensemble import ens_workflow
from qlib.model.ens.group import RollingGroup
from qlib.utils.serial import Serializable
from typing import Dict, List, Union
from qlib import get_module_logger
from qlib.data.data import D
from qlib.model.trainer import Trainer, TrainerR, task_train
from qlib.workflow import R
from qlib.workflow.online.update import PredUpdater
from qlib.workflow.recorder import Recorder
from qlib.workflow.task.collect import Collector, RecorderCollector
from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.utils import TimeAdjuster, list_recorders
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
This class is a component of online serving, it can manage a series of models dynamically.
With the change of time, the decisive models will be also changed. In this module, we called those contributing models as `online` models.
OnlineManager can manage a set of `Online Strategy <#Online Strategy>`_ and run them dynamically.
With the change of time, the decisive models will be also changed. In this module, we call those contributing models as `online` models.
In every routine(such as everyday or every minutes), the `online` models maybe changed and the prediction of them need to be updated.
So this module provide a series methods to control this process.
This module also provide a method to simulate `Online Strategy <#Online Strategy>`_ in the history.
Which means you can verify your strategy or find a better one.
"""
from typing import Dict, List, Union
import pandas as pd
from qlib import get_module_logger
from qlib.data.data import D
from qlib.model.ens.ensemble import AverageEnsemble, SingleKeyEnsemble
from qlib.utils.serial import Serializable
from qlib.workflow.online.strategy import OnlineStrategy
from qlib.workflow.task.collect import HyperCollector
class OnlineManager(Serializable):
ONLINE_KEY = "online_status" # the online status key in recorder
ONLINE_TAG = "online" # the 'online' model
# NOTE: The meaning of this tag is that we can not assume the training models can be trained before we need its predition. Whenever finished training, it can be guaranteed that there are some online models.
NEXT_ONLINE_TAG = "next_online" # the 'next online' model, which can be 'online' model when call reset_online_model
OFFLINE_TAG = "offline" # the 'offline' model, not for online serving
SIGNAL_EXP = "OnlineManagerSignals" # a specific experiment to save signals of different experiment.
def __init__(self, trainer: Trainer = None, need_log=True):
"""
init OnlineManager.
Args:
trainer (Trainer, optional): a instance of Trainer. Defaults to None.
need_log (bool, optional): print log or not. Defaults to True.
"""
self.trainer = trainer
self.logger = get_module_logger(self.__class__.__name__)
self.need_log = need_log
self.cur_time = None
def prepare_signals(self):
"""
After perparing the data of last routine (a box in box-plot) which means the end of the routine, we can prepare trading signals for next routine.
Must use `pass` even though there is nothing to do.
"""
raise NotImplementedError(f"Please implement the `prepare_signals` method.")
def get_signals(self):
"""
After preparing signals, here is the method to get them.
"""
raise NotImplementedError(f"Please implement the `get_signals` method.")
def prepare_tasks(self, *args, **kwargs):
"""
After the end of a routine, check whether we need to prepare and train some new tasks.
return the new tasks waiting for training.
"""
raise NotImplementedError(f"Please implement the `prepare_tasks` method.")
def prepare_new_models(self, tasks, tag=NEXT_ONLINE_TAG, check_func=None, *args, **kwargs):
"""
Use trainer to train a list of tasks and set the trained model to `tag`.
Args:
tasks (list): a list of tasks.
tag (str):
`ONLINE_TAG` for first train or additional train
`NEXT_ONLINE_TAG` for reset online model when calling `reset_online_tag`
`OFFLINE_TAG` for train but offline those models
check_func: the method to judge if a model can be online.
The parameter is the model record and return True for online.
None for online every models.
*args, **kwargs: will be passed to end_train which means will be passed to customized train method.
"""
if check_func is None:
check_func = lambda x: True
if len(tasks) > 0:
if self.trainer is not None:
new_models = self.trainer.train(tasks, *args, **kwargs)
if check_func(new_models):
self.set_online_tag(tag, new_models)
if self.need_log:
self.logger.info(f"Finished preparing {len(new_models)} new models and set them to {tag}.")
else:
self.logger.warn("No trainer to train new tasks.")
def update_online_pred(self):
"""
After the end of a routine, update the predictions of online models to latest.
"""
raise NotImplementedError(f"Please implement the `update_online_pred` method.")
def set_online_tag(self, tag, recorder):
"""
Set `tag` to the model to sign whether online.
Args:
tag (str): the tags in `ONLINE_TAG`, `NEXT_ONLINE_TAG`, `OFFLINE_TAG`
"""
raise NotImplementedError(f"Please implement the `set_online_tag` method.")
def get_online_tag(self):
"""
Given a model and return its online tag.
"""
raise NotImplementedError(f"Please implement the `get_online_tag` method.")
def reset_online_tag(self, recorders=None):
"""offline all models and set the recorders to 'online'. If no parameter and no 'next online' model, then do nothing.
Args:
recorders (List, optional):
the recorders you want to reset to 'online'. If don't give, set 'next online' model to 'online' model. If there isn't any 'next online' model, then maintain existing 'online' model.
Returns:
list: new online recorder. [] if there is no update.
"""
raise NotImplementedError(f"Please implement the `reset_online_tag` method.")
def online_models(self):
"""
Return online models.
"""
raise NotImplementedError(f"Please implement the `online_models` method.")
def first_train(self):
"""
Train a series of models firstly and set some of them into online models.
"""
raise NotImplementedError(f"Please implement the `first_train` method.")
def get_collector(self):
"""
Return the collector.
Returns:
Collector
"""
raise NotImplementedError(f"Please implement the `get_collector` method.")
def delay_prepare(self, rec_dict, *args, **kwargs):
"""
Prepare all models and signals if there are something waiting for prepare.
NOTE: Assumption: the predictions of online models are between `time_segment`, or this method will work in a wrong way.
Args:
rec_dict (str): an online models dict likes {(begin_time, end_time):[online models]}.
*args, **kwargs: will be passed to end_train which means will be passed to customized train method.
"""
for time_segment, recs_list in rec_dict.items():
self.trainer.end_train(recs_list, *args, **kwargs)
self.reset_online_tag(recs_list)
self.prepare_signals()
signal_max = self.get_signals().index.get_level_values("datetime").max()
if time_segment[1] is not None and signal_max > time_segment[1]:
raise ValueError(
f"The max time of signals prepared by online models is {signal_max}, but those models only online in {time_segment}"
)
def routine(self, cur_time=None, delay_prepare=False, *args, **kwargs):
"""
The typical update process after a routine, such as day by day or month by month.
update online prediction -> prepare signals -> prepare tasks -> prepare new models -> reset online models
NOTE: Assumption: if using simulator (delay_prepare is True), the prediction will be prepared well after every training, so there is no need to update predictions.
Args:
cur_time ([type], optional): [description]. Defaults to None.
delay_prepare (bool, optional): [description]. Defaults to False.
*args, **kwargs: will be passed to `prepare_tasks` and `prepare_new_models`. It can be some hyper parameter or training config.
Returns:
[type]: [description]
"""
self.cur_time = cur_time # None for latest date
if not delay_prepare:
self.update_online_pred()
self.prepare_signals()
tasks = self.prepare_tasks(*args, **kwargs)
self.prepare_new_models(tasks, *args, **kwargs)
return self.reset_online_tag()
class OnlineManagerR(OnlineManager):
"""
The implementation of OnlineManager based on (R)ecorder.
OnlineManager can manage online models with `Online Strategy <#Online Strategy>`_.
It also provide a history recording which models are onlined at what time.
"""
def __init__(self, experiment_name: str, trainer: Trainer = None, need_log=True):
"""
init OnlineManagerR.
Args:
experiment_name (str): the experiment name.
trainer (Trainer, optional): a instance of Trainer. Defaults to None.
need_log (bool, optional): print log or not. Defaults to True.
"""
if trainer is None:
trainer = TrainerR(experiment_name)
super().__init__(trainer=trainer, need_log=need_log)
self.exp_name = experiment_name
self.signal_rec = None
def set_online_tag(self, tag, recorder: Union[Recorder, List]):
"""
Set `tag` to the model to sign whether online.
Args:
tag (str): the tags in `ONLINE_TAG`, `NEXT_ONLINE_TAG`, `OFFLINE_TAG`
recorder (Union[Recorder, List])
"""
if isinstance(recorder, Recorder):
recorder = [recorder]
for rec in recorder:
rec.set_tags(**{self.ONLINE_KEY: tag})
if self.need_log:
self.logger.info(f"Set {len(recorder)} models to '{tag}'.")
def get_online_tag(self, recorder: Recorder):
"""
Given a model and return its online tag.
Args:
recorder (Recorder): a instance of recorder
Returns:
str: the tag
"""
tags = recorder.list_tags()
return tags.get(OnlineManager.ONLINE_KEY, OnlineManager.OFFLINE_TAG)
def reset_online_tag(self, recorder: Union[Recorder, List] = None):
"""offline all models and set the recorders to 'online'. If no parameter and no 'next online' model, then do nothing.
Args:
recorders (Union[Recorder, List], optional):
the recorders you want to reset to 'online'. If don't give, set 'next online' model to 'online' model. If there isn't any 'next online' model, then maintain existing 'online' model.
Returns:
list: new online recorder. [] if there is no update.
"""
if recorder is None:
recorder = list(
list_recorders(
self.exp_name, lambda rec: self.get_online_tag(rec) == OnlineManager.NEXT_ONLINE_TAG
).values()
)
if isinstance(recorder, Recorder):
recorder = [recorder]
if len(recorder) == 0:
if self.need_log:
self.logger.info("No 'next online' model, just use current 'online' models.")
return []
recs = list_recorders(self.exp_name)
self.set_online_tag(OnlineManager.OFFLINE_TAG, list(recs.values()))
self.set_online_tag(OnlineManager.ONLINE_TAG, recorder)
return recorder
def get_signals(self):
"""
get signals from the recorder(named self.exp_name) of the experiment(named self.SIGNAL_EXP)
Returns:
signals
"""
if self.signal_rec is None:
with R.start(experiment_name=self.SIGNAL_EXP, recorder_name=self.exp_name, resume=True):
self.signal_rec = R.get_recorder()
signals = None
try:
signals = self.signal_rec.load_object("signals")
except OSError:
self.logger.warn("Can not find `signals`, have you called `prepare_signals` before?")
return signals
def online_models(self):
"""
Return online models.
Returns:
list: the list of online models
"""
return list(
list_recorders(self.exp_name, lambda rec: self.get_online_tag(rec) == OnlineManager.ONLINE_TAG).values()
)
def update_online_pred(self):
"""
Update all online model predictions to the latest day in Calendar
"""
online_models = self.online_models()
for rec in online_models:
PredUpdater(rec, to_date=self.cur_time, need_log=self.need_log).update()
if self.need_log:
self.logger.info(f"Finished updating {len(online_models)} online model predictions of {self.exp_name}.")
def prepare_signals(self, over_write=False):
"""
Average the predictions of online models and offer a trading signals every routine.
The signals will be saved to `signal` file of a recorder named self.exp_name of a experiment using the name of `SIGNAL_EXP`
Even if the latest signal already exists, the latest calculation result will be overwritten.
NOTE: Given a prediction of a certain time, all signals before this time will be prepared well.
Args:
over_write (bool, optional): If True, the new signals will overwrite the file. If False, the new signals will append to the end of signals. Defaults to False.
"""
if self.signal_rec is None:
with R.start(experiment_name=self.SIGNAL_EXP, recorder_name=self.exp_name, resume=True):
self.signal_rec = R.get_recorder()
pred = []
try:
old_signals = self.signal_rec.load_object("signals")
except OSError:
old_signals = None
for rec in self.online_models():
pred.append(rec.load_object("pred.pkl"))
signals = pd.concat(pred, axis=1).mean(axis=1).to_frame("score")
signals = signals.sort_index()
if old_signals is not None and not over_write:
old_max = old_signals.index.get_level_values("datetime").max()
new_signals = signals.loc[old_max:]
signals = pd.concat([old_signals, new_signals], axis=0)
else:
new_signals = signals
if self.need_log:
self.logger.info(f"Finished preparing new {len(new_signals)} signals to {self.SIGNAL_EXP}/{self.exp_name}.")
self.signal_rec.save_objects(**{"signals": signals})
class RollingOnlineManager(OnlineManagerR):
"""An implementation of OnlineManager based on Rolling."""
def __init__(
self,
experiment_name: str,
rolling_gen: RollingGen,
trainer: Trainer = None,
strategy: Union[OnlineStrategy, List[OnlineStrategy]],
begin_time: Union[str, pd.Timestamp] = None,
freq="day",
need_log=True,
):
"""
init RollingOnlineManager.
Init OnlineManager.
One OnlineManager must have at least one OnlineStrategy.
Args:
experiment_name (str): the experiment name.
rolling_gen (RollingGen): a instance of RollingGen
trainer (Trainer, optional): a instance of Trainer. Defaults to None.
collector (Collector, optional): a instance of Collector. Defaults to None.
strategy (Union[OnlineStrategy, List[OnlineStrategy]]): an instance of OnlineStrategy or a list of OnlineStrategy
begin_time (Union[str,pd.Timestamp], optional): the OnlineManager will begin at this time. Defaults to None for using latest date.
freq (str, optional): data frequency. Defaults to "day".
need_log (bool, optional): print log or not. Defaults to True.
"""
if trainer is None:
trainer = TrainerR(experiment_name)
super().__init__(experiment_name=experiment_name, trainer=trainer, need_log=need_log)
self.ta = TimeAdjuster()
self.rg = rolling_gen
self.logger = get_module_logger(self.__class__.__name__)
self.need_log = need_log
if not isinstance(strategy, list):
strategy = [strategy]
self.strategy = strategy
self.freq = freq
if begin_time is None:
begin_time = D.calendar(freq=self.freq).max()
self.begin_time = pd.Timestamp(begin_time)
self.cur_time = self.begin_time
self.history = {}
def get_collector(self, rec_key_func=None, rec_filter_func=None):
def first_train(self):
"""
Get the instance of collector to collect results. The returned collector must can distinguish results in different models.
Assumption: the models can be distinguished based on model name and rolling test segments.
If you do not want this assumption, please implement your own method or use another rec_key_func.
Run every strategy first_train method and record the online history.
"""
for strategy in self.strategy:
self.logger.info(f"Strategy `{strategy.name_id}` begins first training...")
online_models = strategy.first_train()
self.history.setdefault(strategy.name_id, {})[self.cur_time] = online_models
def routine(self, cur_time: Union[str, pd.Timestamp] = None, task_kwargs: dict = {}, model_kwargs: dict = {}):
"""
Run typical update process for every strategy and record the online history.
The typical update process after a routine, such as day by day or month by month.
The process is: Prepare signals -> Prepare tasks -> Prepare online models.
Args:
rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
cur_time (Union[str,pd.Timestamp], optional): run routine method in this time. Defaults to None.
task_kwargs (dict): the params for `prepare_tasks`
model_kwargs (dict): the params for `prepare_online_models`
"""
if cur_time is None:
cur_time = D.calendar(freq=self.freq).max()
self.cur_time = pd.Timestamp(cur_time) # None for latest date
for strategy in self.strategy:
if self.need_log:
self.logger.info(f"Strategy `{strategy.name_id}` begins routine...")
if not strategy.trainer.is_delay():
strategy.prepare_signals()
tasks = strategy.prepare_tasks(self.cur_time, **task_kwargs)
online_models = strategy.prepare_online_models(tasks, **model_kwargs)
if len(online_models) > 0:
self.history.setdefault(strategy.name_id, {})[self.cur_time] = online_models
def rec_key(recorder):
task_config = recorder.load_object("task")
model_key = task_config["model"]["class"]
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
return model_key, rolling_key
if rec_key_func is None:
rec_key_func = rec_key
return RecorderCollector(experiment=self.exp_name, rec_key_func=rec_key_func, rec_filter_func=rec_filter_func)
def collect_artifact(self, rec_key_func=None, rec_filter_func=None):
def get_collector(self) -> HyperCollector:
"""
collecting artifact based on the collector and RollingGroup.
Args:
rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results from every strategy.
Returns:
dict: the artifact dict after rolling ensemble
HyperCollector: the collector to collect other collectors (using SingleKeyEnsemble() to make results more readable).
"""
artifact = ens_workflow(
self.get_collector(rec_key_func=rec_key_func, rec_filter_func=rec_filter_func), RollingGroup()
)
return artifact
collector_dict = {}
for strategy in self.strategy:
collector_dict[strategy.name_id] = strategy.get_collector()
return HyperCollector(collector_dict, process_list=SingleKeyEnsemble())
def first_train(self, task_configs: list):
def get_online_history(self, strategy_name_id: str) -> list:
"""
Use rolling_gen to generate different tasks based on task_configs and trained them.
Get the online history based on strategy_name_id.
Args:
task_configs (list or dict): a list of task configs or a task config
strategy_name_id (str): the name_id of strategy
Returns:
Collector: a instance of a Collector.
list: a list like [(begin_time, [online_models])]
"""
tasks = task_generator(
tasks=task_configs,
generators=self.rg, # generate different date segment
)
self.prepare_new_models(tasks, tag=self.ONLINE_TAG)
history_dict = self.history[strategy_name_id]
history = []
for time in sorted(history_dict):
models = history_dict[time]
history.append((time, models))
return history
def delay_prepare(self, delay_kwargs={}):
"""
Prepare all models and signals if there are something waiting for prepare.
Args:
delay_kwargs: the params for `delay_prepare`
"""
for strategy in self.strategy:
strategy.delay_prepare(self.get_online_history(strategy.name_id), **delay_kwargs)
def get_signals(self) -> pd.DataFrame:
"""
Average all strategy signals as the online signals.
Assumption: the signals from every strategy is pd.DataFrame. Override this function to change.
Returns:
pd.DataFrame: signals
"""
signals_dict = {}
for strategy in self.strategy:
signals_dict[strategy.name_id] = strategy.get_signals()
return AverageEnsemble()(signals_dict)
def simulate(self, end_time, frequency="day", task_kwargs={}, model_kwargs={}, delay_kwargs={}) -> HyperCollector:
"""
Starting from current time, this method will simulate every routine in OnlineManager until end time.
Considering the parallel training, the models and signals can be perpared after all routine simulating.
The delay training way can be ``DelayTrainer`` and the delay preparing signals way can be ``delay_prepare``.
Returns:
HyperCollector: the OnlineManager's collector
"""
cal = D.calendar(start_time=self.cur_time, end_time=end_time, freq=frequency)
self.first_train()
for cur_time in cal:
self.logger.info(f"Simulating at {str(cur_time)}......")
self.routine(cur_time, task_kwargs=task_kwargs, model_kwargs=model_kwargs)
self.delay_prepare(delay_kwargs=delay_kwargs)
self.logger.info(f"Finished preparing signals")
return self.get_collector()
def prepare_tasks(self):
def reset(self):
"""
Prepare new tasks based on new date.
This method will reset all strategy!
Returns:
list: a list of new tasks.
**Be careful to use it.**
"""
latest_records, max_test = self.list_latest_recorders(
lambda rec: self.get_online_tag(rec) == OnlineManager.ONLINE_TAG
)
if max_test is None:
self.logger.warn(f"No latest online recorders, no new tasks.")
return []
calendar_latest = D.calendar(end_time=self.cur_time)[-1] if self.cur_time is None else self.cur_time
if self.need_log:
self.logger.info(
f"The interval between current time {calendar_latest} and last rolling test begin time {max_test[0]} is {self.ta.cal_interval(calendar_latest, max_test[0])}, the rolling step is {self.rg.step}"
)
if self.ta.cal_interval(calendar_latest, max_test[0]) >= self.rg.step:
old_tasks = []
tasks_tmp = []
for rid, rec in latest_records.items():
task = rec.load_object("task")
old_tasks.append(deepcopy(task))
test_begin = task["dataset"]["kwargs"]["segments"]["test"][0]
# modify the test segment to generate new tasks
task["dataset"]["kwargs"]["segments"]["test"] = (test_begin, calendar_latest)
tasks_tmp.append(task)
new_tasks_tmp = task_generator(tasks_tmp, self.rg)
new_tasks = [task for task in new_tasks_tmp if task not in old_tasks]
return new_tasks
return []
def list_latest_recorders(self, rec_filter_func=None):
"""find latest recorders based on test segments.
Args:
rec_filter_func (Callable, optional): recorder filter. Defaults to None.
Returns:
dict, tuple: the latest recorders and the latest date of them
"""
recs_flt = list_recorders(self.exp_name, rec_filter_func)
if len(recs_flt) == 0:
return recs_flt, None
max_test = max(rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] for rec in recs_flt.values())
latest_rec = {}
for rid, rec in recs_flt.items():
if rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] == max_test:
latest_rec[rid] = rec
return latest_rec, max_test
self.cur_time = self.begin_time
self.history = {}
for strategy in self.strategy:
strategy.reset()

View File

@@ -1,72 +0,0 @@
from qlib.data import D
from qlib import get_module_logger
from qlib.workflow.online.manager import OnlineManager
class OnlineSimulator:
"""
To simulate online serving in the past, like a "online serving backtest".
"""
def __init__(
self,
start_time,
end_time,
online_manager: OnlineManager,
frequency="day",
):
"""
init OnlineSimulator.
Args:
start_time (str or pd.Timestamp): the start time of simulating.
end_time (str or pd.Timestamp): the end time of simulating. If None, then end_time is latest.
onlinemanager (OnlineManager): the instance of OnlineManager
frequency (str, optional): the data frequency. Defaults to "day".
"""
self.logger = get_module_logger(self.__class__.__name__)
self.cal = D.calendar(start_time=start_time, end_time=end_time, freq=frequency)
self.start_time = self.cal[0]
self.end_time = self.cal[-1]
self.olm = online_manager
if len(self.cal) == 0:
self.logger.warn(f"There is no need to simulate bacause start_time is larger than end_time.")
def simulate(self, *args, **kwargs):
"""
Starting from start time, this method will simulate every routine in OnlineManager.
NOTE: Considering the parallel training, the models and signals can be perpared after all routine simulating.
Returns:
Collector: the OnlineManager's collector
"""
self.rec_dict = {}
tmp_begin = self.start_time
tmp_end = None
prev_recorders = self.olm.online_models()
for cur_time in self.cal:
self.logger.info(f"Simulating at {str(cur_time)}......")
recorders = self.olm.routine(cur_time, True, *args, **kwargs)
if len(recorders) == 0:
tmp_end = cur_time
else:
self.rec_dict[(tmp_begin, tmp_end)] = prev_recorders
tmp_begin = cur_time
prev_recorders = recorders
self.rec_dict[(tmp_begin, self.end_time)] = prev_recorders
# finished perparing models (and pred) and signals
self.olm.delay_prepare(self.rec_dict)
self.logger.info(f"Finished preparing signals")
return self.olm.get_collector()
def online_models(self):
"""
Return a online models dict likes {(begin_time, end_time):[online models]}.
Returns:
dict
"""
if hasattr(self, "rec_dict"):
return self.rec_dict
self.logger.warn(f"Please call `simulate` firstly when calling `online_models`")
return {}

View File

@@ -0,0 +1,339 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
OnlineStrategy is a set of strategy for online serving.
"""
from copy import deepcopy
from typing import List, Tuple, Union
import pandas as pd
from qlib.data.data import D
from qlib.log import get_module_logger
from qlib.model.ens.ensemble import AverageEnsemble, SingleKeyEnsemble
from qlib.model.ens.group import RollingGroup
from qlib.model.trainer import Trainer, TrainerR
from qlib.workflow import R
from qlib.workflow.online.utils import OnlineTool, OnlineToolR
from qlib.workflow.recorder import Recorder
from qlib.workflow.task.collect import Collector, HyperCollector, RecorderCollector
from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.utils import TimeAdjuster, list_recorders
class OnlineStrategy:
"""
OnlineStrategy is working with `Online Manager <#Online Manager>`_, responsing how the tasks are generated, the models are updated and signals are perpared.
"""
def __init__(self, name_id: str, trainer: Trainer = None, need_log=True):
"""
Init OnlineStrategy.
This module **MUST** use `Trainer <../reference/api.html#Trainer>`_ to finishing model training.
Args:
name_id (str): a unique name or id
trainer (Trainer, optional): a instance of Trainer. Defaults to None.
need_log (bool, optional): print log or not. Defaults to True.
"""
self.name_id = name_id
self.trainer = trainer
self.logger = get_module_logger(self.__class__.__name__)
self.need_log = need_log
self.tool = OnlineTool()
def prepare_signals(self, delay: bool = False):
"""
After perparing the data of last routine (a box in box-plot) which means the end of the routine, we can prepare trading signals for next routine.
NOTE: Given a set prediction, all signals before these prediction end time will be prepared well.
Args:
delay: bool
If this method was called by `delay_prepare`
"""
raise NotImplementedError(f"Please implement the `prepare_signals` method.")
def prepare_tasks(self, *args, **kwargs):
"""
After the end of a routine, check whether we need to prepare and train some new tasks.
Return the new tasks waiting for training.
You can find last online models by OnlineTool.online_models.
"""
raise NotImplementedError(f"Please implement the `prepare_tasks` method.")
def prepare_online_models(self, tasks, check_func=None, **kwargs):
"""
Use trainer to train a list of tasks and set the trained model to `online`.
NOTE: This method will first offline all models and online the online models prepared by this method. So you can find last online models by OnlineTool.online_models if you still need them.
Args:
tasks (list): a list of tasks.
check_func: the method to judge if a model can be online.
The parameter is the model record and return True for online.
None for online every models.
**kwargs: will be passed to end_train which means will be passed to customized train method.
"""
if check_func is None:
check_func = lambda x: True
online_models = []
if len(tasks) > 0:
new_models = self.trainer.train(tasks, **kwargs)
for model in new_models:
if check_func(model):
online_models.append(model)
self.tool.reset_online_tag(online_models)
return online_models
def first_train(self):
"""
Train a series of models firstly and set some of them as online models.
"""
raise NotImplementedError(f"Please implement the `first_train` method.")
def get_collector(self) -> Collector:
"""
Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results of online serving.
For example:
1) collect predictions in Recorder
2) collect signals in .txt file
Returns:
Collector
"""
raise NotImplementedError(f"Please implement the `get_collector` method.")
def delay_prepare(self, history: list, **kwargs):
"""
Prepare all models and signals if there are something waiting for prepare.
Assumption: the predictions of online models need less than next begin_time, or this method will work in a wrong way.
Args:
history (list): an online models list likes [begin_time:[online models]].
**kwargs: will be passed to end_train which means will be passed to customized train method.
"""
for begin_time, recs_list in history:
self.trainer.end_train(recs_list, **kwargs)
self.tool.reset_online_tag(recs_list)
self.prepare_signals(delay=True)
def get_signals(self):
"""
Get prepared signals.
"""
raise NotImplementedError(f"Please implement the `get_signals` method.")
def reset(self):
"""
Delete all things and set them to default status. This method is convenient to explore the strategy for online simulation.
"""
pass
class RollingAverageStrategy(OnlineStrategy):
"""
This example strategy always use latest rolling model as online model and prepare trading signals using the average prediction of online models
"""
def __init__(
self,
name_id: str,
task_template: Union[dict, List[dict]],
rolling_gen: RollingGen,
trainer: Trainer = None,
need_log=True,
signal_exp_name="OnlineManagerSignals",
):
"""
Init RollingAverageStrategy.
Assumption: the str of name_id, the experiment name and the trainer's experiment name are same one.
Args:
name_id (str): a unique name or id. Will be also the name of Experiment.
task_template (Union[dict,List[dict]]): a list of task_template or a single template, which will be used to generate many tasks using rolling_gen.
rolling_gen (RollingGen): an instance of RollingGen
trainer (Trainer, optional): a instance of Trainer. Defaults to None.
need_log (bool, optional): print log or not. Defaults to True.
signal_exp_path (str): a specific experiment to save signals of different experiment.
"""
super().__init__(name_id=name_id, trainer=trainer, need_log=need_log)
self.exp_name = self.name_id
if not isinstance(task_template, list):
task_template = [task_template]
self.task_template = task_template
self.signal_exp_name = signal_exp_name
self.rg = rolling_gen
self.tool = OnlineToolR(self.exp_name)
self.ta = TimeAdjuster()
with R.start(experiment_name=self.signal_exp_name, recorder_name=self.exp_name, resume=True):
self.signal_rec = R.get_recorder() # the recorder to record signals
self.signal_rec.save_objects(**{"signals": None})
def get_collector(self, process_list=[RollingGroup()], rec_key_func=None, rec_filter_func=None, artifacts_key=None):
"""
Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results. The returned collector must can distinguish results in different models.
Assumption: the models can be distinguished based on model name and rolling test segments.
If you do not want this assumption, please implement your own method or use another rec_key_func.
Args:
rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
artifacts_key (List[str], optional): the artifacts key you want to get. If None, get all artifacts.
"""
def rec_key(recorder):
task_config = recorder.load_object("task")
model_key = task_config["model"]["class"]
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
return model_key, rolling_key
if rec_key_func is None:
rec_key_func = rec_key
artifacts_collector = RecorderCollector(
experiment=self.exp_name,
process_list=process_list,
rec_key_func=rec_key_func,
rec_filter_func=rec_filter_func,
artifacts_key=artifacts_key,
)
return artifacts_collector
def first_train(self) -> List[Recorder]:
"""
Use rolling_gen to generate different tasks based on task_template and trained them.
Returns:
List[Recorder]: a list of Recorder.
"""
tasks = task_generator(
tasks=self.task_template,
generators=self.rg, # generate different date segment
)
return self.prepare_online_models(tasks)
def prepare_tasks(self, cur_time) -> List[dict]:
"""
Prepare new tasks based on cur_time (None for latest).
You can find last online models by OnlineToolR.online_models.
Returns:
List[dict]: a list of new tasks.
"""
latest_records, max_test = self._list_latest(self.tool.online_models())
if max_test is None:
self.logger.warn(f"No latest online recorders, no new tasks.")
return []
calendar_latest = D.calendar(end_time=cur_time)[-1] if cur_time is None else cur_time
if self.need_log:
self.logger.info(
f"The interval between current time {calendar_latest} and last rolling test begin time {max_test[0]} is {self.ta.cal_interval(calendar_latest, max_test[0])}, the rolling step is {self.rg.step}"
)
if self.ta.cal_interval(calendar_latest, max_test[0]) >= self.rg.step:
old_tasks = []
tasks_tmp = []
for rec in latest_records:
task = rec.load_object("task")
old_tasks.append(deepcopy(task))
test_begin = task["dataset"]["kwargs"]["segments"]["test"][0]
# modify the test segment to generate new tasks
task["dataset"]["kwargs"]["segments"]["test"] = (test_begin, calendar_latest)
tasks_tmp.append(task)
new_tasks_tmp = task_generator(tasks_tmp, self.rg)
new_tasks = [task for task in new_tasks_tmp if task not in old_tasks]
return new_tasks
return []
def prepare_signals(self, delay=False, over_write=False) -> pd.DataFrame:
"""
Average the predictions of online models and offer a trading signals every routine.
The signals will be saved to `signal` file of a recorder named self.exp_name of a experiment using the name of `SIGNAL_EXP`
Even if the latest signal already exists, the latest calculation result will be overwritten.
.. note::
Given a prediction of a certain time, all signals before this time will be prepared well.
Args:
over_write (bool, optional): If True, the new signals will overwrite the file. If False, the new signals will append to the end of signals. Defaults to False.
Returns:
pd.DataFrame: the signals.
"""
if not delay:
self.tool.update_online_pred()
# Get a collector to average online models predictions
online_collector = self.get_collector(
process_list=[AverageEnsemble()],
rec_filter_func=lambda x: True if self.tool.get_online_tag(x) == self.tool.ONLINE_TAG else False,
artifacts_key="pred",
)
online_results = online_collector()
signals = online_results["pred"]
old_signals = self.get_signals()
if old_signals is not None and not over_write:
old_max = old_signals.index.get_level_values("datetime").max()
new_signals = signals.loc[old_max:]
signals = pd.concat([old_signals, new_signals], axis=0)
else:
new_signals = signals
if self.need_log:
self.logger.info(
f"Finished preparing new {len(new_signals)} signals to {self.signal_exp_name}/{self.exp_name}."
)
self.signal_rec.save_objects(**{"signals": signals})
return signals
def get_signals(self) -> object:
"""
Get signals from the recorder(named self.exp_name) of the experiment(named self.SIGNAL_EXP)
Returns:
object: signals
"""
signals = self.signal_rec.load_object("signals")
return signals
def _list_latest(self, rec_list: List[Recorder]):
"""
List latest recorder form rec_list
Args:
rec_list (List[Recorder]): a list of Recorder
Returns:
List[Recorder], pd.Timestamp: the latest recorders and its test end time
"""
if len(rec_list) == 0:
return rec_list, None
max_test = max(rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] for rec in rec_list)
latest_rec = []
for rec in rec_list:
if rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] == max_test:
latest_rec.append(rec)
return latest_rec, max_test
def reset(self):
"""
NOTE: This method will delete all recorder in Experiment and reset the Trainer!
"""
self.trainer.reset()
# delete models
exp = R.get_exp(experiment_name=self.exp_name)
for rid in exp.list_recorders():
exp.delete_recorder(rid)
# delete signals
for rid in list_recorders(self.signal_exp_name, lambda x: True if x.info["name"] == self.exp_name else False):
exp.delete_recorder(rid)

View File

@@ -1,18 +1,20 @@
from typing import Union, List
from qlib.data.dataset import DatasetH
from qlib.workflow import R
from qlib.data import D
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
Updater is a module to update artifacts such as predictions, when the stock data is updating.
"""
from abc import ABCMeta, abstractmethod
import pandas as pd
from qlib import get_module_logger
from qlib.workflow import R
from qlib.model import Model
from qlib.model.trainer import task_train
from qlib.workflow.recorder import Recorder
from qlib.workflow.task.utils import list_recorders
from qlib.data.dataset.handler import DataHandlerLP
from qlib.data import D
from qlib.data.dataset import DatasetH
from abc import ABCMeta, abstractmethod
from qlib.data.dataset.handler import DataHandlerLP
from qlib.model import Model
from qlib.utils import get_date_by_shift
from qlib.workflow.recorder import Recorder
class RMDLoader:
@@ -25,19 +27,22 @@ class RMDLoader:
def get_dataset(self, start_time, end_time, segments=None) -> DatasetH:
"""
load, config and setup dataset.
Load, config and setup dataset.
This dataset is for inference
This dataset is for inference.
Args:
start_time :
the start_time of underlying data
end_time :
the end_time of underlying data
segments : dict
the segments config for dataset
Due to the time series dataset (TSDatasetH), the test segments maybe different from start_time and end_time
Returns:
DatasetH: the instance of DatasetH
Parameters
----------
start_time :
the start_time of underlying data
end_time :
the end_time of underlying data
segments : dict
the segments config for dataset
Due to the time series dataset (TSDatasetH), the test segments maybe different from start_time and end_time
"""
if segments is None:
segments = {"test": (start_time, end_time)}
@@ -52,7 +57,7 @@ class RMDLoader:
class RecordUpdater(metaclass=ABCMeta):
"""
Updata a specific recorders
Update a specific recorders
"""
def __init__(self, record: Recorder, need_log=True, *args, **kwargs):
@@ -75,17 +80,22 @@ class PredUpdater(RecordUpdater):
def __init__(self, record: Recorder, to_date=None, hist_ref: int = 0, freq="day", need_log=True):
"""
Parameters
----------
record : Recorder
to_date :
update to prediction to the `to_date`
hist_ref : int
Sometimes, the dataset will have historical depends.
Leave the problem to user to set the length of historical dependancy
NOTE: the start_time is not included in the hist_ref
# TODO: automate this step in the future.
Init PredUpdater.
Args:
record : Recorder
to_date :
update to prediction to the `to_date`
hist_ref : int
Sometimes, the dataset will have historical depends.
Leave the problem to user to set the length of historical dependency
.. note::
the start_time is not included in the hist_ref
"""
# TODO: automate this hist_ref in the future.
super().__init__(record=record, need_log=need_log)
self.to_date = to_date
@@ -101,9 +111,12 @@ class PredUpdater(RecordUpdater):
def prepare_data(self) -> DatasetH:
"""
# Load dataset
Load dataset
Seperating this function will make it easier to reuse the dataset
Returns:
DatasetH: the instance of DatasetH
"""
start_time_buffer = get_date_by_shift(self.last_end, -self.hist_ref + 1, clip_shift=False, freq=self.freq)
start_time = get_date_by_shift(self.last_end, 1, freq=self.freq)
@@ -113,9 +126,12 @@ class PredUpdater(RecordUpdater):
def update(self, dataset: DatasetH = None):
"""
update the precition in a recorder
Update the precition in a recorder
Args:
DatasetH: the instance of DatasetH. None for reprepare.
"""
# FIXME: the problme below is not solved
# FIXME: the problem below is not solved
# The model dumped on GPU instances can not be loaded on CPU instance. Follow exception will raised
# RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
# https://github.com/pytorch/pytorch/issues/16797

View File

@@ -0,0 +1,170 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
OnlineTool is a module to set and unset a series of `online` models.
The `online` models are some decisive models in some time point, which can be changed with the change of time.
This allows us to use efficient submodels as the market style changing.
"""
from typing import List, Union
from qlib.log import get_module_logger
from qlib.workflow.online.update import PredUpdater
from qlib.workflow.recorder import Recorder
from qlib.workflow.task.utils import list_recorders
class OnlineTool:
"""
OnlineTool.
"""
ONLINE_KEY = "online_status" # the online status key in recorder
ONLINE_TAG = "online" # the 'online' model
OFFLINE_TAG = "offline" # the 'offline' model, not for online serving
def __init__(self, need_log=True):
"""
Init OnlineTool.
Args:
need_log (bool, optional): print log or not. Defaults to True.
"""
self.logger = get_module_logger(self.__class__.__name__)
self.need_log = need_log
def set_online_tag(self, tag, recorder: Union[list, object]):
"""
Set `tag` to the model to sign whether online.
Args:
tag (str): the tags in `ONLINE_TAG`, `OFFLINE_TAG`
recorder (Union[list,object]): the model's recorder
"""
raise NotImplementedError(f"Please implement the `set_online_tag` method.")
def get_online_tag(self, recorder: object) -> str:
"""
Given a model recorder and return its online tag.
Args:
recorder (Object): the model's recorder
Returns:
str: the online tag
"""
raise NotImplementedError(f"Please implement the `get_online_tag` method.")
def reset_online_tag(self, recorder: Union[list, object]):
"""
Offline all models and set the recorders to 'online'.
Args:
recorder (Union[list,object]):
the recorder you want to reset to 'online'.
"""
raise NotImplementedError(f"Please implement the `reset_online_tag` method.")
def online_models(self) -> list:
"""
Get current `online` models
Returns:
list: a list of `online` models.
"""
raise NotImplementedError(f"Please implement the `online_models` method.")
def update_online_pred(self, to_date=None):
"""
Update the predictions of `online` models to a date.
Args:
to_date (pd.Timestamp): the pred before this date will be updated. None for update to latest.
"""
raise NotImplementedError(f"Please implement the `update_online_pred` method.")
class OnlineToolR(OnlineTool):
"""
The implementation of OnlineTool based on (R)ecorder.
"""
def __init__(self, experiment_name: str, need_log=True):
"""
Init OnlineToolR.
Args:
experiment_name (str): the experiment name.
need_log (bool, optional): print log or not. Defaults to True.
"""
super().__init__(need_log=need_log)
self.exp_name = experiment_name
def set_online_tag(self, tag, recorder: Union[Recorder, List]):
"""
Set `tag` to the model's recorder to sign whether online.
Args:
tag (str): the tags in `ONLINE_TAG`, `NEXT_ONLINE_TAG`, `OFFLINE_TAG`
recorder (Union[Recorder, List]): a list of Recorder or an instance of Recorder
"""
if isinstance(recorder, Recorder):
recorder = [recorder]
for rec in recorder:
rec.set_tags(**{self.ONLINE_KEY: tag})
if self.need_log:
self.logger.info(f"Set {len(recorder)} models to '{tag}'.")
def get_online_tag(self, recorder: Recorder) -> str:
"""
Given a model recorder and return its online tag.
Args:
recorder (Recorder): an instance of recorder
Returns:
str: the online tag
"""
tags = recorder.list_tags()
return tags.get(self.ONLINE_KEY, self.OFFLINE_TAG)
def reset_online_tag(self, recorder: Union[Recorder, List]):
"""
Offline all models and set the recorders to 'online'.
Args:
recorder (Union[Recorder, List]):
the recorder you want to reset to 'online'.
"""
if isinstance(recorder, Recorder):
recorder = [recorder]
recs = list_recorders(self.exp_name)
self.set_online_tag(self.OFFLINE_TAG, list(recs.values()))
self.set_online_tag(self.ONLINE_TAG, recorder)
def online_models(self) -> list:
"""
Get current `online` models
Returns:
list: a list of `online` models.
"""
return list(list_recorders(self.exp_name, lambda rec: self.get_online_tag(rec) == self.ONLINE_TAG).values())
def update_online_pred(self, to_date=None):
"""
Update the predictions of online models to a date.
Args:
to_date (pd.Timestamp): the pred before this date will be updated. None for update to latest time in Calendar.
"""
online_models = self.online_models()
for rec in online_models:
PredUpdater(rec, to_date=to_date, need_log=self.need_log).update()
if self.need_log:
self.logger.info(f"Finished updating {len(online_models)} online model predictions of {self.exp_name}.")

View File

@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import re
import re, logging
import pandas as pd
from pathlib import Path
from pprint import pprint
@@ -13,10 +13,10 @@ from ..data.dataset.handler import DataHandlerLP
from ..utils import init_instance_by_config, get_module_by_module_path
from ..log import get_module_logger
from ..utils import flatten_dict
from ..contrib.eva.alpha import calc_ic, calc_long_short_return
from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec
from ..contrib.strategy.strategy import BaseStrategy
logger = get_module_logger("workflow", "INFO")
logger = get_module_logger("workflow", logging.INFO)
class RecordTemp:
@@ -166,6 +166,60 @@ class SignalRecord(RecordTemp):
return super().load(name)
class HFSignalRecord(SignalRecord):
"""
This is the Signal Analysis Record class that generates the analysis results such as IC and IR. This class inherits the ``RecordTemp`` class.
"""
artifact_path = "hg_sig_analysis"
def __init__(self, recorder, **kwargs):
super().__init__(recorder=recorder)
def generate(self):
pred = self.load("pred.pkl")
raw_label = self.load("label.pkl")
long_pre, short_pre = calc_long_short_prec(pred.iloc[:, 0], raw_label.iloc[:, 0], is_alpha=True)
ic, ric = calc_ic(pred.iloc[:, 0], raw_label.iloc[:, 0])
metrics = {
"IC": ic.mean(),
"ICIR": ic.mean() / ic.std(),
"Rank IC": ric.mean(),
"Rank ICIR": ric.mean() / ric.std(),
"Long precision": long_pre.mean(),
"Short precision": short_pre.mean(),
}
objects = {"ic.pkl": ic, "ric.pkl": ric}
objects.update({"long_pre.pkl": long_pre, "short_pre.pkl": short_pre})
long_short_r, long_avg_r = calc_long_short_return(pred.iloc[:, 0], raw_label.iloc[:, 0])
metrics.update(
{
"Long-Short Average Return": long_short_r.mean(),
"Long-Short Average Sharpe": long_short_r.mean() / long_short_r.std(),
}
)
objects.update(
{
"long_short_r.pkl": long_short_r,
"long_avg_r.pkl": long_avg_r,
}
)
self.recorder.log_metrics(**metrics)
self.recorder.save_objects(**objects, artifact_path=self.get_path())
pprint(metrics)
def list(self):
paths = [
self.get_path("ic.pkl"),
self.get_path("ric.pkl"),
self.get_path("long_pre.pkl"),
self.get_path("short_pre.pkl"),
self.get_path("long_short_r.pkl"),
self.get_path("long_avg_r.pkl"),
]
return paths
class SigAnaRecord(SignalRecord):
"""
This is the Signal Analysis Record class that generates the analysis results such as IC and IR. This class inherits the ``RecordTemp`` class.

View File

@@ -1,14 +1,14 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import mlflow
import mlflow, logging
import shutil, os, pickle, tempfile, codecs, pickle
from pathlib import Path
from datetime import datetime
from ..utils.objm import FileManager
from ..log import get_module_logger
logger = get_module_logger("workflow", "INFO")
logger = get_module_logger("workflow", logging.INFO)
class Recorder:

View File

@@ -1,8 +1,12 @@
from abc import abstractmethod
from typing import Callable, Union
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
Collector can collect object from everywhere and process them such as merging, grouping, averaging and so on.
"""
from qlib.model.ens.ensemble import SingleKeyEnsemble
from qlib.workflow import R
from qlib.workflow.task.utils import list_recorders
from qlib.utils.serial import Serializable
import dill as pickle
@@ -18,7 +22,7 @@ class Collector:
process_list = [process_list]
self.process_list = process_list
def collect(self):
def collect(self) -> dict:
"""Collect the results and return a dict like {key: things}
Returns:
@@ -35,7 +39,7 @@ class Collector:
raise NotImplementedError(f"Please implement the `collect` method.")
@staticmethod
def process_collect(collected_dict, process_list=[], *args, **kwargs):
def process_collect(collected_dict, process_list=[], *args, **kwargs) -> dict:
"""do a series of processing to the dict returned by collect and return a dict like {key: things}
For example: you can group and ensemble.
@@ -60,7 +64,7 @@ class Collector:
result[artifact] = value
return result
def __call__(self, *args, **kwargs):
def __call__(self, *args, **kwargs) -> dict:
"""
do the workflow including collect and process_collect
@@ -78,7 +82,7 @@ class Collector:
filepath (str): the path of file
Returns:
bool: if successed
bool: if succeeded
"""
try:
with open(filepath, "wb") as f:
@@ -109,6 +113,29 @@ class Collector:
raise TypeError(f"The instance of {type(collector)} is not a valid `Collector`!")
class HyperCollector(Collector):
"""
A collector to collect the results of other Collectors
"""
def __init__(self, collector_dict, process_list=[]):
"""
Args:
collector_dict (dict): the dict like {collector_key, Collector}
process_list (list or Callable): the list of processors or the instance of processor to process dict.
NOTE: process_list = [SingleKeyEnsemble()] can ignore key and use value directly if there is only one {k,v} in a dict.
This can make result more readable. If you want to maintain as it should be, just give a empty process list.
"""
super().__init__(process_list=process_list)
self.collector_dict = collector_dict
def collect(self) -> dict:
collect_dict = {}
for key, collector in self.collector_dict.items():
collect_dict[key] = collector()
return collect_dict
class RecorderCollector(Collector):
ART_KEY_RAW = "__raw"
@@ -131,10 +158,10 @@ class RecorderCollector(Collector):
artifacts_path (dict, optional): The artifacts name and its path in Recorder. Defaults to {"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}.
artifacts_key (str or List, optional): the artifacts key you want to get. If None, get all artifacts.
"""
super().__init__(process_list=process_list)
if isinstance(experiment, str):
experiment = R.get_exp(experiment_name=experiment)
self.experiment = experiment
self.process_list = process_list
self.artifacts_path = artifacts_path
if rec_key_func is None:
rec_key_func = lambda rec: rec.info["id"]
@@ -144,7 +171,7 @@ class RecorderCollector(Collector):
self.artifacts_key = artifacts_key
self._rec_filter_func = rec_filter_func
def collect(self, artifacts_key=None, rec_filter_func=None):
def collect(self, artifacts_key=None, rec_filter_func=None) -> dict:
"""Collect different artifacts based on recorder after filtering.
Args:
@@ -180,3 +207,12 @@ class RecorderCollector(Collector):
collect_dict.setdefault(key, {})[rec_key] = artifact
return collect_dict
def get_exp_name(self) -> str:
"""
Get experiment name
Returns:
str: experiment name
"""
return self.experiment.name

View File

@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
this is a task generator
Task generator can generate many tasks based on TaskGen and some task templates.
"""
import abc
import copy
@@ -113,7 +113,7 @@ class RollingGen(TaskGen):
self.test_key = "test"
self.train_key = "train"
def generate(self, task: dict):
def generate(self, task: dict) -> typing.List[dict]:
"""
Converting the task into a rolling task.
@@ -158,6 +158,10 @@ class RollingGen(TaskGen):
},
]
}
Returns
----------
typing.List[dict]: a list of tasks
"""
res = []
@@ -196,16 +200,18 @@ class RollingGen(TaskGen):
# update segments of this task
t["dataset"]["kwargs"]["segments"] = copy.deepcopy(segments)
# if end_time < the end of test_segments, then change end_time to allow load more data
if (
self.modify_end_time
and self.ta.cal_interval(
try:
interval = self.ta.cal_interval(
t["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"],
t["dataset"]["kwargs"]["segments"][self.test_key][1],
)
< 0
):
t["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"] = copy.deepcopy(segments[self.test_key][1])
# if end_time < the end of test_segments, then change end_time to allow load more data
if self.modify_end_time and interval < 0:
t["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"] = copy.deepcopy(segments[self.test_key][1])
except KeyError:
# Maybe the user dataset has no handler or end_time
pass
prev_seg = segments
res.append(t)
return res

View File

@@ -1,31 +1,39 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
A task consists of 3 parts
TaskManager can fetch unused tasks automatically and manager the lifecycle of a set of tasks with error handling.
These features can run tasks concurrently and ensure every task will be used only once.
Task Manager will store all tasks in `MongoDB <https://www.mongodb.com/>`_.
Users **MUST** finished the configuration of `MongoDB <https://www.mongodb.com/>`_ when using this module.
A task in TaskManager consists of 3 parts
- tasks description: the desc will define the task
- tasks status: the status of the task
- tasks result information : A user can get the task with the task description and task result.
"""
from bson.binary import Binary
import pickle
from pymongo.errors import InvalidDocument
from bson.objectid import ObjectId
from contextlib import contextmanager
import qlib
from tqdm.cli import tqdm
import time
import concurrent
import pymongo
from qlib.config import C
from .utils import get_mongodb
from qlib import get_module_logger, auto_init
import pickle
import time
from contextlib import contextmanager
from typing import Callable, List
import fire
import pymongo
from bson.binary import Binary
from bson.objectid import ObjectId
from pymongo.errors import InvalidDocument
from qlib import auto_init, get_module_logger
from tqdm.cli import tqdm
from .utils import get_mongodb
class TaskManager:
"""TaskManager
here is what will a task looks like when it created by TaskManager
"""
TaskManager
Here is what will a task looks like when it created by TaskManager
.. code-block:: python
@@ -42,6 +50,16 @@ class TaskManager:
.. note::
Assumption: the data in MongoDB was encoded and the data out of MongoDB was decoded
Here are four status which are:
STATUS_WAITING: waiting for train
STATUS_RUNNING: training
STATUS_PART_DONE: finished some step and waiting for next step
STATUS_DONE: all work done
"""
STATUS_WAITING = "waiting"
@@ -53,7 +71,7 @@ class TaskManager:
def __init__(self, task_pool: str = None):
"""
init Task Manager, remember to make the statement of MongoDB url and database name firstly.
Init Task Manager, remember to make the statement of MongoDB url and database name firstly.
Parameters
----------
@@ -65,7 +83,7 @@ class TaskManager:
self.task_pool = getattr(self.mdb, task_pool)
self.logger = get_module_logger(self.__class__.__name__)
def list(self):
def list(self) -> list:
"""
list the all collection(task_pool) of the db
@@ -92,7 +110,9 @@ class TaskManager:
return {k: str(v) for k, v in flt.items()}
def replace_task(self, task, new_task):
# assume that the data out of interface was decoded and the data in interface was encoded
"""
Use a new task to replace a old one
"""
new_task = self._encode_task(new_task)
query = {"_id": ObjectId(task["_id"])}
try:
@@ -121,7 +141,7 @@ class TaskManager:
Returns
-------
pymongo.results.InsertOneResult
"""
task = self._encode_task(
{
@@ -133,9 +153,9 @@ class TaskManager:
insert_result = self.insert_task(task)
return insert_result
def create_task(self, task_def_l, dry_run=False, print_nt=False):
def create_task(self, task_def_l, dry_run=False, print_nt=False) -> List[str]:
"""
if the tasks in task_def_l is new, then insert new tasks into the task_pool
If the tasks in task_def_l is new, then insert new tasks into the task_pool
Parameters
----------
@@ -145,6 +165,7 @@ class TaskManager:
if insert those new tasks to task pool
print_nt: bool
if print new task
Returns
-------
list
@@ -165,7 +186,7 @@ class TaskManager:
print(t)
if dry_run:
return
return []
_id_list = []
for t in new_tasks:
@@ -174,7 +195,17 @@ class TaskManager:
return _id_list
def fetch_task(self, query={}, status=STATUS_WAITING):
def fetch_task(self, query={}, status=STATUS_WAITING) -> dict:
"""
Use query to fetch tasks
Args:
query (dict, optional): query dict. Defaults to {}.
status (str, optional): [description]. Defaults to STATUS_WAITING.
Returns:
dict: a task(document in collection) after decoding
"""
query = query.copy()
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
@@ -191,7 +222,7 @@ class TaskManager:
@contextmanager
def safe_fetch_task(self, query={}, status=STATUS_WAITING):
"""
fetch task from task_pool using query with contextmanager
Fetch task from task_pool using query with contextmanager
Parameters
----------
@@ -200,7 +231,7 @@ class TaskManager:
Returns
-------
dict: a task(document in collection) after decoding
"""
task = self.fetch_task(query=query, status=status)
try:
@@ -231,7 +262,7 @@ class TaskManager:
Returns
-------
dict: a task(document in collection) after decoding
"""
query = query.copy()
if "_id" in query:
@@ -240,16 +271,40 @@ class TaskManager:
yield self._decode_task(t)
def re_query(self, _id):
"""
Use _id to query task.
Args:
_id (str): _id of a document
Returns:
dict: a task(document in collection) after decoding
"""
t = self.task_pool.find_one({"_id": ObjectId(_id)})
return self._decode_task(t)
def commit_task_res(self, task, res, status=None):
def commit_task_res(self, task, res, status=STATUS_DONE):
"""
Commit the result to task['res'].
Args:
task ([type]): [description]
res (object): the result you want to save
status (str, optional): STATUS_WAITING, STATUS_RUNNING, STATUS_DONE, STATUS_PART_DONE. Defaults to STATUS_DONE.
"""
# A workaround to use the class attribute.
if status is None:
status = TaskManager.STATUS_DONE
self.task_pool.update_one({"_id": task["_id"]}, {"$set": {"status": status, "res": Binary(pickle.dumps(res))}})
def return_task(self, task, status=None):
def return_task(self, task, status=STATUS_WAITING):
"""
Return a task to status. Alway using in error handling.
Args:
task ([type]): [description]
status (str, optional): STATUS_WAITING, STATUS_RUNNING, STATUS_DONE, STATUS_PART_DONE. Defaults to STATUS_WAITING.
"""
if status is None:
status = TaskManager.STATUS_WAITING
update_dict = {"$set": {"status": status}}
@@ -257,7 +312,7 @@ class TaskManager:
def remove(self, query={}):
"""
remove the task using query
Remove the task using query
Parameters
----------
@@ -295,7 +350,7 @@ class TaskManager:
def prioritize(self, task, priority: int):
"""
set priority for task
Set priority for task
Parameters
----------
@@ -331,29 +386,41 @@ class TaskManager:
def run_task(
task_func,
task_pool,
force_release=False,
before_status=TaskManager.STATUS_WAITING,
after_status=TaskManager.STATUS_DONE,
*args,
task_func: Callable,
task_pool: str,
force_release: bool = False,
before_status: str = TaskManager.STATUS_WAITING,
after_status: str = TaskManager.STATUS_DONE,
**kwargs,
):
"""
While task pool is not empty (has WAITING tasks), use task_func to fetch and run tasks in task_pool
After running this method, here are 4 situations (before_status -> after_status):
STATUS_WAITING -> STATUS_DONE: use task["def"] as `task_func` param
STATUS_WAITING -> STATUS_PART_DONE: use task["def"] as `task_func` param
STATUS_PART_DONE -> STATUS_PART_DONE: use task["res"] as `task_func` param
STATUS_PART_DONE -> STATUS_DONE: use task["res"] as `task_func` param
Parameters
----------
task_func : def (task_def, *args, **kwargs) -> <res which will be committed>
the function to run the task
task_func : Callable
def (task_def, **kwargs) -> <res which will be committed>
the function to run the task
task_pool : str
the name of the task pool (Collection in MongoDB)
force_release :
force_release : bool
will the program force to release the resource
args :
args
kwargs :
kwargs
before_status : str:
the tasks in before_status will be fetched and trained. Can be STATUS_WAITING, STATUS_PART_DONE.
after_status : str:
the tasks after trained will become after_status. Can be STATUS_WAITING, STATUS_PART_DONE.
kwargs
the params for `task_func`
"""
tm = TaskManager(task_pool)
@@ -364,19 +431,19 @@ def run_task(
if task is None:
break
get_module_logger("run_task").info(task["def"])
# when fetching `WAITING` task, use task_def to train
# when fetching `WAITING` task, use task["def"] to train
if before_status == TaskManager.STATUS_WAITING:
param = task["def"]
# when fetching `PART_DONE` task, use task_res to train for the result has been saved
# when fetching `PART_DONE` task, use task["res"] to train because the middle result has been saved to task["res"]
elif before_status == TaskManager.STATUS_PART_DONE:
param = task["res"]
else:
raise ValueError("The fetched task must be `STATUS_WAITING` or `STATUS_PART_DONE`!")
if force_release:
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
res = executor.submit(task_func, param, *args, **kwargs).result()
res = executor.submit(task_func, param, **kwargs).result()
else:
res = task_func(param, *args, **kwargs)
res = task_func(param, **kwargs)
tm.commit_task_res(task, res, status=after_status)
ever_run = True

View File

@@ -1,5 +1,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
Some tools for task management.
"""
import bisect
import pandas as pd
from qlib.data import D
@@ -7,13 +12,14 @@ from qlib.workflow import R
from qlib.config import C
from qlib.log import get_module_logger
from pymongo import MongoClient
from pymongo.database import Database
from typing import Union
def get_mongodb():
"""
def get_mongodb() -> Database:
get database in MongoDB, which means you need to declare the address and the name of database.
"""
Get database in MongoDB, which means you need to declare the address and the name of database.
for example:
Using qlib.init():
@@ -31,6 +37,8 @@ def get_mongodb():
"task_db_name" : "rolling_db"
}
Returns:
Database: the Database instance
"""
try:
cfg = C["mongo"]
@@ -43,7 +51,8 @@ def get_mongodb():
def list_recorders(experiment, rec_filter_func=None):
"""list all recorders which can pass the filter in a experiment.
"""
List all recorders which can pass the filter in a experiment.
Args:
experiment (str or Experiment): the name of a Experiment or a instance
@@ -65,7 +74,7 @@ def list_recorders(experiment, rec_filter_func=None):
class TimeAdjuster:
"""
find appropriate date and adjust date.
Find appropriate date and adjust date.
"""
def __init__(self, future=True, end_time=None):
@@ -88,15 +97,15 @@ class TimeAdjuster:
return None
return self.cals[idx]
def max(self):
def max(self) -> pd.Timestamp:
"""
Return the max calendar datetime
"""
return max(self.cals)
def align_idx(self, time_point, tp_type="start"):
def align_idx(self, time_point, tp_type="start") -> int:
"""
align the index of time_point in the calendar
Align the index of time_point in the calendar
Parameters
----------
@@ -116,9 +125,9 @@ class TimeAdjuster:
raise NotImplementedError(f"This type of input is not supported")
return idx
def cal_interval(self, time_point_A, time_point_B):
def cal_interval(self, time_point_A, time_point_B) -> int:
"""
calculate the trading day interval
Calculate the trading day interval (time_point_A - time_point_B)
Args:
time_point_A : time_point_A
@@ -129,20 +138,22 @@ class TimeAdjuster:
"""
return self.align_idx(time_point_A) - self.align_idx(time_point_B)
def align_time(self, time_point, tp_type="start"):
def align_time(self, time_point, tp_type="start") -> pd.Timestamp:
"""
Align time_point to trade date of calendar
Parameters
----------
time_point
Time point
tp_type : str
time point type (`"start"`, `"end"`)
Args:
time_point
Time point
tp_type : str
time point type (`"start"`, `"end"`)
Returns:
pd.Timestamp
"""
return self.cals[self.align_idx(time_point, tp_type=tp_type)]
def align_seg(self, segment: Union[dict, tuple]):
def align_seg(self, segment: Union[dict, tuple]) -> Union[dict, tuple]:
"""
align the given date to trade date
@@ -162,7 +173,7 @@ class TimeAdjuster:
Returns
-------
the start and end trade date (pd.Timestamp) between the given start and end date.
Union[dict, tuple]: the start and end trade date (pd.Timestamp) between the given start and end date.
"""
if isinstance(segment, dict):
return {k: self.align_seg(seg) for k, seg in segment.items()}
@@ -171,7 +182,7 @@ class TimeAdjuster:
else:
raise NotImplementedError(f"This type of input is not supported")
def truncate(self, segment: tuple, test_start, days: int):
def truncate(self, segment: tuple, test_start, days: int) -> tuple:
"""
truncate the segment based on the test_start date
@@ -183,6 +194,10 @@ class TimeAdjuster:
days : int
The trading days to be truncated
the data in this segment may need 'days' data
Returns
---------
tuple: new segment
"""
test_idx = self.align_idx(test_start)
if isinstance(segment, tuple):
@@ -198,7 +213,7 @@ class TimeAdjuster:
SHIFT_SD = "sliding"
SHIFT_EX = "expanding"
def shift(self, seg: tuple, step: int, rtype=SHIFT_SD):
def shift(self, seg: tuple, step: int, rtype=SHIFT_SD) -> tuple:
"""
shift the datatime of segment
@@ -211,6 +226,10 @@ class TimeAdjuster:
rtype : str
rolling type ("sliding" or "expanding")
Returns
--------
tuple: new segment
Raises
------
KeyError:

View File

@@ -1,12 +1,12 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys, traceback, signal, atexit
import sys, traceback, signal, atexit, logging
from . import R
from .recorder import Recorder
from ..log import get_module_logger
logger = get_module_logger("workflow", "INFO")
logger = get_module_logger("workflow", logging.INFO)
# function to handle the experiment when unusual program ending occurs

View File

@@ -0,0 +1,24 @@
# Get future trading days
> `D.calendar(future=True)` will be used
## Requirements
```bash
pip install -r requirements.txt
```
## Collector Data
```bash
# parse instruments, using in qlib/instruments.
python future_trading_date_collector.py --qlib_dir ~/.qlib/qlib_data/cn_data --freq day
```
## Parameters
- qlib_dir: qlib data directory
- freq: value from [`day`, `1min`], default `day`

View File

@@ -0,0 +1,87 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys
from typing import List
from pathlib import Path
import fire
import numpy as np
import pandas as pd
from loguru import logger
# get data from baostock
import baostock as bs
CUR_DIR = Path(__file__).resolve().parent
sys.path.append(str(CUR_DIR.parent.parent))
from data_collector.utils import generate_minutes_calendar_from_daily
def read_calendar_from_qlib(qlib_dir: Path) -> pd.DataFrame:
calendar_path = qlib_dir.joinpath("calendars").joinpath("day.txt")
if not calendar_path.exists():
return pd.DataFrame()
return pd.read_csv(calendar_path, header=None)
def write_calendar_to_qlib(qlib_dir: Path, date_list: List[str], freq: str = "day"):
calendar_path = str(qlib_dir.joinpath("calendars").joinpath(f"{freq}_future.txt"))
np.savetxt(calendar_path, date_list, fmt="%s", encoding="utf-8")
logger.info(f"write future calendars success: {calendar_path}")
def generate_qlib_calendar(date_list: List[str], freq: str) -> List[str]:
print(freq)
if freq == "day":
return date_list
elif freq == "1min":
date_list = generate_minutes_calendar_from_daily(date_list, freq=freq).tolist()
return list(map(lambda x: pd.Timestamp(x).strftime("%Y-%m-%d %H:%M:%S"), date_list))
else:
raise ValueError(f"Unsupported freq: {freq}")
def future_calendar_collector(qlib_dir: [str, Path], freq: str = "day"):
"""get future calendar
Parameters
----------
qlib_dir: str or Path
qlib data directory
freq: str
value from ["day", "1min"], by default day
"""
qlib_dir = Path(qlib_dir).expanduser().resolve()
if not qlib_dir.exists():
raise FileNotFoundError(str(qlib_dir))
lg = bs.login()
if lg.error_code != "0":
logger.error(f"login error: {lg.error_msg}")
return
# read daily calendar
daily_calendar = read_calendar_from_qlib(qlib_dir)
end_year = pd.Timestamp.now().year
if daily_calendar.empty:
start_year = pd.Timestamp.now().year
else:
start_year = pd.Timestamp(daily_calendar.iloc[-1, 0]).year
rs = bs.query_trade_dates(start_date=pd.Timestamp(f"{start_year}-01-01"), end_date=f"{end_year}-12-31")
data_list = []
while (rs.error_code == "0") & rs.next():
_row_data = rs.get_row_data()
if int(_row_data[1]) == 1:
data_list.append(_row_data[0])
data_list = sorted(data_list)
date_list = generate_qlib_calendar(data_list, freq=freq)
write_calendar_to_qlib(qlib_dir, date_list, freq=freq)
bs.logout()
logger.info(f"get trading dates success: {start_year}-01-01 to {end_year}-12-31")
if __name__ == "__main__":
fire.Fire(future_calendar_collector)

View File

@@ -0,0 +1,5 @@
baostock
fire
numpy
pandas
loguru

View File

@@ -10,7 +10,9 @@ import random
import requests
import functools
from pathlib import Path
from typing import Iterable, Tuple
import numpy as np
import pandas as pd
from lxml import etree
from loguru import logger
@@ -418,5 +420,40 @@ def get_trading_date_by_shift(trading_list: list, trading_date: pd.Timestamp, sh
return res
def generate_minutes_calendar_from_daily(
calendars: Iterable,
freq: str = "1min",
am_range: Tuple[str, str] = ("09:30:00", "11:29:00"),
pm_range: Tuple[str, str] = ("13:00:00", "14:59:00"),
) -> pd.Index:
"""generate minutes calendar
Parameters
----------
calendars: Iterable
daily calendar
freq: str
by default 1min
am_range: Tuple[str, str]
AM Time Range, by default China-Stock: ("09:30:00", "11:29:00")
pm_range: Tuple[str, str]
PM Time Range, by default China-Stock: ("13:00:00", "14:59:00")
"""
daily_format: str = "%Y-%m-%d"
res = []
for _day in calendars:
for _range in [am_range, pm_range]:
res.append(
pd.date_range(
f"{pd.Timestamp(_day).strftime(daily_format)} {_range[0]}",
f"{pd.Timestamp(_day).strftime(daily_format)} {_range[1]}",
freq=freq,
)
)
return pd.Index(sorted(set(np.hstack(res))))
if __name__ == "__main__":
assert len(get_hs_stock_symbols()) >= MINIMUM_SYMBOLS_NUM

View File

@@ -24,7 +24,12 @@ from qlib.config import REG_CN as REGION_CN
CUR_DIR = Path(__file__).resolve().parent
sys.path.append(str(CUR_DIR.parent.parent))
from data_collector.base import BaseCollector, BaseNormalize, BaseRun
from data_collector.utils import get_calendar_list, get_hs_stock_symbols, get_us_stock_symbols
from data_collector.utils import (
get_calendar_list,
get_hs_stock_symbols,
get_us_stock_symbols,
generate_minutes_calendar_from_daily,
)
INDEX_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{index_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg={begin}&end={end}"
@@ -418,21 +423,9 @@ class YahooNormalize1min(YahooNormalize, ABC):
return calendar_list_1d
def generate_1min_from_daily(self, calendars: Iterable) -> pd.Index:
res = []
daily_format = self.DAILY_FORMAT
am_range = self.AM_RANGE
pm_range = self.PM_RANGE
for _day in calendars:
for _range in [am_range, pm_range]:
res.append(
pd.date_range(
f"{_day.strftime(daily_format)} {_range[0]}",
f"{_day.strftime(daily_format)} {_range[1]}",
freq="1min",
)
)
return pd.Index(sorted(set(np.hstack(res))))
return generate_minutes_calendar_from_daily(
calendars, freq="1min", am_range=self.AM_RANGE, pm_range=self.PM_RANGE
)
def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame:
# TODO: using daily data factor

View File

@@ -35,7 +35,7 @@ REQUIRED = [
"scipy>=1.0.0",
"requests>=2.18.0",
"sacred>=0.7.4",
"python-socketio==3.1.2",
"python-socketio",
"redis>=3.0.1",
"python-redis-lock>=3.3.1",
"schedule>=0.6.0",