mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-04 19:41:00 +08:00
Merge branch 'online_srv' into online_srv_blin
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
.. _task_managment:
|
||||
.. _task_management:
|
||||
|
||||
=================================
|
||||
Task Management
|
||||
@@ -10,15 +10,17 @@ Introduction
|
||||
=============
|
||||
|
||||
The `Workflow <../component/introduction.html>`_ part introduces how to run research workflow in a loosely-coupled way. But it can only execute one ``task`` when you use ``qrun``.
|
||||
To automatically generate and execute different tasks, ``Task Management`` provides a whole process including `Task Generating`_, `Task Storing`_, `Task Running`_ and `Task Collecting`_.
|
||||
To automatically generate and execute different tasks, ``Task Management`` provides a whole process including `Task Generating`_, `Task Storing`_, `Task Training`_ and `Task Collecting`_.
|
||||
With this module, users can run their ``task`` automatically at different periods, in different losses, or even by different models.
|
||||
|
||||
An example of the entire process is shown `here <https://github.com/microsoft/qlib/tree/main/examples/taskmanager/task_manager_rolling.py>`_.
|
||||
This whole process can be used in `Online Serving <../component/online.html>`_.
|
||||
|
||||
An example of the entire process is shown `here <https://github.com/microsoft/qlib/tree/main/examples/model_rolling/task_manager_rolling.py>`_.
|
||||
|
||||
Task Generating
|
||||
===============
|
||||
A ``task`` consists of `Model`, `Dataset`, `Record` or anything added by users.
|
||||
The specific task template(/definition/config) can be viewed in
|
||||
The specific task template can be viewed in
|
||||
`Task Section <../component/workflow.html#task-section>`_.
|
||||
Even though the task template is fixed, users can customize their ``TaskGen`` to generate different ``task`` by task template.
|
||||
|
||||
@@ -27,15 +29,16 @@ Here is the base class of ``TaskGen``:
|
||||
.. autoclass:: qlib.workflow.task.gen.TaskGen
|
||||
:members:
|
||||
|
||||
``Qlib`` provider a class `RollingGen <https://github.com/microsoft/qlib/tree/main/qlib/workflow/task/gen.py>`_ to generate a list of ``task`` of the dataset in different date segments.
|
||||
This class allows users to verify the effect of data from different periods on the model in one experiment.
|
||||
``Qlib`` provides a class `RollingGen <https://github.com/microsoft/qlib/tree/main/qlib/workflow/task/gen.py>`_ to generate a list of ``task`` of the dataset in different date segments.
|
||||
This class allows users to verify the effect of data from different periods on the model in one experiment. More information in `here <../reference/api.html#TaskGen>`_.
|
||||
|
||||
Task Storing
|
||||
===============
|
||||
To achieve higher efficiency and the possibility of cluster operation, ``Task Manager`` will store all tasks in `MongoDB <https://www.mongodb.com/>`_.
|
||||
``TaskManager`` can fetch undone tasks automatically and manage the lifecycle of a set of tasks with error handling.
|
||||
Users **MUST** finished the configuration of `MongoDB <https://www.mongodb.com/>`_ when using this module.
|
||||
|
||||
Users need to provide the URL and database name of ``task`` storing like this.
|
||||
Users need to provide the MongoDB URL and database name for using ``TaskManager`` in `initialization <../start/initialization.html#Parameters>`_ or make statement like this.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -45,13 +48,12 @@ Users need to provide the URL and database name of ``task`` storing like this.
|
||||
"task_db_name" : "rolling_db" # database name
|
||||
}
|
||||
|
||||
The CRUD methods of ``task`` can be found in TaskManager.
|
||||
More methods can be seen in the `Github <https://github.com/microsoft/qlib/tree/main/qlib/workflow/task/manage.py>`_.
|
||||
|
||||
.. autoclass:: qlib.workflow.task.manage.TaskManager
|
||||
:members:
|
||||
|
||||
Task Running
|
||||
More information of ``Task Manager`` can be found in `here <../reference/api.html#TaskManager>`_.
|
||||
|
||||
Task Training
|
||||
===============
|
||||
After generating and storing those ``task``, it's time to run the ``task`` which are in the *WAITING* status.
|
||||
``Qlib`` provides a method called ``run_task`` to run those ``task`` in task pool, however, users can also customize how tasks are executed.
|
||||
@@ -60,14 +62,24 @@ It will run the whole workflow defined by ``task``, which includes *Model*, *Dat
|
||||
|
||||
.. autofunction:: qlib.workflow.task.manage.run_task
|
||||
|
||||
Meanwhile, ``Qlib`` provides a module called ``Trainer``.
|
||||
``Trainer`` will train a list of tasks and return a list of model recorder.
|
||||
``Qlib`` offer two kind of Trainer, TrainerR is the simplest way and TrainerRM is based on TaskManager to help manager tasks lifecycle automatically.
|
||||
If you do not want to use ``Task Manager`` to manage tasks, then use TrainerR to train a list of tasks generated by ``TaskGen`` is enough.
|
||||
More information is in `here <../reference/api.html#Trainer>`_.
|
||||
|
||||
Task Collecting
|
||||
===============
|
||||
To see the results of ``task`` after running or to update something, ``Qlib`` provides a ``TaskCollector`` to collect the tasks by filter condition (optional).
|
||||
Here are some methods in this class.
|
||||
To collect the results of ``task`` after training, ``Qlib`` provides `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_ to collect the results in a readable, expandable and loosely-coupled way.
|
||||
|
||||
.. autoclass:: qlib.workflow.task.collect.TaskCollector
|
||||
:members:
|
||||
`Collector <../reference/api.html#Collector>`_ can collect object from everywhere and process them such as merging, grouping, averaging and so on. It has 2 step action including ``collect`` (collect anything in a dict) and ``process_collect`` (process collected dict).
|
||||
|
||||
``Qlib`` provides a concrete `example <https://github.com/microsoft/qlib/tree/main/examples/taskmanager/task_manager_rolling_with_updating.py>`_, including a whole process of `Task Generating`_ (using `RollingGen <https://github.com/microsoft/qlib/tree/main/qlib/workflow/task/gen.py>`_), `Task Storing`_, `Task Running`_ and `Task Collecting`_.
|
||||
Besides, the `example <https://github.com/microsoft/qlib/tree/main/examples/taskmanager/task_manager_rolling_with_updating.py>`_ uses a ``ModelUpdater`` inherited from ``TaskCollector``, which can update the inferences and retrain the model if it is out of date.
|
||||
Actually, the model updating can be viewed as a subset of ``Online Serving``.
|
||||
`Group <../reference/api.html#Group>`_ also has 2 steps including ``group`` (can group a set of object based on `group_func` and change them to a dict) and ``reduce`` (can make a dict become an ensemble based on some rule).
|
||||
For example: {(A,B,C1): object, (A,B,C2): object} ---``group``---> {(A,B): {C1: object, C2: object}} ---``reduce``---> {(A,B): object}
|
||||
|
||||
`Ensemble <../reference/api.html#Ensemble>`_ can merge the objects in an ensemble.
|
||||
For example: {C1: object, C2: object} ---``Ensemble``---> object
|
||||
|
||||
So the hierarchy is ``Collector``'s second step correspond to ``Group``. And ``Group``'s second step correspond to ``Ensemble``.
|
||||
|
||||
For more information, please see `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_, or the `example <https://github.com/microsoft/qlib/tree/main/examples/model_rolling/task_manager_rolling.py>`_
|
||||
@@ -182,6 +182,11 @@ The `trade unit` defines the unit number of stocks can be used in a trade, and t
|
||||
qlib.init(provider_uri='~/.qlib/qlib_data/us_data', region=REG_US)
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
PRs for new data source are highly welcome! Users could commit the code to crawl data as a PR like `the examples here <https://github.com/microsoft/qlib/tree/main/scripts>`_. And then we will use the code to create data cache on our server which other users could use directly.
|
||||
|
||||
|
||||
Data API
|
||||
========================
|
||||
|
||||
|
||||
41
docs/component/online.rst
Normal file
41
docs/component/online.rst
Normal file
@@ -0,0 +1,41 @@
|
||||
.. _online:
|
||||
|
||||
=================================
|
||||
Online Serving
|
||||
=================================
|
||||
.. currentmodule:: qlib
|
||||
|
||||
|
||||
Introduction
|
||||
=============
|
||||
In addition to backtesting, one way to test a model is effective is to make predictions in real market conditions or even do real trading based on those predictions.
|
||||
``Online Serving`` is a set of module for online models using latest data,
|
||||
which including `Online Manager <#Online Manager>`_, `Online Strategy <#Online Strategy>`_, `Online Tool <#Online Tool>`_, `Updater <#Updater>`_.
|
||||
|
||||
`Here <https://github.com/microsoft/qlib/tree/main/examples/online_srv>`_ are several examples for reference, which demonstrate different features of ``Online Serving``.
|
||||
If you have many models or `task` need to be managed, please consider `Task Management <../advanced/task_management.html>`_.
|
||||
The `examples <https://github.com/microsoft/qlib/tree/main/examples/online_srv>`_ maybe based on `Task Management <../advanced/task_management.html>`_ such as ``TrainerRM`` or ``Collector``.
|
||||
|
||||
Online Manager
|
||||
=============
|
||||
|
||||
.. automodule:: qlib.workflow.online.manager
|
||||
:members:
|
||||
|
||||
Online Strategy
|
||||
=============
|
||||
|
||||
.. automodule:: qlib.workflow.online.strategy
|
||||
:members:
|
||||
|
||||
Online Tool
|
||||
=============
|
||||
|
||||
.. automodule:: qlib.workflow.online.utils
|
||||
:members:
|
||||
|
||||
Updater
|
||||
=============
|
||||
|
||||
.. automodule:: qlib.workflow.online.update
|
||||
:members:
|
||||
@@ -42,6 +42,7 @@ Document Structure
|
||||
Intraday Trading: Model&Strategy Testing <component/backtest.rst>
|
||||
Qlib Recorder: Experiment Management <component/recorder.rst>
|
||||
Analysis: Evaluation & Results Analysis <component/report.rst>
|
||||
Online Serving: Online Management & Strategy & Tool <component/online.rst>
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 3
|
||||
|
||||
@@ -154,36 +154,71 @@ Record Template
|
||||
.. automodule:: qlib.workflow.record_temp
|
||||
:members:
|
||||
|
||||
|
||||
Task Management
|
||||
====================
|
||||
|
||||
|
||||
RollingGen
|
||||
TaskGen
|
||||
--------------------
|
||||
.. autoclass:: qlib.workflow.task.gen.RollingGen
|
||||
.. automodule:: qlib.workflow.task.gen
|
||||
:members:
|
||||
|
||||
TaskManager
|
||||
--------------------
|
||||
.. autoclass:: qlib.workflow.task.manage.TaskManager
|
||||
.. automodule:: qlib.workflow.task.manage
|
||||
:members:
|
||||
|
||||
TaskCollector
|
||||
Trainer
|
||||
--------------------
|
||||
.. autoclass:: qlib.workflow.task.collect.TaskCollector
|
||||
.. automodule:: qlib.model.trainer
|
||||
:members:
|
||||
|
||||
ModelUpdater
|
||||
Collector
|
||||
--------------------
|
||||
.. autoclass:: qlib.workflow.task.update.ModelUpdater
|
||||
.. automodule:: qlib.workflow.task.collect
|
||||
:members:
|
||||
|
||||
TimeAdjuster
|
||||
Group
|
||||
--------------------
|
||||
.. autoclass:: qlib.workflow.task.utils.TimeAdjuster
|
||||
.. automodule:: qlib.model.ens.group
|
||||
:members:
|
||||
|
||||
Ensemble
|
||||
--------------------
|
||||
.. automodule:: qlib.model.ens.ensemble
|
||||
:members:
|
||||
|
||||
Utils
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.task.utils
|
||||
:members:
|
||||
|
||||
|
||||
Online Serving
|
||||
====================
|
||||
|
||||
|
||||
Online Manager
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.online.manager
|
||||
:members:
|
||||
|
||||
Online Strategy
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.online.strategy
|
||||
:members:
|
||||
|
||||
Online Tool
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.online.utils
|
||||
:members:
|
||||
|
||||
RecordUpdater
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.online.update
|
||||
:members:
|
||||
|
||||
|
||||
Utils
|
||||
====================
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
| ALSTM (Yao Qin, et al.) | Alpha360 | 0.0493±0.01 | 0.3778±0.06| 0.0585±0.00 | 0.4606±0.04 | 0.0513±0.03 | 0.6727±0.38| -0.1085±0.02 |
|
||||
| GATs (Petar Velickovic, et al.) | Alpha360 | 0.0475±0.00 | 0.3515±0.02| 0.0592±0.00 | 0.4585±0.01 | 0.0876±0.02 | 1.1513±0.27| -0.0795±0.02 |
|
||||
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha360 | 0.0407±0.00| 0.3053±0.00 | 0.0490±0.00 | 0.3840±0.00 | 0.0380±0.02 | 0.5000±0.21 | -0.0984±0.02 |
|
||||
| TabNet (Sercan O. Arik, et al.)| Alpha360 | 0.0192±0.00 | 0.1401±0.00| 0.0291±0.00 | 0.2163±0.00 | -0.0258±0.00 | -0.2961±0.00| -0.1429±0.00 |
|
||||
|
||||
## Alpha158 dataset
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
@@ -32,6 +33,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
| ALSTM (Yao Qin, et al.) | Alpha158 (with selected 20 features) | 0.0385±0.01 | 0.3022±0.06| 0.0478±0.00 | 0.3874±0.04 | 0.0486±0.03 | 0.7141±0.45| -0.1088±0.03 |
|
||||
| GATs (Petar Velickovic, et al.) | Alpha158 (with selected 20 features) | 0.0349±0.00 | 0.2511±0.01| 0.0457±0.00 | 0.3537±0.01 | 0.0578±0.02 | 0.8221±0.25| -0.0824±0.02 |
|
||||
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha158 | 0.0544±0.00 | 0.4338±0.01 | 0.0523±0.00 | 0.4257±0.01 | 0.1253±0.01 | 1.4105±0.14 | -0.0902±0.01 |
|
||||
| TabNet (Sercan O. Arik, et al.)| Alpha158 | 0.0383±0.00 | 0.3414±0.00| 0.0388±0.00 | 0.3460±0.00 | 0.0226±0.00 | 0.2652±0.00| -0.1072±0.00 |
|
||||
|
||||
- The selected 20 features are based on the feature importance of a lightgbm-based model.
|
||||
- The base model of DoubleEnsemble is LGBM.
|
||||
|
||||
@@ -132,7 +132,7 @@ class GenericDataFormatter(abc.ABC):
|
||||
return -1, -1
|
||||
|
||||
def get_column_definition(self):
|
||||
""""Returns formatted column definition in order expected by the TFT."""
|
||||
"""Returns formatted column definition in order expected by the TFT."""
|
||||
|
||||
column_definition = self._column_definition
|
||||
|
||||
|
||||
@@ -25,4 +25,11 @@ The example is given in `workflow.py`, users can run the code as follows.
|
||||
Run the example by running the following command:
|
||||
```bash
|
||||
python workflow.py dump_and_load_dataset
|
||||
```
|
||||
```
|
||||
|
||||
## Benchmarks Performance
|
||||
### Signal Test
|
||||
Here are the results of signal test for benchmark models. We will keep updating benchmark models in future.
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Long precision| Short Precision | Long-Short Average Return | Long-Short Average Sharpe |
|
||||
|---|---|---|---|---|---|---|---|---|---|
|
||||
| LightGBM | Alpha158 | 0.3042±0.00 | 1.5372±0.00| 0.3117±0.00 | 1.6258±0.00 | 0.6720±0.00 | 0.6870±0.00 | 0.000769±0.00 | 1.0190±0.00 |
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data_1min"
|
||||
region: cn
|
||||
market: &market 'csi300'
|
||||
start_time: &start_time "2020-09-15 00:00:00"
|
||||
end_time: &end_time "2021-01-18 16:00:00"
|
||||
train_end_time: &train_end_time "2020-11-15 16:00:00"
|
||||
valid_start_time: &valid_start_time "2020-11-16 00:00:00"
|
||||
valid_end_time: &valid_end_time "2020-11-30 16:00:00"
|
||||
test_start_time: &test_start_time "2020-12-01 00:00:00"
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: *start_time
|
||||
end_time: *end_time
|
||||
fit_start_time: *start_time
|
||||
fit_end_time: *train_end_time
|
||||
instruments: *market
|
||||
freq: '1min'
|
||||
infer_processors:
|
||||
- class: 'RobustZScoreNorm'
|
||||
kwargs:
|
||||
fields_group: 'feature'
|
||||
clip_outlier: false
|
||||
- class: "Fillna"
|
||||
kwargs:
|
||||
fields_group: 'feature'
|
||||
learn_processors:
|
||||
- class: 'DropnaLabel'
|
||||
- class: 'CSRankNorm'
|
||||
kwargs:
|
||||
fields_group: 'label'
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
task:
|
||||
model:
|
||||
class: "HFLGBModel"
|
||||
module_path: "qlib.contrib.model.highfreq_gdbt_model"
|
||||
kwargs:
|
||||
objective: 'binary'
|
||||
metric: ['binary_logloss','auc']
|
||||
verbosity: -1
|
||||
learning_rate: 0.01
|
||||
max_depth: 8
|
||||
num_leaves: 150
|
||||
lambda_l1: 1.5
|
||||
lambda_l2: 1
|
||||
num_threads: 20
|
||||
dataset:
|
||||
class: "DatasetH"
|
||||
module_path: "qlib.data.dataset"
|
||||
kwargs:
|
||||
handler:
|
||||
class: "Alpha158"
|
||||
module_path: "qlib.contrib.data.handler"
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [*start_time, *train_end_time]
|
||||
valid: [*train_end_time, *valid_end_time]
|
||||
test: [*test_start_time, *end_time]
|
||||
record:
|
||||
- class: "SignalRecord"
|
||||
module_path: "qlib.workflow.record_temp"
|
||||
kwargs: {}
|
||||
- class: "HFSignalRecord"
|
||||
module_path: "qlib.workflow.record_temp"
|
||||
kwargs: {}
|
||||
@@ -1,24 +1,23 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
This example shows how a TrainerRM work based on TaskManager with rolling tasks.
|
||||
After training, how to collect the rolling results will be showed in task_collecting.
|
||||
"""
|
||||
|
||||
from pprint import pprint
|
||||
import time
|
||||
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.model.trainer import TrainerR, task_train
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.manage import TaskManager, run_task
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.workflow.task.collect import RecorderCollector
|
||||
from qlib.model.ens.ensemble import RollingEnsemble, ens_workflow
|
||||
import pandas as pd
|
||||
from qlib.workflow.task.utils import list_recorders
|
||||
from qlib.model.ens.group import RollingGroup
|
||||
from qlib.model.trainer import TrainerRM
|
||||
|
||||
"""
|
||||
This example shows how a Trainer work based on TaskManager with rolling tasks.
|
||||
After training, how to collect the rolling results will be showed in task_collecting.
|
||||
"""
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
@@ -139,11 +138,13 @@ class RollingTaskExample:
|
||||
return True
|
||||
return False
|
||||
|
||||
artifact = ens_workflow(
|
||||
RecorderCollector(experiment=self.experiment_name, rec_key_func=rec_key, rec_filter_func=my_filter),
|
||||
RollingGroup(),
|
||||
collector = RecorderCollector(
|
||||
experiment=self.experiment_name,
|
||||
process_list=RollingGroup(),
|
||||
rec_key_func=rec_key,
|
||||
rec_filter_func=my_filter,
|
||||
)
|
||||
print(artifact)
|
||||
print(collector())
|
||||
|
||||
def main(self):
|
||||
self.reset()
|
||||
|
||||
@@ -1,21 +1,18 @@
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.model.ens.ensemble import ens_workflow
|
||||
from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerRM
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.online.manager import RollingOnlineManager
|
||||
from qlib.workflow.online.simulator import OnlineSimulator
|
||||
from qlib.workflow.task.collect import RecorderCollector
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.workflow.task.utils import list_recorders
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
This examples is about the OnlineManager and OnlineSimulator based on rolling tasks.
|
||||
The OnlineManager will focus on the updating of your online models.
|
||||
The OnlineSimulator will focus on the simulating real updating routine of your online models.
|
||||
This examples is about how can simulate the OnlineManager based on rolling tasks.
|
||||
"""
|
||||
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.model.trainer import DelayTrainerRM
|
||||
from qlib.workflow.online.manager import OnlineManager
|
||||
from qlib.workflow.online.strategy import RollingAverageStrategy
|
||||
from qlib.workflow.task.gen import RollingGen
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2018-01-01",
|
||||
@@ -86,10 +83,10 @@ class OnlineSimulationExample:
|
||||
rolling_step=80,
|
||||
start_time="2018-09-10",
|
||||
end_time="2018-10-31",
|
||||
tasks=[task_xgboost_config], # , task_lgb_config]
|
||||
tasks=[task_xgboost_config, task_lgb_config],
|
||||
):
|
||||
"""
|
||||
init OnlineManagerExample.
|
||||
Init OnlineManagerExample.
|
||||
|
||||
Args:
|
||||
provider_uri (str, optional): the provider uri. Defaults to "~/.qlib/qlib_data/cn_data".
|
||||
@@ -105,6 +102,8 @@ class OnlineSimulationExample:
|
||||
"""
|
||||
self.exp_name = exp_name
|
||||
self.task_pool = task_pool
|
||||
self.start_time = start_time
|
||||
self.end_time = end_time
|
||||
mongo_conf = {
|
||||
"task_url": task_url,
|
||||
"task_db_name": task_db_name,
|
||||
@@ -115,62 +114,30 @@ class OnlineSimulationExample:
|
||||
) # The rolling tasks generator, modify_end_time is false because we just need simulate to 2018-10-31.
|
||||
self.trainer = DelayTrainerRM(self.exp_name, self.task_pool)
|
||||
self.task_manager = TaskManager(self.task_pool) # A good way to manage all your tasks
|
||||
self.rolling_online_manager = RollingOnlineManager(
|
||||
experiment_name=exp_name,
|
||||
rolling_gen=self.rolling_gen,
|
||||
trainer=self.trainer,
|
||||
self.rolling_online_manager = OnlineManager(
|
||||
RollingAverageStrategy(
|
||||
exp_name, task_template=tasks, rolling_gen=self.rolling_gen, trainer=self.trainer, need_log=False
|
||||
),
|
||||
begin_time=self.start_time,
|
||||
need_log=False,
|
||||
) # The OnlineManager based on Rolling
|
||||
self.onlinesimulator = OnlineSimulator(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
online_manager=self.rolling_online_manager,
|
||||
)
|
||||
self.tasks = tasks
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(self):
|
||||
print("========== reset ==========")
|
||||
self.task_manager.remove()
|
||||
|
||||
exp = R.get_exp(experiment_name=self.exp_name)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
for rid in list_recorders(
|
||||
RollingOnlineManager.SIGNAL_EXP, lambda x: True if x.info["name"] == self.exp_name else False
|
||||
):
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
# Run this firstly to see the workflow in OnlineManager
|
||||
def first_train(self):
|
||||
print("========== first train ==========")
|
||||
self.reset()
|
||||
self.rolling_online_manager.first_train(self.tasks)
|
||||
|
||||
# Run this secondly to see the simulating in OnlineSimulator
|
||||
def simulate(self):
|
||||
print("========== simulate ==========")
|
||||
self.onlinesimulator.simulate()
|
||||
print(self.rolling_online_manager.collect_artifact())
|
||||
|
||||
print("========== online models ==========")
|
||||
recs_dict = self.onlinesimulator.online_models()
|
||||
for time, recs in recs_dict.items():
|
||||
print(f"{str(time[0])} to {str(time[1])}:")
|
||||
for rec in recs:
|
||||
print(rec.info["id"])
|
||||
|
||||
print("========== online signals ==========")
|
||||
print(self.rolling_online_manager.get_signals())
|
||||
|
||||
# Run this to run all workflow automaticly
|
||||
# Run this to run all workflow automatically
|
||||
def main(self):
|
||||
self.first_train()
|
||||
self.simulate()
|
||||
print("========== reset ==========")
|
||||
self.rolling_online_manager.reset()
|
||||
print("========== simulate ==========")
|
||||
self.rolling_online_manager.simulate(end_time=self.end_time)
|
||||
print("========== collect results ==========")
|
||||
print(self.rolling_online_manager.get_collector()())
|
||||
print("========== signals ==========")
|
||||
print(self.rolling_online_manager.get_signals())
|
||||
print("========== online history ==========")
|
||||
print(self.rolling_online_manager.get_online_history(self.exp_name))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
## to run all workflow automaticly with your own parameters, use the command below
|
||||
## to run all workflow automatically with your own parameters, use the command below
|
||||
# python online_management_simulate.py main --experiment_name="your_exp_name" --rolling_step=60
|
||||
fire.Fire(OnlineSimulationExample)
|
||||
|
||||
@@ -1,22 +1,26 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
This example show how OnlineManager works with rolling tasks.
|
||||
There are two parts including first train and routine.
|
||||
Firstly, the OnlineManager will finish the first training and set trained models to `online` models.
|
||||
Next, the OnlineManager will finish a routine process, including update online prediction -> prepare signals -> prepare tasks -> prepare new models -> reset online models
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.online.strategy import RollingAverageStrategy
|
||||
from qlib.workflow.task.gen import RollingGen
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.workflow.online.manager import RollingOnlineManager
|
||||
from qlib.workflow.online.manager import OnlineManager
|
||||
from qlib.workflow.task.utils import list_recorders
|
||||
from qlib.model.trainer import TrainerRM
|
||||
|
||||
"""
|
||||
This example show how RollingOnlineManager works with rolling tasks.
|
||||
There are two parts including first train and routine.
|
||||
Firstly, the RollingOnlineManager will finish the first training and set trained models to `online` models.
|
||||
Next, the RollingOnlineManager will finish a routine process, including update online prediction -> prepare signals -> prepare tasks -> prepare new models -> reset online models
|
||||
"""
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2013-01-01",
|
||||
"end_time": "2020-09-25",
|
||||
@@ -77,58 +81,75 @@ task_xgboost_config = {
|
||||
class RollingOnlineExample:
|
||||
def __init__(
|
||||
self,
|
||||
exp_name="rolling_exp",
|
||||
task_pool="rolling_task",
|
||||
provider_uri="~/.qlib/qlib_data/cn_data",
|
||||
region="cn",
|
||||
task_url="mongodb://10.0.0.4:27017/",
|
||||
task_db_name="rolling_db",
|
||||
rolling_step=550,
|
||||
tasks=[task_xgboost_config], # , task_lgb_config],
|
||||
):
|
||||
self.exp_name = exp_name
|
||||
self.task_pool = task_pool
|
||||
mongo_conf = {
|
||||
"task_url": task_url, # your MongoDB url
|
||||
"task_db_name": task_db_name, # database name
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
|
||||
self.rolling_online_manager = RollingOnlineManager(
|
||||
experiment_name=exp_name,
|
||||
rolling_gen=RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD),
|
||||
trainer=TrainerRM(self.exp_name, self.task_pool),
|
||||
)
|
||||
self.tasks = tasks
|
||||
self.rolling_step = rolling_step
|
||||
strategy = []
|
||||
for task in tasks:
|
||||
name_id = task["model"]["class"] # NOTE: Assumption: The model class can specify only one strategy
|
||||
strategy.append(
|
||||
RollingAverageStrategy(
|
||||
name_id,
|
||||
task,
|
||||
RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD),
|
||||
TrainerRM(experiment_name=name_id, task_pool=name_id),
|
||||
)
|
||||
)
|
||||
|
||||
_ROLLING_MANAGER_PATH = ".rolling_manager" # the RollingOnlineManager will dump to this file, for it will be loaded when calling routine.
|
||||
self.rolling_online_manager = OnlineManager(strategy)
|
||||
self.collector = self.rolling_online_manager.get_collector()
|
||||
|
||||
_ROLLING_MANAGER_PATH = (
|
||||
".RollingOnlineExample" # the OnlineManager will dump to this file, for it can be loaded when calling routine.
|
||||
)
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(self):
|
||||
print("========== reset ==========")
|
||||
TaskManager(self.task_pool).remove()
|
||||
exp = R.get_exp(experiment_name=self.exp_name)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
for task in self.tasks:
|
||||
name_id = task["model"]["class"] + "_" + str(self.rolling_step)
|
||||
TaskManager(name_id).remove()
|
||||
exp = R.get_exp(experiment_name=name_id)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
if os.path.exists(self._ROLLING_MANAGER_PATH):
|
||||
os.remove(self._ROLLING_MANAGER_PATH)
|
||||
if os.path.exists(self._ROLLING_MANAGER_PATH):
|
||||
os.remove(self._ROLLING_MANAGER_PATH)
|
||||
|
||||
for rid in list_recorders(
|
||||
RollingOnlineManager.SIGNAL_EXP, lambda x: True if x.info["name"] == self.exp_name else False
|
||||
):
|
||||
exp.delete_recorder(rid)
|
||||
for rid in list_recorders("OnlineManagerSignals", lambda x: True if x.info["name"] == name_id else False):
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
def first_run(self):
|
||||
print("========== reset ==========")
|
||||
self.rolling_online_manager.reset()
|
||||
print("========== first_run ==========")
|
||||
self.reset()
|
||||
self.rolling_online_manager.first_train([task_xgboost_config, task_lgb_config])
|
||||
self.rolling_online_manager.first_train()
|
||||
print("========== dump ==========")
|
||||
self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)
|
||||
print(self.rolling_online_manager.collect_artifact())
|
||||
print("========== collect results ==========")
|
||||
print(self.collector())
|
||||
|
||||
def routine(self):
|
||||
print("========== routine ==========")
|
||||
print("========== load ==========")
|
||||
with Path(self._ROLLING_MANAGER_PATH).open("rb") as f:
|
||||
self.rolling_online_manager = pickle.load(f)
|
||||
print("========== routine ==========")
|
||||
self.rolling_online_manager.routine()
|
||||
print(self.rolling_online_manager.collect_artifact())
|
||||
print("========== collect results ==========")
|
||||
print(self.collector())
|
||||
print("========== signals ==========")
|
||||
print(self.rolling_online_manager.get_signals())
|
||||
|
||||
def main(self):
|
||||
self.first_run()
|
||||
@@ -137,11 +158,11 @@ class RollingOnlineExample:
|
||||
|
||||
if __name__ == "__main__":
|
||||
####### to train the first version's models, use the command below
|
||||
# python task_manager_rolling_with_updating.py first_run
|
||||
# python rolling_online_management.py first_run
|
||||
|
||||
####### to update the models and predictions after the trading time, use the command below
|
||||
# python task_manager_rolling_with_updating.py after_day
|
||||
# python rolling_online_management.py after_day
|
||||
|
||||
####### to define your own parameters, use `--`
|
||||
# python task_manager_rolling_with_updating.py first_run --exp_name='your_exp_name' --rolling_step=40
|
||||
# python rolling_online_management.py first_run --exp_name='your_exp_name' --rolling_step=40
|
||||
fire.Fire(RollingOnlineExample)
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
This example show how OnlineTool works when we need update prediction.
|
||||
There are two parts including first_train and update_online_pred.
|
||||
Firstly, we will finish the training and set the trained model to `online` model.
|
||||
Next, we will finish updating online prediction.
|
||||
"""
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.model.trainer import task_train
|
||||
from qlib.workflow.online.manager import OnlineManagerR
|
||||
from qlib.workflow.task.utils import list_recorders
|
||||
|
||||
"""
|
||||
This example show how OnlineManager works when we need update prediction.
|
||||
There are two parts including first_train and update_online_pred.
|
||||
Firstly, the RollingOnlineManager will finish the first training and set the trained model to `online` model.
|
||||
Next, the RollingOnlineManager will finish updating online prediction
|
||||
"""
|
||||
from qlib.workflow.online.utils import OnlineToolR
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
@@ -65,15 +66,15 @@ class UpdatePredExample:
|
||||
):
|
||||
qlib.init(provider_uri=provider_uri, region=region)
|
||||
self.experiment_name = experiment_name
|
||||
self.online_manager = OnlineManagerR(self.experiment_name)
|
||||
self.online_tool = OnlineToolR(self.experiment_name)
|
||||
self.task_config = task_config
|
||||
|
||||
def first_train(self):
|
||||
rec = task_train(self.task_config, experiment_name=self.experiment_name)
|
||||
self.online_manager.reset_online_tag(rec) # set to online model
|
||||
self.online_tool.reset_online_tag(rec) # set to online model
|
||||
|
||||
def update_online_pred(self):
|
||||
self.online_manager.update_online_pred()
|
||||
self.online_tool.update_online_pred()
|
||||
|
||||
def main(self):
|
||||
self.first_train()
|
||||
|
||||
@@ -15,7 +15,8 @@ LOG = get_module_logger("backtest")
|
||||
|
||||
|
||||
def backtest(pred, strategy, executor, trade_exchange, shift, verbose, account, benchmark, return_order):
|
||||
"""Parameters
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
pred : pandas.DataFrame
|
||||
predict should has <datetime, instrument> index and one `score` column
|
||||
@@ -124,7 +125,9 @@ def backtest(pred, strategy, executor, trade_exchange, shift, verbose, account,
|
||||
|
||||
|
||||
def update_account(trade_account, trade_info, trade_exchange, trade_date):
|
||||
"""Update the account and strategy
|
||||
"""
|
||||
Update the account and strategy
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trade_account : Account()
|
||||
|
||||
@@ -128,7 +128,7 @@ class Position:
|
||||
return self.position["cash"]
|
||||
|
||||
def get_stock_amount_dict(self):
|
||||
"""generate stock amount dict {stock_id : amount of stock} """
|
||||
"""generate stock amount dict {stock_id : amount of stock}"""
|
||||
d = {}
|
||||
stock_list = self.get_stock_list()
|
||||
for stock_code in stock_list:
|
||||
|
||||
@@ -8,6 +8,59 @@ import pandas as pd
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
def calc_long_short_prec(
|
||||
pred: pd.Series, label: pd.Series, date_col="datetime", quantile: float = 0.2, dropna=False, is_alpha=False
|
||||
) -> Tuple[pd.Series, pd.Series]:
|
||||
"""
|
||||
calculate the precision for long and short operation
|
||||
|
||||
|
||||
:param pred/label: index is **pd.MultiIndex**, index name is **[datetime, instruments]**; columns names is **[score]**.
|
||||
|
||||
.. code-block:: python
|
||||
score
|
||||
datetime instrument
|
||||
2020-12-01 09:30:00 SH600068 0.553634
|
||||
SH600195 0.550017
|
||||
SH600276 0.540321
|
||||
SH600584 0.517297
|
||||
SH600715 0.544674
|
||||
label :
|
||||
label
|
||||
date_col :
|
||||
date_col
|
||||
|
||||
Returns
|
||||
-------
|
||||
(pd.Series, pd.Series)
|
||||
long precision and short precision in time level
|
||||
"""
|
||||
if is_alpha:
|
||||
label = label - label.mean(level=date_col)
|
||||
if int(1 / quantile) >= len(label.index.get_level_values(1).unique()):
|
||||
raise ValueError("Need more instruments to calculate precision")
|
||||
|
||||
df = pd.DataFrame({"pred": pred, "label": label})
|
||||
if dropna:
|
||||
df.dropna(inplace=True)
|
||||
|
||||
group = df.groupby(level=date_col)
|
||||
|
||||
N = lambda x: int(len(x) * quantile)
|
||||
# find the top/low quantile of prediction and treat them as long and short target
|
||||
long = group.apply(lambda x: x.nlargest(N(x), columns="pred").label).reset_index(level=0, drop=True)
|
||||
short = group.apply(lambda x: x.nsmallest(N(x), columns="pred").label).reset_index(level=0, drop=True)
|
||||
|
||||
groupll = long.groupby(date_col)
|
||||
l_dom = groupll.apply(lambda x: x > 0)
|
||||
l_c = groupll.count()
|
||||
|
||||
groups = short.groupby(date_col)
|
||||
s_dom = groups.apply(lambda x: x < 0)
|
||||
s_c = groups.count()
|
||||
return (l_dom.groupby(date_col).sum() / l_c), (s_dom.groupby(date_col).sum() / s_c)
|
||||
|
||||
|
||||
def calc_ic(pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False) -> Tuple[pd.Series, pd.Series]:
|
||||
"""calc_ic.
|
||||
|
||||
|
||||
157
qlib/contrib/model/highfreq_gdbt_model.py
Normal file
157
qlib/contrib/model/highfreq_gdbt_model.py
Normal file
@@ -0,0 +1,157 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import lightgbm as lgb
|
||||
|
||||
from qlib.model.base import ModelFT
|
||||
from qlib.data.dataset import DatasetH
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
import warnings
|
||||
|
||||
|
||||
class HFLGBModel(ModelFT):
|
||||
"""LightGBM Model for high frequency prediction"""
|
||||
|
||||
def __init__(self, loss="mse", **kwargs):
|
||||
if loss not in {"mse", "binary"}:
|
||||
raise NotImplementedError
|
||||
self.params = {"objective": loss, "verbosity": -1}
|
||||
self.params.update(kwargs)
|
||||
self.model = None
|
||||
|
||||
def _cal_signal_metrics(self, y_test, l_cut, r_cut):
|
||||
"""
|
||||
Calcaute the signal metrics by daily level
|
||||
"""
|
||||
up_pre, down_pre = [], []
|
||||
up_alpha_ll, down_alpha_ll = [], []
|
||||
for date in y_test.index.get_level_values(0).unique():
|
||||
df_res = y_test.loc[date].sort_values("pred")
|
||||
if int(l_cut * len(df_res)) < 10:
|
||||
warnings.warn("Warning: threhold is too low or instruments number is not enough")
|
||||
continue
|
||||
top = df_res.iloc[: int(l_cut * len(df_res))]
|
||||
bottom = df_res.iloc[int(r_cut * len(df_res)) :]
|
||||
|
||||
down_precision = len(top[top[top.columns[0]] < 0]) / (len(top))
|
||||
up_precision = len(bottom[bottom[top.columns[0]] > 0]) / (len(bottom))
|
||||
|
||||
down_alpha = top[top.columns[0]].mean()
|
||||
up_alpha = bottom[bottom.columns[0]].mean()
|
||||
|
||||
up_pre.append(up_precision)
|
||||
down_pre.append(down_precision)
|
||||
up_alpha_ll.append(up_alpha)
|
||||
down_alpha_ll.append(down_alpha)
|
||||
|
||||
return (
|
||||
np.array(up_pre).mean(),
|
||||
np.array(down_pre).mean(),
|
||||
np.array(up_alpha_ll).mean(),
|
||||
np.array(down_alpha_ll).mean(),
|
||||
)
|
||||
|
||||
def hf_signal_test(self, dataset: DatasetH, threhold=0.2):
|
||||
"""
|
||||
Test the sigal in high frequency test set
|
||||
"""
|
||||
if self.model == None:
|
||||
raise ValueError("Model hasn't been trained yet")
|
||||
df_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
||||
df_test.dropna(inplace=True)
|
||||
x_test, y_test = df_test["feature"], df_test["label"]
|
||||
# Convert label into alpha
|
||||
y_test[y_test.columns[0]] = y_test[y_test.columns[0]] - y_test[y_test.columns[0]].mean(level=0)
|
||||
|
||||
res = pd.Series(self.model.predict(x_test.values), index=x_test.index)
|
||||
y_test["pred"] = res
|
||||
|
||||
up_p, down_p, up_a, down_a = self._cal_signal_metrics(y_test, threhold, 1 - threhold)
|
||||
print("===============================")
|
||||
print("High frequency signal test")
|
||||
print("===============================")
|
||||
print("Test set precision: ")
|
||||
print("Positive precision: {}, Negative precision: {}".format(up_p, down_p))
|
||||
print("Test Alpha Average in test set: ")
|
||||
print("Positive average alpha: {}, Negative average alpha: {}".format(up_a, down_a))
|
||||
|
||||
def _prepare_data(self, dataset: DatasetH):
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
)
|
||||
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_train["feature"], df_valid["label"]
|
||||
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
|
||||
l_name = df_train["label"].columns[0]
|
||||
# Convert label into alpha
|
||||
df_train["label"][l_name] = df_train["label"][l_name] - df_train["label"][l_name].mean(level=0)
|
||||
df_valid["label"][l_name] = df_valid["label"][l_name] - df_valid["label"][l_name].mean(level=0)
|
||||
mapping_fn = lambda x: 0 if x < 0 else 1
|
||||
df_train["label_c"] = df_train["label"][l_name].apply(mapping_fn)
|
||||
df_valid["label_c"] = df_valid["label"][l_name].apply(mapping_fn)
|
||||
x_train, y_train = df_train["feature"], df_train["label_c"].values
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label_c"].values
|
||||
else:
|
||||
raise ValueError("LightGBM doesn't support multi-label training")
|
||||
|
||||
dtrain = lgb.Dataset(x_train.values, label=y_train)
|
||||
dvalid = lgb.Dataset(x_valid.values, label=y_valid)
|
||||
return dtrain, dvalid
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
num_boost_round=1000,
|
||||
early_stopping_rounds=50,
|
||||
verbose_eval=20,
|
||||
evals_result=dict(),
|
||||
**kwargs
|
||||
):
|
||||
dtrain, dvalid = self._prepare_data(dataset)
|
||||
self.model = lgb.train(
|
||||
self.params,
|
||||
dtrain,
|
||||
num_boost_round=num_boost_round,
|
||||
valid_sets=[dtrain, dvalid],
|
||||
valid_names=["train", "valid"],
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
verbose_eval=verbose_eval,
|
||||
evals_result=evals_result,
|
||||
**kwargs
|
||||
)
|
||||
evals_result["train"] = list(evals_result["train"].values())[0]
|
||||
evals_result["valid"] = list(evals_result["valid"].values())[0]
|
||||
|
||||
def predict(self, dataset):
|
||||
if self.model is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
return pd.Series(self.model.predict(x_test.values), index=x_test.index)
|
||||
|
||||
def finetune(self, dataset: DatasetH, num_boost_round=10, verbose_eval=20):
|
||||
"""
|
||||
finetune model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dataset : DatasetH
|
||||
dataset for finetuning
|
||||
num_boost_round : int
|
||||
number of round to finetune model
|
||||
verbose_eval : int
|
||||
verbose level
|
||||
"""
|
||||
# Based on existing model and finetune by train more rounds
|
||||
dtrain, _ = self._prepare_data(dataset)
|
||||
self.model = lgb.train(
|
||||
self.params,
|
||||
dtrain,
|
||||
num_boost_round=num_boost_round,
|
||||
init_model=self.model,
|
||||
valid_sets=[dtrain],
|
||||
valid_names=["train"],
|
||||
verbose_eval=verbose_eval,
|
||||
)
|
||||
@@ -214,7 +214,7 @@ def cumulative_return_graph(
|
||||
features_df = D.features(D.instruments('csi500'), ['Ref($close, -1)/$close - 1'], pred_df_dates.min(), pred_df_dates.max())
|
||||
features_df.columns = ['label']
|
||||
|
||||
qcr.cumulative_return_graph(positions, report_normal_df, features_df)
|
||||
qcr.analysis_position.cumulative_return_graph(positions, report_normal_df, features_df)
|
||||
|
||||
|
||||
Graph desc:
|
||||
|
||||
@@ -94,7 +94,7 @@ def rank_label_graph(
|
||||
features_df = D.features(D.instruments('csi500'), ['Ref($close, -1)/$close-1'], pred_df_dates.min(), pred_df_dates.max())
|
||||
features_df.columns = ['label']
|
||||
|
||||
qcr.rank_label_graph(positions, features_df, pred_df_dates.min(), pred_df_dates.max())
|
||||
qcr.analysis_position.rank_label_graph(positions, features_df, pred_df_dates.min(), pred_df_dates.max())
|
||||
|
||||
|
||||
:param position: position data; **qlib.contrib.backtest.backtest.backtest** result.
|
||||
|
||||
@@ -186,7 +186,7 @@ def report_graph(report_df: pd.DataFrame, show_notebook: bool = True) -> [list,
|
||||
|
||||
report_normal_df, _ = backtest(pred_df, strategy, **bparas)
|
||||
|
||||
qcr.report_graph(report_normal_df)
|
||||
qcr.analysis_position.report_graph(report_normal_df)
|
||||
|
||||
:param report_df: **df.index.name** must be **date**, **df.columns** must contain **return**, **turnover**, **cost**, **bench**.
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from ...utils import get_module_by_module_path
|
||||
|
||||
|
||||
class BaseGraph:
|
||||
""""""
|
||||
""" """
|
||||
|
||||
_name = None
|
||||
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from sklearn.metrics import mean_squared_error
|
||||
from typing import Dict, Text, Any
|
||||
import numpy as np
|
||||
|
||||
from ...contrib.eva.alpha import calc_ic
|
||||
from ...workflow.record_temp import RecordTemp
|
||||
@@ -12,7 +13,7 @@ from ...workflow.record_temp import SignalRecord
|
||||
from ...data import dataset as qlib_dataset
|
||||
from ...log import get_module_logger
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
logger = get_module_logger("workflow", logging.INFO)
|
||||
|
||||
|
||||
class MultiSegRecord(RecordTemp):
|
||||
|
||||
@@ -522,6 +522,9 @@ class LocalCalendarProvider(CalendarProvider):
|
||||
# if future calendar not exists, return current calendar
|
||||
if not os.path.exists(fname):
|
||||
get_module_logger("data").warning(f"{freq}_future.txt not exists, return current calendar!")
|
||||
get_module_logger("data").warning(
|
||||
"You can get future calendar by referring to the following document: https://github.com/microsoft/qlib/blob/main/scripts/data_collector/contrib/README.md"
|
||||
)
|
||||
fname = self._uri_cal.format(freq)
|
||||
else:
|
||||
fname = self._uri_cal.format(freq)
|
||||
@@ -1016,7 +1019,8 @@ class ClientProvider(BaseProvider):
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
if isinstance(Cal, ClientCalendarProvider):
|
||||
Cal.set_conn(self.client)
|
||||
Inst.set_conn(self.client)
|
||||
if isinstance(Inst, ClientInstrumentProvider):
|
||||
Inst.set_conn(self.client)
|
||||
if hasattr(DatasetD, "provider"):
|
||||
DatasetD.provider.set_conn(self.client)
|
||||
else:
|
||||
|
||||
@@ -27,7 +27,7 @@ class Dataset(Serializable):
|
||||
- setup data
|
||||
- The data related attributes' names should start with '_' so that it will not be saved on disk when serializing.
|
||||
|
||||
The data could specify the info to caculate the essential data for preparation
|
||||
The data could specify the info to calculate the essential data for preparation
|
||||
"""
|
||||
self.setup_data(**kwargs)
|
||||
super().__init__()
|
||||
@@ -92,7 +92,7 @@ class DatasetH(Dataset):
|
||||
handler : Union[dict, DataHandler]
|
||||
handler could be:
|
||||
|
||||
- insntance of `DataHandler`
|
||||
- instance of `DataHandler`
|
||||
|
||||
- config of `DataHandler`. Please refer to `DataHandler`
|
||||
|
||||
@@ -124,7 +124,7 @@ class DatasetH(Dataset):
|
||||
Parameters
|
||||
----------
|
||||
handler_kwargs : dict
|
||||
Config of DataHanlder, which could include the following arguments:
|
||||
Config of DataHandler, which could include the following arguments:
|
||||
|
||||
- arguments of DataHandler.conf_data, such as 'instruments', 'start_time' and 'end_time'.
|
||||
|
||||
@@ -148,11 +148,11 @@ class DatasetH(Dataset):
|
||||
Parameters
|
||||
----------
|
||||
handler_kwargs : dict
|
||||
init arguments of DataHanlder, which could include the following arguments:
|
||||
init arguments of DataHandler, which could include the following arguments:
|
||||
|
||||
- init_type : Init Type of Handler
|
||||
|
||||
- enable_cache : wheter to enable cache
|
||||
- enable_cache : whether to enable cache
|
||||
|
||||
"""
|
||||
super().setup_data(**kwargs)
|
||||
@@ -238,7 +238,7 @@ class TSDataSampler:
|
||||
(T)ime-(S)eries DataSampler
|
||||
This is the result of TSDatasetH
|
||||
|
||||
It works like `torch.data.utils.Dataset`, it provides a very convient interface for constructing time-series
|
||||
It works like `torch.data.utils.Dataset`, it provides a very convenient interface for constructing time-series
|
||||
dataset based on tabular data.
|
||||
|
||||
If user have further requirements for processing data, user could process them based on `TSDataSampler` or create
|
||||
@@ -310,7 +310,7 @@ class TSDataSampler:
|
||||
|
||||
self.start_idx, self.end_idx = self.data_index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end))
|
||||
self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance
|
||||
|
||||
|
||||
del self.data # save memory
|
||||
|
||||
@staticmethod
|
||||
@@ -472,7 +472,7 @@ class TSDatasetH(DatasetH):
|
||||
(T)ime-(S)eries Dataset (H)andler
|
||||
|
||||
|
||||
Covnert the tabular data to Time-Series data
|
||||
Convert the tabular data to Time-Series data
|
||||
|
||||
Requirements analysis
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ class DataHandler(Serializable):
|
||||
The data handler try to maintain a handler with 2 level.
|
||||
`datetime` & `instruments`.
|
||||
|
||||
Any order of the index level can be suported (The order will be implied in the data).
|
||||
Any order of the index level can be supported (The order will be implied in the data).
|
||||
The order <`datetime`, `instruments`> will be used when the dataframe index name is missed.
|
||||
|
||||
Example of the data:
|
||||
@@ -77,7 +77,7 @@ class DataHandler(Serializable):
|
||||
data_loader : Tuple[dict, str, DataLoader]
|
||||
data loader to load the data.
|
||||
init_data :
|
||||
intialize the original data in the constructor.
|
||||
initialize the original data in the constructor.
|
||||
fetch_orig : bool
|
||||
Return the original data instead of copy if possible.
|
||||
"""
|
||||
@@ -128,7 +128,7 @@ class DataHandler(Serializable):
|
||||
|
||||
def setup_data(self, enable_cache: bool = False):
|
||||
"""
|
||||
Set Up the data in case of running intialization for multiple time
|
||||
Set Up the data in case of running initialization for multiple time
|
||||
|
||||
It is responsible for maintaining following variable
|
||||
1) self._data
|
||||
@@ -453,7 +453,7 @@ class DataHandlerLP(DataHandler):
|
||||
|
||||
def setup_data(self, init_type: str = IT_FIT_SEQ, **kwargs):
|
||||
"""
|
||||
Set up the data in case of running intialization for multiple time
|
||||
Set up the data in case of running initialization for multiple time
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
@@ -130,7 +130,7 @@ class FilterCol(Processor):
|
||||
|
||||
|
||||
class TanhProcess(Processor):
|
||||
""" Use tanh to process noise data"""
|
||||
"""Use tanh to process noise data"""
|
||||
|
||||
def __call__(self, df):
|
||||
def tanh_denoise(data):
|
||||
@@ -145,7 +145,7 @@ class TanhProcess(Processor):
|
||||
|
||||
|
||||
class ProcessInf(Processor):
|
||||
"""Process infinity """
|
||||
"""Process infinity"""
|
||||
|
||||
def __call__(self, df):
|
||||
def replace_inf(data):
|
||||
|
||||
38
qlib/log.py
38
qlib/log.py
@@ -12,7 +12,41 @@ from contextlib import contextmanager
|
||||
from .config import C
|
||||
|
||||
|
||||
def get_module_logger(module_name, level: Optional[int] = None):
|
||||
class MetaLogger(type):
|
||||
def __new__(cls, name, bases, dict):
|
||||
wrapper_dict = logging.Logger.__dict__.copy()
|
||||
for key in wrapper_dict:
|
||||
if key not in dict and key != "__reduce__":
|
||||
dict[key] = wrapper_dict[key]
|
||||
return type.__new__(cls, name, bases, dict)
|
||||
|
||||
|
||||
class QlibLogger(metaclass=MetaLogger):
|
||||
"""
|
||||
Customized logger for Qlib.
|
||||
"""
|
||||
|
||||
def __init__(self, module_name):
|
||||
self.module_name = module_name
|
||||
self.level = 0
|
||||
|
||||
@property
|
||||
def logger(self):
|
||||
logger = logging.getLogger(self.module_name)
|
||||
logger.setLevel(self.level)
|
||||
return logger
|
||||
|
||||
def setLevel(self, level):
|
||||
self.level = level
|
||||
|
||||
def __getattr__(self, name):
|
||||
# During unpickling, python will call __getattr__. Use this line to avoid maximum recursion error.
|
||||
if name in {"__setstate__"}:
|
||||
raise AttributeError
|
||||
return self.logger.__getattribute__(name)
|
||||
|
||||
|
||||
def get_module_logger(module_name, level: Optional[int] = None) -> logging.Logger:
|
||||
"""
|
||||
Get a logger for a specific module.
|
||||
|
||||
@@ -27,7 +61,7 @@ def get_module_logger(module_name, level: Optional[int] = None):
|
||||
|
||||
module_name = "qlib.{}".format(module_name)
|
||||
# Get logger.
|
||||
module_logger = logging.getLogger(module_name)
|
||||
module_logger = QlibLogger(module_name)
|
||||
module_logger.setLevel(level)
|
||||
return module_logger
|
||||
|
||||
|
||||
@@ -11,11 +11,11 @@ class BaseModel(Serializable, metaclass=abc.ABCMeta):
|
||||
|
||||
@abc.abstractmethod
|
||||
def predict(self, *args, **kwargs) -> object:
|
||||
""" Make predictions after modeling things """
|
||||
"""Make predictions after modeling things"""
|
||||
pass
|
||||
|
||||
def __call__(self, *args, **kwargs) -> object:
|
||||
""" leverage Python syntactic sugar to make the models' behaviors like functions """
|
||||
"""leverage Python syntactic sugar to make the models' behaviors like functions"""
|
||||
return self.predict(*args, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -1,36 +1,12 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Callable, Union
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
Ensemble can merge the objects in an Ensemble. For example, if there are many submodels predictions, we may need to merge them in an ensemble predictions.
|
||||
"""
|
||||
|
||||
from typing import Union
|
||||
import pandas as pd
|
||||
from qlib.workflow.task.collect import Collector
|
||||
from qlib.utils.serial import Serializable
|
||||
|
||||
|
||||
def ens_workflow(collector: Collector, process_list, *args, **kwargs):
|
||||
"""the ensemble workflow based on collector and different dict processors.
|
||||
|
||||
Args:
|
||||
collector (Collector): the collector to collect the result into {result_key: things}
|
||||
process_list (list or Callable): the list of processors or the instance of processor to process dict.
|
||||
The processor order is same as the list order.
|
||||
For example: [Group1(..., Ensemble1()), Group2(..., Ensemble2())]
|
||||
Returns:
|
||||
dict: the ensemble dict
|
||||
"""
|
||||
collect_dict = collector.collect()
|
||||
if not isinstance(process_list, list):
|
||||
process_list = [process_list]
|
||||
|
||||
ensemble = {}
|
||||
for artifact in collect_dict:
|
||||
value = collect_dict[artifact]
|
||||
for process in process_list:
|
||||
if not callable(process):
|
||||
raise NotImplementedError(f"{type(process)} is not supported in `ens_workflow`.")
|
||||
value = process(value, *args, **kwargs)
|
||||
ensemble[artifact] = value
|
||||
|
||||
return ensemble
|
||||
|
||||
|
||||
class Ensemble:
|
||||
@@ -49,21 +25,45 @@ class Ensemble:
|
||||
raise NotImplementedError(f"Please implement the `__call__` method.")
|
||||
|
||||
|
||||
class SingleKeyEnsemble(Ensemble):
|
||||
|
||||
"""
|
||||
Extract the object if there is only one key and value in dict. Make result more readable.
|
||||
{Only key: Only value} -> Only value
|
||||
If there are more than 1 key or less than 1 key, then do nothing.
|
||||
Even you can run this recursively to make dict more readable.
|
||||
NOTE: Default run recursively.
|
||||
"""
|
||||
|
||||
def __call__(self, ensemble_dict: Union[dict, object], recursion: bool = True) -> object:
|
||||
if not isinstance(ensemble_dict, dict):
|
||||
return ensemble_dict
|
||||
if recursion:
|
||||
tmp_dict = {}
|
||||
for k, v in ensemble_dict.items():
|
||||
tmp_dict[k] = self(v, recursion)
|
||||
ensemble_dict = tmp_dict
|
||||
keys = list(ensemble_dict.keys())
|
||||
if len(keys) == 1:
|
||||
ensemble_dict = ensemble_dict[keys[0]]
|
||||
return ensemble_dict
|
||||
|
||||
|
||||
class RollingEnsemble(Ensemble):
|
||||
|
||||
"""Merge the rolling objects in an Ensemble"""
|
||||
|
||||
def __call__(self, ensemble_dict: dict):
|
||||
def __call__(self, ensemble_dict: dict) -> pd.DataFrame:
|
||||
"""Merge a dict of rolling dataframe like `prediction` or `IC` into an ensemble.
|
||||
|
||||
NOTE: The values of dict must be pd.Dataframe, and have the index "datetime"
|
||||
NOTE: The values of dict must be pd.DataFrame, and have the index "datetime"
|
||||
|
||||
Args:
|
||||
ensemble_dict (dict): a dict like {"A": pd.Dataframe, "B": pd.Dataframe}.
|
||||
ensemble_dict (dict): a dict like {"A": pd.DataFrame, "B": pd.DataFrame}.
|
||||
The key of the dict will be ignored.
|
||||
|
||||
Returns:
|
||||
pd.Dataframe: the complete result of rolling.
|
||||
pd.DataFrame: the complete result of rolling.
|
||||
"""
|
||||
artifact_list = list(ensemble_dict.values())
|
||||
artifact_list.sort(key=lambda x: x.index.get_level_values("datetime").min())
|
||||
@@ -72,3 +72,24 @@ class RollingEnsemble(Ensemble):
|
||||
artifact = artifact[~artifact.index.duplicated(keep="last")]
|
||||
artifact = artifact.sort_index()
|
||||
return artifact
|
||||
|
||||
|
||||
class AverageEnsemble(Ensemble):
|
||||
def __call__(self, ensemble_dict: dict):
|
||||
"""
|
||||
Average a dict of same shape dataframe like `prediction` or `IC` into an ensemble.
|
||||
|
||||
NOTE: The values of dict must be pd.DataFrame, and have the index "datetime"
|
||||
|
||||
Args:
|
||||
ensemble_dict (dict): a dict like {"A": pd.DataFrame, "B": pd.DataFrame}.
|
||||
The key of the dict will be ignored.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: the complete result of averaging.
|
||||
"""
|
||||
values = list(ensemble_dict.values())
|
||||
results = pd.concat(values, axis=1)
|
||||
results = results.mean(axis=1).to_frame("score")
|
||||
results = results.sort_index()
|
||||
return results
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
Group can group a set of object based on `group_func` and change them to a dict.
|
||||
After group, we provide a method to reduce them.
|
||||
|
||||
For example:
|
||||
|
||||
group: {(A,B,C1): object, (A,B,C2): object} -> {(A,B): {C1: object, C2: object}}
|
||||
reduce: {(A,B): {C1: object, C2: object}} -> {(A,B): object}
|
||||
|
||||
"""
|
||||
|
||||
from qlib.model.ens.ensemble import Ensemble, RollingEnsemble
|
||||
from typing import Callable, Union
|
||||
from joblib import Parallel, delayed
|
||||
@@ -21,20 +35,20 @@ class Group:
|
||||
self._group_func = group_func
|
||||
self._ens_func = ens
|
||||
|
||||
def group(self, *args, **kwargs):
|
||||
def group(self, *args, **kwargs) -> dict:
|
||||
# TODO: such design is weird when `_group_func` is the only configurable part in the class
|
||||
if isinstance(getattr(self, "_group_func", None), Callable):
|
||||
return self._group_func(*args, **kwargs)
|
||||
else:
|
||||
raise NotImplementedError(f"Please specify valid `group_func`.")
|
||||
|
||||
def reduce(self, *args, **kwargs):
|
||||
def reduce(self, *args, **kwargs) -> dict:
|
||||
if isinstance(getattr(self, "_ens_func", None), Callable):
|
||||
return self._ens_func(*args, **kwargs)
|
||||
else:
|
||||
raise NotImplementedError(f"Please specify valid `_ens_func`.")
|
||||
|
||||
def __call__(self, ungrouped_dict: dict, n_jobs=1, verbose=0, *args, **kwargs):
|
||||
def __call__(self, ungrouped_dict: dict, n_jobs=1, verbose=0, *args, **kwargs) -> dict:
|
||||
"""Group the ungrouped_dict into different groups.
|
||||
|
||||
Args:
|
||||
@@ -59,7 +73,7 @@ class Group:
|
||||
class RollingGroup(Group):
|
||||
"""group the rolling dict"""
|
||||
|
||||
def group(self, rolling_dict: dict):
|
||||
def group(self, rolling_dict: dict) -> dict:
|
||||
"""Given an rolling dict likes {(A,B,R): things}, return the grouped dict likes {(A,B): {R:things}}
|
||||
|
||||
NOTE: There is a assumption which is the rolling key is at the end of key tuple, because the rolling results always need to be ensemble firstly.
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
import abc
|
||||
import typing
|
||||
|
||||
|
||||
class TaskGen(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
def __call__(self, *args, **kwargs) -> typing.List[dict]:
|
||||
"""
|
||||
generate
|
||||
|
||||
Parameters
|
||||
----------
|
||||
args, kwargs:
|
||||
The info for generating tasks
|
||||
Example 1):
|
||||
input: a specific task template
|
||||
output: rolling version of the tasks
|
||||
Example 2):
|
||||
input: a specific task template
|
||||
output: a set of tasks with different losses
|
||||
|
||||
Returns
|
||||
-------
|
||||
typing.List[dict]:
|
||||
A list of tasks
|
||||
"""
|
||||
pass
|
||||
@@ -1,58 +1,69 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import copy
|
||||
"""
|
||||
The Trainer will train a list of tasks and return a list of model recorder.
|
||||
There are two steps in each Trainer including ``train``(make model recorder) and ``end_train``(modify model recorder).
|
||||
|
||||
This is concept called ``DelayTrainer``, which can be used in online simulating to parallel training.
|
||||
In ``DelayTrainer``, the first step is only to save some necessary info to model recorder, and the second step which will be finished in the end can do some concurrent and time-consuming operations such as model fitting.
|
||||
|
||||
``Qlib`` offer two kind of Trainer, ``TrainerR`` is the simplest way and ``TrainerRM`` is based on TaskManager to help manager tasks lifecycle automatically.
|
||||
"""
|
||||
|
||||
import socket
|
||||
import time
|
||||
from xxlimited import Str
|
||||
from qlib.utils import init_instance_by_config, flatten_dict, get_cls_kwargs
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.recorder import Recorder
|
||||
from qlib.workflow.record_temp import SignalRecord
|
||||
from qlib.workflow.task.manage import TaskManager, run_task
|
||||
from typing import Callable, List
|
||||
|
||||
from qlib.data.dataset import Dataset
|
||||
from qlib.model.base import Model
|
||||
import socket
|
||||
from qlib.utils import flatten_dict, get_cls_kwargs, init_instance_by_config
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord
|
||||
from qlib.workflow.recorder import Recorder
|
||||
from qlib.workflow.task.manage import TaskManager, run_task
|
||||
|
||||
|
||||
def begin_task_train(task_config: dict, experiment_name: str, *args, **kwargs) -> Recorder:
|
||||
def begin_task_train(task_config: dict, experiment_name: str, recorder_name: str = None) -> Recorder:
|
||||
"""
|
||||
Begin a task training with starting a recorder and saving the task config.
|
||||
Begin a task training to start a recorder and save the task config.
|
||||
|
||||
Args:
|
||||
task_config (dict)
|
||||
experiment_name (str)
|
||||
task_config (dict): the config of a task
|
||||
experiment_name (str): the name of experiment
|
||||
recorder_name (str): the given name will be the recorder name. None for using rid.
|
||||
|
||||
Returns:
|
||||
Recorder
|
||||
Recorder: the model recorder
|
||||
"""
|
||||
with R.start(experiment_name=experiment_name, recorder_name=str(time.time())):
|
||||
with R.start(experiment_name=experiment_name, recorder_name=recorder_name):
|
||||
R.log_params(**flatten_dict(task_config))
|
||||
R.save_objects(**{"task": task_config}) # keep the original format and datatype
|
||||
R.set_tags(**{"hostname": socket.gethostname(), "train_status": "begin_task_train"})
|
||||
R.set_tags(**{"hostname": socket.gethostname()})
|
||||
recorder: Recorder = R.get_recorder()
|
||||
return recorder
|
||||
|
||||
|
||||
def end_task_train(rec: Recorder, experiment_name: str, *args, **kwargs):
|
||||
def end_task_train(rec: Recorder, experiment_name: str) -> Recorder:
|
||||
"""
|
||||
Finished task training with real model fitting and saving.
|
||||
Finish task training with real model fitting and saving.
|
||||
|
||||
Args:
|
||||
rec (Recorder): This recorder will be resumed
|
||||
experiment_name (str)
|
||||
rec (Recorder): the recorder will be resumed
|
||||
experiment_name (str): the name of experiment
|
||||
|
||||
Returns:
|
||||
Recorder
|
||||
Recorder: the model recorder
|
||||
"""
|
||||
with R.start(experiment_name=experiment_name, recorder_name=rec.info["name"], resume=True):
|
||||
with R.start(experiment_name=experiment_name, recorder_id=rec.info["id"], resume=True):
|
||||
task_config = R.load_object("task")
|
||||
# model & dataset initiaiton
|
||||
# model & dataset initiation
|
||||
model: Model = init_instance_by_config(task_config["model"])
|
||||
dataset: Dataset = init_instance_by_config(task_config["dataset"])
|
||||
# model training
|
||||
model.fit(dataset)
|
||||
R.save_objects(**{"params.pkl": model})
|
||||
# This dataset is saved for online inference. So the concrete data should not be dumped
|
||||
# this dataset is saved for online inference. So the concrete data should not be dumped
|
||||
dataset.config(dump_all=False, recursive=True)
|
||||
R.save_objects(**{"dataset": dataset})
|
||||
# generate records: prediction, backtest, and analysis
|
||||
@@ -67,18 +78,18 @@ def end_task_train(rec: Recorder, experiment_name: str, *args, **kwargs):
|
||||
rconf = {"recorder": rec}
|
||||
r = cls(**kwargs, **rconf)
|
||||
r.generate()
|
||||
R.set_tags(**{"train_status": "end_task_train"})
|
||||
|
||||
return rec
|
||||
|
||||
|
||||
def task_train(task_config: dict, experiment_name: str) -> Recorder:
|
||||
"""
|
||||
task based training
|
||||
Task based training, will be divided into two steps.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task_config : dict
|
||||
A dict describes a task setting.
|
||||
The config of a task.
|
||||
experiment_name: str
|
||||
The name of experiment
|
||||
|
||||
@@ -96,39 +107,79 @@ class Trainer:
|
||||
The trainer which can train a list of model
|
||||
"""
|
||||
|
||||
def train(self, tasks: list, *args, **kwargs):
|
||||
"""Given a list of model definition, begin a training and return the models.
|
||||
def __init__(self):
|
||||
self.delay = False
|
||||
|
||||
def train(self, tasks: list, *args, **kwargs) -> list:
|
||||
"""
|
||||
Given a list of model definition, begin a training and return the models.
|
||||
|
||||
Args:
|
||||
tasks: a list of tasks
|
||||
|
||||
Returns:
|
||||
list: a list of models
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `train` method.")
|
||||
|
||||
def end_train(self, models, *args, **kwargs):
|
||||
"""Given a list of models, finished something in the end of training if you need.
|
||||
def end_train(self, models: list, *args, **kwargs) -> list:
|
||||
"""
|
||||
Given a list of models, finished something in the end of training if you need.
|
||||
The models maybe Recorder, txt file, database and so on.
|
||||
|
||||
Args:
|
||||
models: a list of models
|
||||
|
||||
Returns:
|
||||
list: a list of models
|
||||
"""
|
||||
# do nothing if you finished all work in `train` method
|
||||
return models
|
||||
|
||||
def is_delay(self) -> bool:
|
||||
"""
|
||||
If Trainer will delay finishing `end_train`.
|
||||
|
||||
Returns:
|
||||
bool: if DelayTrainer
|
||||
"""
|
||||
return self.delay
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Reset the Trainer status.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class TrainerR(Trainer):
|
||||
"""Trainer based on (R)ecorder.
|
||||
"""
|
||||
Trainer based on (R)ecorder.
|
||||
It will train a list of tasks and return a list of model recorder in a linear way.
|
||||
|
||||
Assumption: models were defined by `task` and the results will saved to `Recorder`
|
||||
"""
|
||||
|
||||
def __init__(self, experiment_name, train_func=task_train):
|
||||
def __init__(self, experiment_name: str, train_func: Callable = task_train):
|
||||
"""
|
||||
Init TrainerR.
|
||||
|
||||
Args:
|
||||
experiment_name (str): the name of experiment.
|
||||
train_func (Callable, optional): default training method. Defaults to `task_train`.
|
||||
"""
|
||||
super().__init__()
|
||||
self.experiment_name = experiment_name
|
||||
self.train_func = train_func
|
||||
|
||||
def train(self, tasks: list, train_func=None, *args, **kwargs):
|
||||
"""Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed.
|
||||
def train(self, tasks: list, train_func: Callable = None, **kwargs) -> List[Recorder]:
|
||||
"""
|
||||
Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed.
|
||||
|
||||
Args:
|
||||
tasks (list): a list of definition based on `task` dict
|
||||
train_func (Callable): the train method which need at least `task` and `experiment_name`. None for default.
|
||||
train_func (Callable): the train method which need at least `task`s and `experiment_name`. None for default training method.
|
||||
kwargs: the params for train_func.
|
||||
|
||||
Returns:
|
||||
list: a list of Recorders
|
||||
@@ -137,17 +188,74 @@ class TrainerR(Trainer):
|
||||
train_func = self.train_func
|
||||
recs = []
|
||||
for task in tasks:
|
||||
recs.append(train_func(task, self.experiment_name, *args, **kwargs))
|
||||
rec = train_func(task, self.experiment_name, **kwargs)
|
||||
rec.set_tags(**{"train_status": "begin_task_train"})
|
||||
recs.append(rec)
|
||||
return recs
|
||||
|
||||
def end_train(self, recs: list, **kwargs) -> list:
|
||||
for rec in recs:
|
||||
rec.set_tags(**{"train_status": "end_task_train"})
|
||||
return recs
|
||||
|
||||
|
||||
class DelayTrainerR(TrainerR):
|
||||
"""
|
||||
A delayed implementation based on TrainerR, which means `train` method may only do some preparation and `end_train` method can do the real model fitting.
|
||||
"""
|
||||
|
||||
def __init__(self, experiment_name, train_func=begin_task_train, end_train_func=end_task_train):
|
||||
"""
|
||||
Init TrainerRM.
|
||||
|
||||
Args:
|
||||
experiment_name (str): the name of experiment.
|
||||
train_func (Callable, optional): default train method. Defaults to `begin_task_train`.
|
||||
end_train_func (Callable, optional): default end_train method. Defaults to `end_task_train`.
|
||||
"""
|
||||
super().__init__(experiment_name, train_func)
|
||||
self.end_train_func = end_train_func
|
||||
self.delay = True
|
||||
|
||||
def end_train(self, recs, end_train_func=None, **kwargs) -> List[Recorder]:
|
||||
"""
|
||||
Given a list of Recorder and return a list of trained Recorder.
|
||||
This class will finish real data loading and model fitting.
|
||||
|
||||
Args:
|
||||
recs (list): a list of Recorder, the tasks have been saved to them
|
||||
end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
|
||||
kwargs: the params for end_train_func.
|
||||
|
||||
Returns:
|
||||
list: a list of Recorders
|
||||
"""
|
||||
if end_train_func is None:
|
||||
end_train_func = self.end_train_func
|
||||
for rec in recs:
|
||||
end_train_func(rec, **kwargs)
|
||||
rec.set_tags(**{"train_status": "end_task_train"})
|
||||
return recs
|
||||
|
||||
|
||||
class TrainerRM(Trainer):
|
||||
"""Trainer based on (R)ecorder and Task(M)anager
|
||||
"""
|
||||
Trainer based on (R)ecorder and Task(M)anager.
|
||||
It can train a list of tasks and return a list of model recorder in a multiprocessing way.
|
||||
|
||||
Assumption: `task` will be saved to TaskManager and `task` will be fetched and trained from TaskManager
|
||||
"""
|
||||
|
||||
def __init__(self, experiment_name: str, task_pool: str, train_func=task_train):
|
||||
"""
|
||||
Init TrainerR.
|
||||
|
||||
Args:
|
||||
experiment_name (str): the name of experiment.
|
||||
task_pool (str): task pool name in TaskManager.
|
||||
train_func (Callable, optional): default training method. Defaults to `task_train`.
|
||||
"""
|
||||
super().__init__()
|
||||
self.experiment_name = experiment_name
|
||||
self.task_pool = task_pool
|
||||
self.train_func = train_func
|
||||
@@ -155,20 +263,23 @@ class TrainerRM(Trainer):
|
||||
def train(
|
||||
self,
|
||||
tasks: list,
|
||||
train_func=None,
|
||||
before_status=TaskManager.STATUS_WAITING,
|
||||
after_status=TaskManager.STATUS_DONE,
|
||||
*args,
|
||||
train_func: Callable = None,
|
||||
before_status: str = TaskManager.STATUS_WAITING,
|
||||
after_status: str = TaskManager.STATUS_DONE,
|
||||
**kwargs,
|
||||
):
|
||||
"""Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed.
|
||||
) -> List[Recorder]:
|
||||
"""
|
||||
Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed.
|
||||
|
||||
This method defaults to a single process, but TaskManager offered a great way to parallel training.
|
||||
Users can customize their train_func to realize multiple processes or even multiple machines.
|
||||
|
||||
Args:
|
||||
tasks (list): a list of definition based on `task` dict
|
||||
train_func (Callable): the train method which need at least `task` and `experiment_name`. None for default.
|
||||
train_func (Callable): the train method which need at least `task`s and `experiment_name`. None for default training method.
|
||||
before_status (str): the tasks in before_status will be fetched and trained. Can be STATUS_WAITING, STATUS_PART_DONE.
|
||||
after_status (str): the tasks after trained will become after_status. Can be STATUS_WAITING, STATUS_PART_DONE.
|
||||
kwargs: the params for train_func.
|
||||
|
||||
Returns:
|
||||
list: a list of Recorders
|
||||
@@ -183,63 +294,29 @@ class TrainerRM(Trainer):
|
||||
experiment_name=self.experiment_name,
|
||||
before_status=before_status,
|
||||
after_status=after_status,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
recs = []
|
||||
for _id in _id_list:
|
||||
recs.append(tm.re_query(_id)["res"])
|
||||
rec = tm.re_query(_id)["res"]
|
||||
rec.set_tags(**{"train_status": "begin_task_train"})
|
||||
recs.append(rec)
|
||||
return recs
|
||||
|
||||
|
||||
class DelayTrainerR(TrainerR):
|
||||
"""
|
||||
A delayed implementation based on TrainerR, which means `train` method may only do some preparation and `end_train` method can do the real model fitting.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, experiment_name, train_func=begin_task_train, end_train_func=end_task_train):
|
||||
super().__init__(experiment_name, train_func)
|
||||
self.end_train_func = end_train_func
|
||||
self.recs = []
|
||||
|
||||
def train(self, tasks: list, train_func, *args, **kwargs):
|
||||
"""
|
||||
Same as `train` of TrainerR, the results will be recorded in self.recs
|
||||
|
||||
Args:
|
||||
tasks (list): a list of definition based on `task` dict
|
||||
train_func (Callable): the train method which need at least `task` and `experiment_name`. None for default.
|
||||
|
||||
Returns:
|
||||
list: a list of Recorders
|
||||
"""
|
||||
self.recs = super().train(tasks, train_func=train_func, *args, **kwargs)
|
||||
return self.recs
|
||||
|
||||
def end_train(self, recs=None, end_train_func=None):
|
||||
"""
|
||||
Given a list of Recorder and return a list of trained Recorder.
|
||||
This class will finished real data loading and model fitting.
|
||||
|
||||
Args:
|
||||
recs (list, optional): a list of Recorder, the tasks have been saved to them. Defaults to None for using self.recs.
|
||||
end_train_func (Callable, optional): the end_train method which need at least `rec` and `experiment_name`. Defaults to None for using self.end_train_func.
|
||||
|
||||
Returns:
|
||||
list: a list of Recorders
|
||||
"""
|
||||
if recs is None:
|
||||
recs = copy.deepcopy(self.recs)
|
||||
# the models will be only trained once
|
||||
self.recs = []
|
||||
if end_train_func is None:
|
||||
end_train_func = self.end_train_func
|
||||
def end_train(self, recs: list, **kwargs) -> list:
|
||||
for rec in recs:
|
||||
end_train_func(rec)
|
||||
rec.set_tags(**{"train_status": "end_task_train"})
|
||||
return recs
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
.. note::
|
||||
this method will delete all task in this task_pool!
|
||||
"""
|
||||
tm = TaskManager(task_pool=self.task_pool)
|
||||
tm.remove()
|
||||
|
||||
|
||||
class DelayTrainerRM(TrainerRM):
|
||||
"""
|
||||
@@ -250,28 +327,28 @@ class DelayTrainerRM(TrainerRM):
|
||||
def __init__(self, experiment_name, task_pool: str, train_func=begin_task_train, end_train_func=end_task_train):
|
||||
super().__init__(experiment_name, task_pool, train_func)
|
||||
self.end_train_func = end_train_func
|
||||
self.delay = True
|
||||
|
||||
def train(self, tasks: list, train_func=None, *args, **kwargs):
|
||||
def train(self, tasks: list, train_func=None, **kwargs):
|
||||
"""
|
||||
Same as `train` of TrainerRM, the results will be recorded in self.recs
|
||||
|
||||
Same as `train` of TrainerRM, after_status will be STATUS_PART_DONE.
|
||||
Args:
|
||||
tasks (list): a list of definition based on `task` dict
|
||||
train_func (Callable): the train method which need at least `task` and `experiment_name`. None for default.
|
||||
|
||||
train_func (Callable): the train method which need at least `task`s and `experiment_name`. Defaults to None for using self.train_func.
|
||||
Returns:
|
||||
list: a list of Recorders
|
||||
"""
|
||||
return super().train(tasks, train_func=train_func, after_status=TaskManager.STATUS_PART_DONE, *args, **kwargs)
|
||||
return super().train(tasks, train_func=train_func, after_status=TaskManager.STATUS_PART_DONE, **kwargs)
|
||||
|
||||
def end_train(self, recs, end_train_func=None):
|
||||
def end_train(self, recs, end_train_func=None, **kwargs):
|
||||
"""
|
||||
Given a list of Recorder and return a list of trained Recorder.
|
||||
This class will finished real data loading and model fitting.
|
||||
This class will finish real data loading and model fitting.
|
||||
|
||||
Args:
|
||||
recs (list, optional): a list of Recorder, the tasks have been saved to them. Defaults to None for using self.recs..
|
||||
end_train_func (Callable, optional): the end_train method which need at least `rec` and `experiment_name`. Defaults to None for using self.end_train_func.
|
||||
recs (list): a list of Recorder, the tasks have been saved to them.
|
||||
end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
|
||||
kwargs: the params for end_train_func.
|
||||
|
||||
Returns:
|
||||
list: a list of Recorders
|
||||
@@ -284,5 +361,8 @@ class DelayTrainerRM(TrainerRM):
|
||||
self.task_pool,
|
||||
experiment_name=self.experiment_name,
|
||||
before_status=TaskManager.STATUS_PART_DONE,
|
||||
**kwargs,
|
||||
)
|
||||
for rec in recs:
|
||||
rec.set_tags(**{"train_status": "end_task_train"})
|
||||
return recs
|
||||
|
||||
@@ -5,9 +5,9 @@ import abc
|
||||
|
||||
|
||||
class BaseOptimizer(abc.ABC):
|
||||
""" Construct portfolio with a optimization related method """
|
||||
"""Construct portfolio with a optimization related method"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def __call__(self, *args, **kwargs) -> object:
|
||||
""" Generate a optimized portfolio allocation """
|
||||
"""Generate a optimized portfolio allocation"""
|
||||
pass
|
||||
|
||||
@@ -3,11 +3,12 @@
|
||||
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
from typing import Union
|
||||
|
||||
|
||||
class Serializable:
|
||||
"""
|
||||
Serializable will change the behaviours of pickle.
|
||||
Serializable will change the behaviors of pickle.
|
||||
- It only saves the state whose name **does not** start with `_`
|
||||
It provides a syntactic sugar for distinguish the attributes which user doesn't want.
|
||||
- For examples, a learnable Datahandler just wants to save the parameters without data when dumping to disk
|
||||
@@ -70,7 +71,7 @@ class Serializable:
|
||||
obj.config(**params, recursive=True)
|
||||
del self.__dict__[self.FLAG_KEY]
|
||||
|
||||
def to_pickle(self, path: [Path, str], dump_all: bool = None, exclude: list = None):
|
||||
def to_pickle(self, path: Union[Path, str], dump_all: bool = None, exclude: list = None):
|
||||
self.config(dump_all=dump_all, exclude=exclude)
|
||||
with Path(path).open("wb") as f:
|
||||
pickle.dump(self, f)
|
||||
|
||||
@@ -23,7 +23,10 @@ class QlibRecorder:
|
||||
@contextmanager
|
||||
def start(
|
||||
self,
|
||||
*,
|
||||
experiment_id: Optional[Text] = None,
|
||||
experiment_name: Optional[Text] = None,
|
||||
recorder_id: Optional[Text] = None,
|
||||
recorder_name: Optional[Text] = None,
|
||||
uri: Optional[Text] = None,
|
||||
resume: bool = False,
|
||||
@@ -45,8 +48,12 @@ class QlibRecorder:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
experiment_id : str
|
||||
id of the experiment one wants to start.
|
||||
experiment_name : str
|
||||
name of the experiment one wants to start.
|
||||
recorder_id : str
|
||||
id of the recorder under the experiment one wants to start.
|
||||
recorder_name : str
|
||||
name of the recorder under the experiment one wants to start.
|
||||
uri : str
|
||||
@@ -57,7 +64,14 @@ class QlibRecorder:
|
||||
resume : bool
|
||||
whether to resume the specific recorder with given name under the given experiment.
|
||||
"""
|
||||
run = self.start_exp(experiment_name, recorder_name, uri, resume)
|
||||
run = self.start_exp(
|
||||
experiment_id=experiment_id,
|
||||
experiment_name=experiment_name,
|
||||
recorder_id=recorder_id,
|
||||
recorder_name=recorder_name,
|
||||
uri=uri,
|
||||
resume=resume,
|
||||
)
|
||||
try:
|
||||
yield run
|
||||
except Exception as e:
|
||||
@@ -65,7 +79,9 @@ class QlibRecorder:
|
||||
raise e
|
||||
self.end_exp(Recorder.STATUS_FI)
|
||||
|
||||
def start_exp(self, experiment_name=None, recorder_name=None, uri=None, resume=False):
|
||||
def start_exp(
|
||||
self, *, experiment_id=None, experiment_name=None, recorder_id=None, recorder_name=None, uri=None, resume=False
|
||||
):
|
||||
"""
|
||||
Lower level method for starting an experiment. When use this method, one should end the experiment manually
|
||||
and the status of the recorder may not be handled properly. Here is the example code:
|
||||
@@ -79,8 +95,12 @@ class QlibRecorder:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
experiment_id : str
|
||||
id of the experiment one wants to start.
|
||||
experiment_name : str
|
||||
the name of the experiment to be started
|
||||
recorder_id : str
|
||||
id of the recorder under the experiment one wants to start.
|
||||
recorder_name : str
|
||||
name of the recorder under the experiment one wants to start.
|
||||
uri : str
|
||||
@@ -93,7 +113,14 @@ class QlibRecorder:
|
||||
-------
|
||||
An experiment instance being started.
|
||||
"""
|
||||
return self.exp_manager.start_exp(experiment_name, recorder_name, uri, resume)
|
||||
return self.exp_manager.start_exp(
|
||||
experiment_id=experiment_id,
|
||||
experiment_name=experiment_name,
|
||||
recorder_id=recorder_id,
|
||||
recorder_name=recorder_name,
|
||||
uri=uri,
|
||||
resume=resume,
|
||||
)
|
||||
|
||||
def end_exp(self, recorder_status=Recorder.STATUS_FI):
|
||||
"""
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import mlflow
|
||||
import mlflow, logging
|
||||
from mlflow.entities import ViewType
|
||||
from mlflow.exceptions import MlflowException
|
||||
from pathlib import Path
|
||||
from .recorder import Recorder, MLflowRecorder
|
||||
from ..log import get_module_logger
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
logger = get_module_logger("workflow", logging.INFO)
|
||||
|
||||
|
||||
class Experiment:
|
||||
@@ -39,12 +39,14 @@ class Experiment:
|
||||
output["recorders"] = list(recorders.keys())
|
||||
return output
|
||||
|
||||
def start(self, recorder_name=None, resume=False):
|
||||
def start(self, *, recorder_id=None, recorder_name=None, resume=False):
|
||||
"""
|
||||
Start the experiment and set it to be active. This method will also start a new recorder.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
recorder_id : str
|
||||
the id of the recorder to be created.
|
||||
recorder_name : str
|
||||
the name of the recorder to be created.
|
||||
resume : bool
|
||||
@@ -238,14 +240,14 @@ class MLflowExperiment(Experiment):
|
||||
def __repr__(self):
|
||||
return "{name}(id={id}, info={info})".format(name=self.__class__.__name__, id=self.id, info=self.info)
|
||||
|
||||
def start(self, recorder_name=None, resume=False):
|
||||
def start(self, *, recorder_id=None, recorder_name=None, resume=False):
|
||||
logger.info(f"Experiment {self.id} starts running ...")
|
||||
# Get or create recorder
|
||||
if recorder_name is None:
|
||||
recorder_name = self._default_rec_name
|
||||
# resume the recorder
|
||||
if resume:
|
||||
recorder, _ = self._get_or_create_rec(recorder_name=recorder_name)
|
||||
recorder, _ = self._get_or_create_rec(recorder_id=recorder_id, recorder_name=recorder_name)
|
||||
# create a new recorder
|
||||
else:
|
||||
recorder = self.create_recorder(recorder_name)
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
import mlflow
|
||||
from mlflow.exceptions import MlflowException
|
||||
from mlflow.entities import ViewType
|
||||
import os
|
||||
import os, logging
|
||||
from pathlib import Path
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Text
|
||||
@@ -14,7 +14,7 @@ from ..config import C
|
||||
from .recorder import Recorder
|
||||
from ..log import get_module_logger
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
logger = get_module_logger("workflow", logging.INFO)
|
||||
|
||||
|
||||
class ExpManager:
|
||||
@@ -33,7 +33,10 @@ class ExpManager:
|
||||
|
||||
def start_exp(
|
||||
self,
|
||||
*,
|
||||
experiment_id: Optional[Text] = None,
|
||||
experiment_name: Optional[Text] = None,
|
||||
recorder_id: Optional[Text] = None,
|
||||
recorder_name: Optional[Text] = None,
|
||||
uri: Optional[Text] = None,
|
||||
resume: bool = False,
|
||||
@@ -45,8 +48,12 @@ class ExpManager:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
experiment_id : str
|
||||
id of the active experiment.
|
||||
experiment_name : str
|
||||
name of the active experiment.
|
||||
recorder_id : str
|
||||
id of the recorder to be started.
|
||||
recorder_name : str
|
||||
name of the recorder to be started.
|
||||
uri : str
|
||||
@@ -298,7 +305,10 @@ class MLflowExpManager(ExpManager):
|
||||
|
||||
def start_exp(
|
||||
self,
|
||||
*,
|
||||
experiment_id: Optional[Text] = None,
|
||||
experiment_name: Optional[Text] = None,
|
||||
recorder_id: Optional[Text] = None,
|
||||
recorder_name: Optional[Text] = None,
|
||||
uri: Optional[Text] = None,
|
||||
resume: bool = False,
|
||||
@@ -308,11 +318,11 @@ class MLflowExpManager(ExpManager):
|
||||
# Create experiment
|
||||
if experiment_name is None:
|
||||
experiment_name = self._default_exp_name
|
||||
experiment, _ = self._get_or_create_exp(experiment_name=experiment_name)
|
||||
experiment, _ = self._get_or_create_exp(experiment_id=experiment_id, experiment_name=experiment_name)
|
||||
# Set up active experiment
|
||||
self.active_experiment = experiment
|
||||
# Start the experiment
|
||||
self.active_experiment.start(recorder_name, resume)
|
||||
self.active_experiment.start(recorder_id=recorder_id, recorder_name=recorder_name, resume=resume)
|
||||
|
||||
return self.active_experiment
|
||||
|
||||
|
||||
@@ -1,477 +1,177 @@
|
||||
from copy import deepcopy
|
||||
from operator import index
|
||||
import pandas as pd
|
||||
from qlib.model.ens.ensemble import ens_workflow
|
||||
from qlib.model.ens.group import RollingGroup
|
||||
from qlib.utils.serial import Serializable
|
||||
from typing import Dict, List, Union
|
||||
from qlib import get_module_logger
|
||||
from qlib.data.data import D
|
||||
from qlib.model.trainer import Trainer, TrainerR, task_train
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.online.update import PredUpdater
|
||||
from qlib.workflow.recorder import Recorder
|
||||
from qlib.workflow.task.collect import Collector, RecorderCollector
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.utils import TimeAdjuster, list_recorders
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
This class is a component of online serving, it can manage a series of models dynamically.
|
||||
With the change of time, the decisive models will be also changed. In this module, we called those contributing models as `online` models.
|
||||
OnlineManager can manage a set of `Online Strategy <#Online Strategy>`_ and run them dynamically.
|
||||
|
||||
With the change of time, the decisive models will be also changed. In this module, we call those contributing models as `online` models.
|
||||
In every routine(such as everyday or every minutes), the `online` models maybe changed and the prediction of them need to be updated.
|
||||
So this module provide a series methods to control this process.
|
||||
|
||||
This module also provide a method to simulate `Online Strategy <#Online Strategy>`_ in the history.
|
||||
Which means you can verify your strategy or find a better one.
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import pandas as pd
|
||||
from qlib import get_module_logger
|
||||
from qlib.data.data import D
|
||||
from qlib.model.ens.ensemble import AverageEnsemble, SingleKeyEnsemble
|
||||
from qlib.utils.serial import Serializable
|
||||
from qlib.workflow.online.strategy import OnlineStrategy
|
||||
from qlib.workflow.task.collect import HyperCollector
|
||||
|
||||
|
||||
class OnlineManager(Serializable):
|
||||
|
||||
ONLINE_KEY = "online_status" # the online status key in recorder
|
||||
ONLINE_TAG = "online" # the 'online' model
|
||||
# NOTE: The meaning of this tag is that we can not assume the training models can be trained before we need its predition. Whenever finished training, it can be guaranteed that there are some online models.
|
||||
NEXT_ONLINE_TAG = "next_online" # the 'next online' model, which can be 'online' model when call reset_online_model
|
||||
OFFLINE_TAG = "offline" # the 'offline' model, not for online serving
|
||||
|
||||
SIGNAL_EXP = "OnlineManagerSignals" # a specific experiment to save signals of different experiment.
|
||||
|
||||
def __init__(self, trainer: Trainer = None, need_log=True):
|
||||
"""
|
||||
init OnlineManager.
|
||||
|
||||
Args:
|
||||
trainer (Trainer, optional): a instance of Trainer. Defaults to None.
|
||||
need_log (bool, optional): print log or not. Defaults to True.
|
||||
"""
|
||||
self.trainer = trainer
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
self.need_log = need_log
|
||||
self.cur_time = None
|
||||
|
||||
def prepare_signals(self):
|
||||
"""
|
||||
After perparing the data of last routine (a box in box-plot) which means the end of the routine, we can prepare trading signals for next routine.
|
||||
Must use `pass` even though there is nothing to do.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `prepare_signals` method.")
|
||||
|
||||
def get_signals(self):
|
||||
"""
|
||||
After preparing signals, here is the method to get them.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_signals` method.")
|
||||
|
||||
def prepare_tasks(self, *args, **kwargs):
|
||||
"""
|
||||
After the end of a routine, check whether we need to prepare and train some new tasks.
|
||||
return the new tasks waiting for training.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `prepare_tasks` method.")
|
||||
|
||||
def prepare_new_models(self, tasks, tag=NEXT_ONLINE_TAG, check_func=None, *args, **kwargs):
|
||||
"""
|
||||
Use trainer to train a list of tasks and set the trained model to `tag`.
|
||||
|
||||
Args:
|
||||
tasks (list): a list of tasks.
|
||||
tag (str):
|
||||
`ONLINE_TAG` for first train or additional train
|
||||
`NEXT_ONLINE_TAG` for reset online model when calling `reset_online_tag`
|
||||
`OFFLINE_TAG` for train but offline those models
|
||||
check_func: the method to judge if a model can be online.
|
||||
The parameter is the model record and return True for online.
|
||||
None for online every models.
|
||||
*args, **kwargs: will be passed to end_train which means will be passed to customized train method.
|
||||
|
||||
"""
|
||||
if check_func is None:
|
||||
check_func = lambda x: True
|
||||
if len(tasks) > 0:
|
||||
if self.trainer is not None:
|
||||
new_models = self.trainer.train(tasks, *args, **kwargs)
|
||||
if check_func(new_models):
|
||||
self.set_online_tag(tag, new_models)
|
||||
if self.need_log:
|
||||
self.logger.info(f"Finished preparing {len(new_models)} new models and set them to {tag}.")
|
||||
else:
|
||||
self.logger.warn("No trainer to train new tasks.")
|
||||
|
||||
def update_online_pred(self):
|
||||
"""
|
||||
After the end of a routine, update the predictions of online models to latest.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `update_online_pred` method.")
|
||||
|
||||
def set_online_tag(self, tag, recorder):
|
||||
"""
|
||||
Set `tag` to the model to sign whether online.
|
||||
|
||||
Args:
|
||||
tag (str): the tags in `ONLINE_TAG`, `NEXT_ONLINE_TAG`, `OFFLINE_TAG`
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `set_online_tag` method.")
|
||||
|
||||
def get_online_tag(self):
|
||||
"""
|
||||
Given a model and return its online tag.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_online_tag` method.")
|
||||
|
||||
def reset_online_tag(self, recorders=None):
|
||||
"""offline all models and set the recorders to 'online'. If no parameter and no 'next online' model, then do nothing.
|
||||
|
||||
Args:
|
||||
recorders (List, optional):
|
||||
the recorders you want to reset to 'online'. If don't give, set 'next online' model to 'online' model. If there isn't any 'next online' model, then maintain existing 'online' model.
|
||||
|
||||
Returns:
|
||||
list: new online recorder. [] if there is no update.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `reset_online_tag` method.")
|
||||
|
||||
def online_models(self):
|
||||
"""
|
||||
Return online models.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `online_models` method.")
|
||||
|
||||
def first_train(self):
|
||||
"""
|
||||
Train a series of models firstly and set some of them into online models.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `first_train` method.")
|
||||
|
||||
def get_collector(self):
|
||||
"""
|
||||
Return the collector.
|
||||
|
||||
Returns:
|
||||
Collector
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_collector` method.")
|
||||
|
||||
def delay_prepare(self, rec_dict, *args, **kwargs):
|
||||
"""
|
||||
Prepare all models and signals if there are something waiting for prepare.
|
||||
NOTE: Assumption: the predictions of online models are between `time_segment`, or this method will work in a wrong way.
|
||||
|
||||
Args:
|
||||
rec_dict (str): an online models dict likes {(begin_time, end_time):[online models]}.
|
||||
*args, **kwargs: will be passed to end_train which means will be passed to customized train method.
|
||||
"""
|
||||
for time_segment, recs_list in rec_dict.items():
|
||||
self.trainer.end_train(recs_list, *args, **kwargs)
|
||||
self.reset_online_tag(recs_list)
|
||||
self.prepare_signals()
|
||||
signal_max = self.get_signals().index.get_level_values("datetime").max()
|
||||
if time_segment[1] is not None and signal_max > time_segment[1]:
|
||||
raise ValueError(
|
||||
f"The max time of signals prepared by online models is {signal_max}, but those models only online in {time_segment}"
|
||||
)
|
||||
|
||||
def routine(self, cur_time=None, delay_prepare=False, *args, **kwargs):
|
||||
"""
|
||||
The typical update process after a routine, such as day by day or month by month.
|
||||
update online prediction -> prepare signals -> prepare tasks -> prepare new models -> reset online models
|
||||
|
||||
NOTE: Assumption: if using simulator (delay_prepare is True), the prediction will be prepared well after every training, so there is no need to update predictions.
|
||||
|
||||
Args:
|
||||
cur_time ([type], optional): [description]. Defaults to None.
|
||||
delay_prepare (bool, optional): [description]. Defaults to False.
|
||||
*args, **kwargs: will be passed to `prepare_tasks` and `prepare_new_models`. It can be some hyper parameter or training config.
|
||||
|
||||
Returns:
|
||||
[type]: [description]
|
||||
"""
|
||||
self.cur_time = cur_time # None for latest date
|
||||
if not delay_prepare:
|
||||
self.update_online_pred()
|
||||
self.prepare_signals()
|
||||
tasks = self.prepare_tasks(*args, **kwargs)
|
||||
self.prepare_new_models(tasks, *args, **kwargs)
|
||||
|
||||
return self.reset_online_tag()
|
||||
|
||||
|
||||
class OnlineManagerR(OnlineManager):
|
||||
"""
|
||||
The implementation of OnlineManager based on (R)ecorder.
|
||||
|
||||
OnlineManager can manage online models with `Online Strategy <#Online Strategy>`_.
|
||||
It also provide a history recording which models are onlined at what time.
|
||||
"""
|
||||
|
||||
def __init__(self, experiment_name: str, trainer: Trainer = None, need_log=True):
|
||||
"""
|
||||
init OnlineManagerR.
|
||||
|
||||
Args:
|
||||
experiment_name (str): the experiment name.
|
||||
trainer (Trainer, optional): a instance of Trainer. Defaults to None.
|
||||
need_log (bool, optional): print log or not. Defaults to True.
|
||||
"""
|
||||
if trainer is None:
|
||||
trainer = TrainerR(experiment_name)
|
||||
super().__init__(trainer=trainer, need_log=need_log)
|
||||
self.exp_name = experiment_name
|
||||
self.signal_rec = None
|
||||
|
||||
def set_online_tag(self, tag, recorder: Union[Recorder, List]):
|
||||
"""
|
||||
Set `tag` to the model to sign whether online.
|
||||
|
||||
Args:
|
||||
tag (str): the tags in `ONLINE_TAG`, `NEXT_ONLINE_TAG`, `OFFLINE_TAG`
|
||||
recorder (Union[Recorder, List])
|
||||
"""
|
||||
if isinstance(recorder, Recorder):
|
||||
recorder = [recorder]
|
||||
for rec in recorder:
|
||||
rec.set_tags(**{self.ONLINE_KEY: tag})
|
||||
if self.need_log:
|
||||
self.logger.info(f"Set {len(recorder)} models to '{tag}'.")
|
||||
|
||||
def get_online_tag(self, recorder: Recorder):
|
||||
"""
|
||||
Given a model and return its online tag.
|
||||
|
||||
Args:
|
||||
recorder (Recorder): a instance of recorder
|
||||
|
||||
Returns:
|
||||
str: the tag
|
||||
"""
|
||||
tags = recorder.list_tags()
|
||||
return tags.get(OnlineManager.ONLINE_KEY, OnlineManager.OFFLINE_TAG)
|
||||
|
||||
def reset_online_tag(self, recorder: Union[Recorder, List] = None):
|
||||
"""offline all models and set the recorders to 'online'. If no parameter and no 'next online' model, then do nothing.
|
||||
|
||||
Args:
|
||||
recorders (Union[Recorder, List], optional):
|
||||
the recorders you want to reset to 'online'. If don't give, set 'next online' model to 'online' model. If there isn't any 'next online' model, then maintain existing 'online' model.
|
||||
|
||||
Returns:
|
||||
list: new online recorder. [] if there is no update.
|
||||
"""
|
||||
if recorder is None:
|
||||
recorder = list(
|
||||
list_recorders(
|
||||
self.exp_name, lambda rec: self.get_online_tag(rec) == OnlineManager.NEXT_ONLINE_TAG
|
||||
).values()
|
||||
)
|
||||
if isinstance(recorder, Recorder):
|
||||
recorder = [recorder]
|
||||
if len(recorder) == 0:
|
||||
if self.need_log:
|
||||
self.logger.info("No 'next online' model, just use current 'online' models.")
|
||||
return []
|
||||
recs = list_recorders(self.exp_name)
|
||||
self.set_online_tag(OnlineManager.OFFLINE_TAG, list(recs.values()))
|
||||
self.set_online_tag(OnlineManager.ONLINE_TAG, recorder)
|
||||
return recorder
|
||||
|
||||
def get_signals(self):
|
||||
"""
|
||||
get signals from the recorder(named self.exp_name) of the experiment(named self.SIGNAL_EXP)
|
||||
|
||||
Returns:
|
||||
signals
|
||||
"""
|
||||
if self.signal_rec is None:
|
||||
with R.start(experiment_name=self.SIGNAL_EXP, recorder_name=self.exp_name, resume=True):
|
||||
self.signal_rec = R.get_recorder()
|
||||
signals = None
|
||||
try:
|
||||
signals = self.signal_rec.load_object("signals")
|
||||
except OSError:
|
||||
self.logger.warn("Can not find `signals`, have you called `prepare_signals` before?")
|
||||
return signals
|
||||
|
||||
def online_models(self):
|
||||
"""
|
||||
Return online models.
|
||||
|
||||
Returns:
|
||||
list: the list of online models
|
||||
"""
|
||||
return list(
|
||||
list_recorders(self.exp_name, lambda rec: self.get_online_tag(rec) == OnlineManager.ONLINE_TAG).values()
|
||||
)
|
||||
|
||||
def update_online_pred(self):
|
||||
"""
|
||||
Update all online model predictions to the latest day in Calendar
|
||||
"""
|
||||
online_models = self.online_models()
|
||||
for rec in online_models:
|
||||
PredUpdater(rec, to_date=self.cur_time, need_log=self.need_log).update()
|
||||
|
||||
if self.need_log:
|
||||
self.logger.info(f"Finished updating {len(online_models)} online model predictions of {self.exp_name}.")
|
||||
|
||||
def prepare_signals(self, over_write=False):
|
||||
"""
|
||||
Average the predictions of online models and offer a trading signals every routine.
|
||||
The signals will be saved to `signal` file of a recorder named self.exp_name of a experiment using the name of `SIGNAL_EXP`
|
||||
Even if the latest signal already exists, the latest calculation result will be overwritten.
|
||||
NOTE: Given a prediction of a certain time, all signals before this time will be prepared well.
|
||||
Args:
|
||||
over_write (bool, optional): If True, the new signals will overwrite the file. If False, the new signals will append to the end of signals. Defaults to False.
|
||||
"""
|
||||
if self.signal_rec is None:
|
||||
with R.start(experiment_name=self.SIGNAL_EXP, recorder_name=self.exp_name, resume=True):
|
||||
self.signal_rec = R.get_recorder()
|
||||
|
||||
pred = []
|
||||
try:
|
||||
old_signals = self.signal_rec.load_object("signals")
|
||||
except OSError:
|
||||
old_signals = None
|
||||
|
||||
for rec in self.online_models():
|
||||
pred.append(rec.load_object("pred.pkl"))
|
||||
|
||||
signals = pd.concat(pred, axis=1).mean(axis=1).to_frame("score")
|
||||
signals = signals.sort_index()
|
||||
if old_signals is not None and not over_write:
|
||||
old_max = old_signals.index.get_level_values("datetime").max()
|
||||
new_signals = signals.loc[old_max:]
|
||||
signals = pd.concat([old_signals, new_signals], axis=0)
|
||||
else:
|
||||
new_signals = signals
|
||||
if self.need_log:
|
||||
self.logger.info(f"Finished preparing new {len(new_signals)} signals to {self.SIGNAL_EXP}/{self.exp_name}.")
|
||||
self.signal_rec.save_objects(**{"signals": signals})
|
||||
|
||||
|
||||
class RollingOnlineManager(OnlineManagerR):
|
||||
"""An implementation of OnlineManager based on Rolling."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
experiment_name: str,
|
||||
rolling_gen: RollingGen,
|
||||
trainer: Trainer = None,
|
||||
strategy: Union[OnlineStrategy, List[OnlineStrategy]],
|
||||
begin_time: Union[str, pd.Timestamp] = None,
|
||||
freq="day",
|
||||
need_log=True,
|
||||
):
|
||||
"""
|
||||
init RollingOnlineManager.
|
||||
Init OnlineManager.
|
||||
One OnlineManager must have at least one OnlineStrategy.
|
||||
|
||||
Args:
|
||||
experiment_name (str): the experiment name.
|
||||
rolling_gen (RollingGen): a instance of RollingGen
|
||||
trainer (Trainer, optional): a instance of Trainer. Defaults to None.
|
||||
collector (Collector, optional): a instance of Collector. Defaults to None.
|
||||
strategy (Union[OnlineStrategy, List[OnlineStrategy]]): an instance of OnlineStrategy or a list of OnlineStrategy
|
||||
begin_time (Union[str,pd.Timestamp], optional): the OnlineManager will begin at this time. Defaults to None for using latest date.
|
||||
freq (str, optional): data frequency. Defaults to "day".
|
||||
need_log (bool, optional): print log or not. Defaults to True.
|
||||
"""
|
||||
if trainer is None:
|
||||
trainer = TrainerR(experiment_name)
|
||||
super().__init__(experiment_name=experiment_name, trainer=trainer, need_log=need_log)
|
||||
self.ta = TimeAdjuster()
|
||||
self.rg = rolling_gen
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
self.need_log = need_log
|
||||
if not isinstance(strategy, list):
|
||||
strategy = [strategy]
|
||||
self.strategy = strategy
|
||||
self.freq = freq
|
||||
if begin_time is None:
|
||||
begin_time = D.calendar(freq=self.freq).max()
|
||||
self.begin_time = pd.Timestamp(begin_time)
|
||||
self.cur_time = self.begin_time
|
||||
self.history = {}
|
||||
|
||||
def get_collector(self, rec_key_func=None, rec_filter_func=None):
|
||||
def first_train(self):
|
||||
"""
|
||||
Get the instance of collector to collect results. The returned collector must can distinguish results in different models.
|
||||
Assumption: the models can be distinguished based on model name and rolling test segments.
|
||||
If you do not want this assumption, please implement your own method or use another rec_key_func.
|
||||
Run every strategy first_train method and record the online history.
|
||||
"""
|
||||
for strategy in self.strategy:
|
||||
self.logger.info(f"Strategy `{strategy.name_id}` begins first training...")
|
||||
online_models = strategy.first_train()
|
||||
self.history.setdefault(strategy.name_id, {})[self.cur_time] = online_models
|
||||
|
||||
def routine(self, cur_time: Union[str, pd.Timestamp] = None, task_kwargs: dict = {}, model_kwargs: dict = {}):
|
||||
"""
|
||||
Run typical update process for every strategy and record the online history.
|
||||
|
||||
The typical update process after a routine, such as day by day or month by month.
|
||||
The process is: Prepare signals -> Prepare tasks -> Prepare online models.
|
||||
|
||||
Args:
|
||||
rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.
|
||||
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
|
||||
cur_time (Union[str,pd.Timestamp], optional): run routine method in this time. Defaults to None.
|
||||
task_kwargs (dict): the params for `prepare_tasks`
|
||||
model_kwargs (dict): the params for `prepare_online_models`
|
||||
"""
|
||||
if cur_time is None:
|
||||
cur_time = D.calendar(freq=self.freq).max()
|
||||
self.cur_time = pd.Timestamp(cur_time) # None for latest date
|
||||
for strategy in self.strategy:
|
||||
if self.need_log:
|
||||
self.logger.info(f"Strategy `{strategy.name_id}` begins routine...")
|
||||
if not strategy.trainer.is_delay():
|
||||
strategy.prepare_signals()
|
||||
tasks = strategy.prepare_tasks(self.cur_time, **task_kwargs)
|
||||
online_models = strategy.prepare_online_models(tasks, **model_kwargs)
|
||||
if len(online_models) > 0:
|
||||
self.history.setdefault(strategy.name_id, {})[self.cur_time] = online_models
|
||||
|
||||
def rec_key(recorder):
|
||||
task_config = recorder.load_object("task")
|
||||
model_key = task_config["model"]["class"]
|
||||
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
|
||||
return model_key, rolling_key
|
||||
|
||||
if rec_key_func is None:
|
||||
rec_key_func = rec_key
|
||||
|
||||
return RecorderCollector(experiment=self.exp_name, rec_key_func=rec_key_func, rec_filter_func=rec_filter_func)
|
||||
|
||||
def collect_artifact(self, rec_key_func=None, rec_filter_func=None):
|
||||
def get_collector(self) -> HyperCollector:
|
||||
"""
|
||||
collecting artifact based on the collector and RollingGroup.
|
||||
|
||||
Args:
|
||||
rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.
|
||||
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
|
||||
Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results from every strategy.
|
||||
|
||||
Returns:
|
||||
dict: the artifact dict after rolling ensemble
|
||||
HyperCollector: the collector to collect other collectors (using SingleKeyEnsemble() to make results more readable).
|
||||
"""
|
||||
artifact = ens_workflow(
|
||||
self.get_collector(rec_key_func=rec_key_func, rec_filter_func=rec_filter_func), RollingGroup()
|
||||
)
|
||||
return artifact
|
||||
collector_dict = {}
|
||||
for strategy in self.strategy:
|
||||
collector_dict[strategy.name_id] = strategy.get_collector()
|
||||
return HyperCollector(collector_dict, process_list=SingleKeyEnsemble())
|
||||
|
||||
def first_train(self, task_configs: list):
|
||||
def get_online_history(self, strategy_name_id: str) -> list:
|
||||
"""
|
||||
Use rolling_gen to generate different tasks based on task_configs and trained them.
|
||||
Get the online history based on strategy_name_id.
|
||||
|
||||
Args:
|
||||
task_configs (list or dict): a list of task configs or a task config
|
||||
strategy_name_id (str): the name_id of strategy
|
||||
|
||||
Returns:
|
||||
Collector: a instance of a Collector.
|
||||
list: a list like [(begin_time, [online_models])]
|
||||
"""
|
||||
tasks = task_generator(
|
||||
tasks=task_configs,
|
||||
generators=self.rg, # generate different date segment
|
||||
)
|
||||
self.prepare_new_models(tasks, tag=self.ONLINE_TAG)
|
||||
history_dict = self.history[strategy_name_id]
|
||||
history = []
|
||||
for time in sorted(history_dict):
|
||||
models = history_dict[time]
|
||||
history.append((time, models))
|
||||
return history
|
||||
|
||||
def delay_prepare(self, delay_kwargs={}):
|
||||
"""
|
||||
Prepare all models and signals if there are something waiting for prepare.
|
||||
|
||||
Args:
|
||||
delay_kwargs: the params for `delay_prepare`
|
||||
"""
|
||||
for strategy in self.strategy:
|
||||
strategy.delay_prepare(self.get_online_history(strategy.name_id), **delay_kwargs)
|
||||
|
||||
def get_signals(self) -> pd.DataFrame:
|
||||
"""
|
||||
Average all strategy signals as the online signals.
|
||||
|
||||
Assumption: the signals from every strategy is pd.DataFrame. Override this function to change.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: signals
|
||||
"""
|
||||
signals_dict = {}
|
||||
for strategy in self.strategy:
|
||||
signals_dict[strategy.name_id] = strategy.get_signals()
|
||||
return AverageEnsemble()(signals_dict)
|
||||
|
||||
def simulate(self, end_time, frequency="day", task_kwargs={}, model_kwargs={}, delay_kwargs={}) -> HyperCollector:
|
||||
"""
|
||||
Starting from current time, this method will simulate every routine in OnlineManager until end time.
|
||||
|
||||
Considering the parallel training, the models and signals can be perpared after all routine simulating.
|
||||
|
||||
The delay training way can be ``DelayTrainer`` and the delay preparing signals way can be ``delay_prepare``.
|
||||
|
||||
Returns:
|
||||
HyperCollector: the OnlineManager's collector
|
||||
"""
|
||||
cal = D.calendar(start_time=self.cur_time, end_time=end_time, freq=frequency)
|
||||
self.first_train()
|
||||
for cur_time in cal:
|
||||
self.logger.info(f"Simulating at {str(cur_time)}......")
|
||||
self.routine(cur_time, task_kwargs=task_kwargs, model_kwargs=model_kwargs)
|
||||
self.delay_prepare(delay_kwargs=delay_kwargs)
|
||||
self.logger.info(f"Finished preparing signals")
|
||||
return self.get_collector()
|
||||
|
||||
def prepare_tasks(self):
|
||||
def reset(self):
|
||||
"""
|
||||
Prepare new tasks based on new date.
|
||||
This method will reset all strategy!
|
||||
|
||||
Returns:
|
||||
list: a list of new tasks.
|
||||
**Be careful to use it.**
|
||||
"""
|
||||
latest_records, max_test = self.list_latest_recorders(
|
||||
lambda rec: self.get_online_tag(rec) == OnlineManager.ONLINE_TAG
|
||||
)
|
||||
if max_test is None:
|
||||
self.logger.warn(f"No latest online recorders, no new tasks.")
|
||||
return []
|
||||
calendar_latest = D.calendar(end_time=self.cur_time)[-1] if self.cur_time is None else self.cur_time
|
||||
if self.need_log:
|
||||
self.logger.info(
|
||||
f"The interval between current time {calendar_latest} and last rolling test begin time {max_test[0]} is {self.ta.cal_interval(calendar_latest, max_test[0])}, the rolling step is {self.rg.step}"
|
||||
)
|
||||
if self.ta.cal_interval(calendar_latest, max_test[0]) >= self.rg.step:
|
||||
old_tasks = []
|
||||
tasks_tmp = []
|
||||
for rid, rec in latest_records.items():
|
||||
task = rec.load_object("task")
|
||||
old_tasks.append(deepcopy(task))
|
||||
test_begin = task["dataset"]["kwargs"]["segments"]["test"][0]
|
||||
# modify the test segment to generate new tasks
|
||||
task["dataset"]["kwargs"]["segments"]["test"] = (test_begin, calendar_latest)
|
||||
tasks_tmp.append(task)
|
||||
new_tasks_tmp = task_generator(tasks_tmp, self.rg)
|
||||
new_tasks = [task for task in new_tasks_tmp if task not in old_tasks]
|
||||
return new_tasks
|
||||
return []
|
||||
|
||||
def list_latest_recorders(self, rec_filter_func=None):
|
||||
"""find latest recorders based on test segments.
|
||||
|
||||
Args:
|
||||
rec_filter_func (Callable, optional): recorder filter. Defaults to None.
|
||||
|
||||
Returns:
|
||||
dict, tuple: the latest recorders and the latest date of them
|
||||
"""
|
||||
recs_flt = list_recorders(self.exp_name, rec_filter_func)
|
||||
if len(recs_flt) == 0:
|
||||
return recs_flt, None
|
||||
max_test = max(rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] for rec in recs_flt.values())
|
||||
latest_rec = {}
|
||||
for rid, rec in recs_flt.items():
|
||||
if rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] == max_test:
|
||||
latest_rec[rid] = rec
|
||||
return latest_rec, max_test
|
||||
self.cur_time = self.begin_time
|
||||
self.history = {}
|
||||
for strategy in self.strategy:
|
||||
strategy.reset()
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
from qlib.data import D
|
||||
from qlib import get_module_logger
|
||||
from qlib.workflow.online.manager import OnlineManager
|
||||
|
||||
|
||||
class OnlineSimulator:
|
||||
"""
|
||||
To simulate online serving in the past, like a "online serving backtest".
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
start_time,
|
||||
end_time,
|
||||
online_manager: OnlineManager,
|
||||
frequency="day",
|
||||
):
|
||||
"""
|
||||
init OnlineSimulator.
|
||||
|
||||
Args:
|
||||
start_time (str or pd.Timestamp): the start time of simulating.
|
||||
end_time (str or pd.Timestamp): the end time of simulating. If None, then end_time is latest.
|
||||
onlinemanager (OnlineManager): the instance of OnlineManager
|
||||
frequency (str, optional): the data frequency. Defaults to "day".
|
||||
"""
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
self.cal = D.calendar(start_time=start_time, end_time=end_time, freq=frequency)
|
||||
self.start_time = self.cal[0]
|
||||
self.end_time = self.cal[-1]
|
||||
self.olm = online_manager
|
||||
if len(self.cal) == 0:
|
||||
self.logger.warn(f"There is no need to simulate bacause start_time is larger than end_time.")
|
||||
|
||||
def simulate(self, *args, **kwargs):
|
||||
"""
|
||||
Starting from start time, this method will simulate every routine in OnlineManager.
|
||||
NOTE: Considering the parallel training, the models and signals can be perpared after all routine simulating.
|
||||
|
||||
Returns:
|
||||
Collector: the OnlineManager's collector
|
||||
"""
|
||||
self.rec_dict = {}
|
||||
tmp_begin = self.start_time
|
||||
tmp_end = None
|
||||
prev_recorders = self.olm.online_models()
|
||||
for cur_time in self.cal:
|
||||
self.logger.info(f"Simulating at {str(cur_time)}......")
|
||||
recorders = self.olm.routine(cur_time, True, *args, **kwargs)
|
||||
if len(recorders) == 0:
|
||||
tmp_end = cur_time
|
||||
else:
|
||||
self.rec_dict[(tmp_begin, tmp_end)] = prev_recorders
|
||||
tmp_begin = cur_time
|
||||
prev_recorders = recorders
|
||||
self.rec_dict[(tmp_begin, self.end_time)] = prev_recorders
|
||||
# finished perparing models (and pred) and signals
|
||||
self.olm.delay_prepare(self.rec_dict)
|
||||
self.logger.info(f"Finished preparing signals")
|
||||
return self.olm.get_collector()
|
||||
|
||||
def online_models(self):
|
||||
"""
|
||||
Return a online models dict likes {(begin_time, end_time):[online models]}.
|
||||
|
||||
Returns:
|
||||
dict
|
||||
"""
|
||||
if hasattr(self, "rec_dict"):
|
||||
return self.rec_dict
|
||||
self.logger.warn(f"Please call `simulate` firstly when calling `online_models`")
|
||||
return {}
|
||||
339
qlib/workflow/online/strategy.py
Normal file
339
qlib/workflow/online/strategy.py
Normal file
@@ -0,0 +1,339 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
OnlineStrategy is a set of strategy for online serving.
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import pandas as pd
|
||||
from qlib.data.data import D
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.model.ens.ensemble import AverageEnsemble, SingleKeyEnsemble
|
||||
from qlib.model.ens.group import RollingGroup
|
||||
from qlib.model.trainer import Trainer, TrainerR
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.online.utils import OnlineTool, OnlineToolR
|
||||
from qlib.workflow.recorder import Recorder
|
||||
from qlib.workflow.task.collect import Collector, HyperCollector, RecorderCollector
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.utils import TimeAdjuster, list_recorders
|
||||
|
||||
|
||||
class OnlineStrategy:
|
||||
"""
|
||||
OnlineStrategy is working with `Online Manager <#Online Manager>`_, responsing how the tasks are generated, the models are updated and signals are perpared.
|
||||
"""
|
||||
|
||||
def __init__(self, name_id: str, trainer: Trainer = None, need_log=True):
|
||||
"""
|
||||
Init OnlineStrategy.
|
||||
This module **MUST** use `Trainer <../reference/api.html#Trainer>`_ to finishing model training.
|
||||
|
||||
Args:
|
||||
name_id (str): a unique name or id
|
||||
trainer (Trainer, optional): a instance of Trainer. Defaults to None.
|
||||
need_log (bool, optional): print log or not. Defaults to True.
|
||||
"""
|
||||
self.name_id = name_id
|
||||
self.trainer = trainer
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
self.need_log = need_log
|
||||
self.tool = OnlineTool()
|
||||
|
||||
def prepare_signals(self, delay: bool = False):
|
||||
"""
|
||||
After perparing the data of last routine (a box in box-plot) which means the end of the routine, we can prepare trading signals for next routine.
|
||||
|
||||
NOTE: Given a set prediction, all signals before these prediction end time will be prepared well.
|
||||
|
||||
Args:
|
||||
delay: bool
|
||||
If this method was called by `delay_prepare`
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `prepare_signals` method.")
|
||||
|
||||
def prepare_tasks(self, *args, **kwargs):
|
||||
"""
|
||||
After the end of a routine, check whether we need to prepare and train some new tasks.
|
||||
Return the new tasks waiting for training.
|
||||
|
||||
You can find last online models by OnlineTool.online_models.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `prepare_tasks` method.")
|
||||
|
||||
def prepare_online_models(self, tasks, check_func=None, **kwargs):
|
||||
"""
|
||||
Use trainer to train a list of tasks and set the trained model to `online`.
|
||||
|
||||
NOTE: This method will first offline all models and online the online models prepared by this method. So you can find last online models by OnlineTool.online_models if you still need them.
|
||||
|
||||
Args:
|
||||
tasks (list): a list of tasks.
|
||||
check_func: the method to judge if a model can be online.
|
||||
The parameter is the model record and return True for online.
|
||||
None for online every models.
|
||||
**kwargs: will be passed to end_train which means will be passed to customized train method.
|
||||
|
||||
"""
|
||||
if check_func is None:
|
||||
check_func = lambda x: True
|
||||
online_models = []
|
||||
if len(tasks) > 0:
|
||||
new_models = self.trainer.train(tasks, **kwargs)
|
||||
for model in new_models:
|
||||
if check_func(model):
|
||||
online_models.append(model)
|
||||
self.tool.reset_online_tag(online_models)
|
||||
return online_models
|
||||
|
||||
def first_train(self):
|
||||
"""
|
||||
Train a series of models firstly and set some of them as online models.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `first_train` method.")
|
||||
|
||||
def get_collector(self) -> Collector:
|
||||
"""
|
||||
Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results of online serving.
|
||||
|
||||
|
||||
For example:
|
||||
1) collect predictions in Recorder
|
||||
2) collect signals in .txt file
|
||||
|
||||
Returns:
|
||||
Collector
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_collector` method.")
|
||||
|
||||
def delay_prepare(self, history: list, **kwargs):
|
||||
"""
|
||||
Prepare all models and signals if there are something waiting for prepare.
|
||||
|
||||
Assumption: the predictions of online models need less than next begin_time, or this method will work in a wrong way.
|
||||
|
||||
Args:
|
||||
history (list): an online models list likes [begin_time:[online models]].
|
||||
**kwargs: will be passed to end_train which means will be passed to customized train method.
|
||||
"""
|
||||
for begin_time, recs_list in history:
|
||||
self.trainer.end_train(recs_list, **kwargs)
|
||||
self.tool.reset_online_tag(recs_list)
|
||||
self.prepare_signals(delay=True)
|
||||
|
||||
def get_signals(self):
|
||||
"""
|
||||
Get prepared signals.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_signals` method.")
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
Delete all things and set them to default status. This method is convenient to explore the strategy for online simulation.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class RollingAverageStrategy(OnlineStrategy):
|
||||
|
||||
"""
|
||||
This example strategy always use latest rolling model as online model and prepare trading signals using the average prediction of online models
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name_id: str,
|
||||
task_template: Union[dict, List[dict]],
|
||||
rolling_gen: RollingGen,
|
||||
trainer: Trainer = None,
|
||||
need_log=True,
|
||||
signal_exp_name="OnlineManagerSignals",
|
||||
):
|
||||
"""
|
||||
Init RollingAverageStrategy.
|
||||
|
||||
Assumption: the str of name_id, the experiment name and the trainer's experiment name are same one.
|
||||
|
||||
Args:
|
||||
name_id (str): a unique name or id. Will be also the name of Experiment.
|
||||
task_template (Union[dict,List[dict]]): a list of task_template or a single template, which will be used to generate many tasks using rolling_gen.
|
||||
rolling_gen (RollingGen): an instance of RollingGen
|
||||
trainer (Trainer, optional): a instance of Trainer. Defaults to None.
|
||||
need_log (bool, optional): print log or not. Defaults to True.
|
||||
signal_exp_path (str): a specific experiment to save signals of different experiment.
|
||||
"""
|
||||
super().__init__(name_id=name_id, trainer=trainer, need_log=need_log)
|
||||
self.exp_name = self.name_id
|
||||
if not isinstance(task_template, list):
|
||||
task_template = [task_template]
|
||||
self.task_template = task_template
|
||||
self.signal_exp_name = signal_exp_name
|
||||
self.rg = rolling_gen
|
||||
self.tool = OnlineToolR(self.exp_name)
|
||||
self.ta = TimeAdjuster()
|
||||
with R.start(experiment_name=self.signal_exp_name, recorder_name=self.exp_name, resume=True):
|
||||
self.signal_rec = R.get_recorder() # the recorder to record signals
|
||||
self.signal_rec.save_objects(**{"signals": None})
|
||||
|
||||
def get_collector(self, process_list=[RollingGroup()], rec_key_func=None, rec_filter_func=None, artifacts_key=None):
|
||||
"""
|
||||
Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results. The returned collector must can distinguish results in different models.
|
||||
Assumption: the models can be distinguished based on model name and rolling test segments.
|
||||
If you do not want this assumption, please implement your own method or use another rec_key_func.
|
||||
|
||||
Args:
|
||||
rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.
|
||||
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
|
||||
artifacts_key (List[str], optional): the artifacts key you want to get. If None, get all artifacts.
|
||||
"""
|
||||
|
||||
def rec_key(recorder):
|
||||
task_config = recorder.load_object("task")
|
||||
model_key = task_config["model"]["class"]
|
||||
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
|
||||
return model_key, rolling_key
|
||||
|
||||
if rec_key_func is None:
|
||||
rec_key_func = rec_key
|
||||
|
||||
artifacts_collector = RecorderCollector(
|
||||
experiment=self.exp_name,
|
||||
process_list=process_list,
|
||||
rec_key_func=rec_key_func,
|
||||
rec_filter_func=rec_filter_func,
|
||||
artifacts_key=artifacts_key,
|
||||
)
|
||||
|
||||
return artifacts_collector
|
||||
|
||||
def first_train(self) -> List[Recorder]:
|
||||
"""
|
||||
Use rolling_gen to generate different tasks based on task_template and trained them.
|
||||
|
||||
Returns:
|
||||
List[Recorder]: a list of Recorder.
|
||||
"""
|
||||
tasks = task_generator(
|
||||
tasks=self.task_template,
|
||||
generators=self.rg, # generate different date segment
|
||||
)
|
||||
return self.prepare_online_models(tasks)
|
||||
|
||||
def prepare_tasks(self, cur_time) -> List[dict]:
|
||||
"""
|
||||
Prepare new tasks based on cur_time (None for latest).
|
||||
|
||||
You can find last online models by OnlineToolR.online_models.
|
||||
|
||||
Returns:
|
||||
List[dict]: a list of new tasks.
|
||||
"""
|
||||
latest_records, max_test = self._list_latest(self.tool.online_models())
|
||||
if max_test is None:
|
||||
self.logger.warn(f"No latest online recorders, no new tasks.")
|
||||
return []
|
||||
calendar_latest = D.calendar(end_time=cur_time)[-1] if cur_time is None else cur_time
|
||||
if self.need_log:
|
||||
self.logger.info(
|
||||
f"The interval between current time {calendar_latest} and last rolling test begin time {max_test[0]} is {self.ta.cal_interval(calendar_latest, max_test[0])}, the rolling step is {self.rg.step}"
|
||||
)
|
||||
if self.ta.cal_interval(calendar_latest, max_test[0]) >= self.rg.step:
|
||||
old_tasks = []
|
||||
tasks_tmp = []
|
||||
for rec in latest_records:
|
||||
task = rec.load_object("task")
|
||||
old_tasks.append(deepcopy(task))
|
||||
test_begin = task["dataset"]["kwargs"]["segments"]["test"][0]
|
||||
# modify the test segment to generate new tasks
|
||||
task["dataset"]["kwargs"]["segments"]["test"] = (test_begin, calendar_latest)
|
||||
tasks_tmp.append(task)
|
||||
new_tasks_tmp = task_generator(tasks_tmp, self.rg)
|
||||
new_tasks = [task for task in new_tasks_tmp if task not in old_tasks]
|
||||
return new_tasks
|
||||
return []
|
||||
|
||||
def prepare_signals(self, delay=False, over_write=False) -> pd.DataFrame:
|
||||
"""
|
||||
Average the predictions of online models and offer a trading signals every routine.
|
||||
The signals will be saved to `signal` file of a recorder named self.exp_name of a experiment using the name of `SIGNAL_EXP`
|
||||
Even if the latest signal already exists, the latest calculation result will be overwritten.
|
||||
|
||||
.. note::
|
||||
|
||||
Given a prediction of a certain time, all signals before this time will be prepared well.
|
||||
|
||||
Args:
|
||||
over_write (bool, optional): If True, the new signals will overwrite the file. If False, the new signals will append to the end of signals. Defaults to False.
|
||||
Returns:
|
||||
pd.DataFrame: the signals.
|
||||
"""
|
||||
if not delay:
|
||||
self.tool.update_online_pred()
|
||||
|
||||
# Get a collector to average online models predictions
|
||||
online_collector = self.get_collector(
|
||||
process_list=[AverageEnsemble()],
|
||||
rec_filter_func=lambda x: True if self.tool.get_online_tag(x) == self.tool.ONLINE_TAG else False,
|
||||
artifacts_key="pred",
|
||||
)
|
||||
online_results = online_collector()
|
||||
signals = online_results["pred"]
|
||||
|
||||
old_signals = self.get_signals()
|
||||
if old_signals is not None and not over_write:
|
||||
old_max = old_signals.index.get_level_values("datetime").max()
|
||||
new_signals = signals.loc[old_max:]
|
||||
signals = pd.concat([old_signals, new_signals], axis=0)
|
||||
else:
|
||||
new_signals = signals
|
||||
if self.need_log:
|
||||
self.logger.info(
|
||||
f"Finished preparing new {len(new_signals)} signals to {self.signal_exp_name}/{self.exp_name}."
|
||||
)
|
||||
self.signal_rec.save_objects(**{"signals": signals})
|
||||
return signals
|
||||
|
||||
def get_signals(self) -> object:
|
||||
"""
|
||||
Get signals from the recorder(named self.exp_name) of the experiment(named self.SIGNAL_EXP)
|
||||
|
||||
Returns:
|
||||
object: signals
|
||||
"""
|
||||
signals = self.signal_rec.load_object("signals")
|
||||
return signals
|
||||
|
||||
def _list_latest(self, rec_list: List[Recorder]):
|
||||
"""
|
||||
List latest recorder form rec_list
|
||||
|
||||
Args:
|
||||
rec_list (List[Recorder]): a list of Recorder
|
||||
|
||||
Returns:
|
||||
List[Recorder], pd.Timestamp: the latest recorders and its test end time
|
||||
"""
|
||||
if len(rec_list) == 0:
|
||||
return rec_list, None
|
||||
max_test = max(rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] for rec in rec_list)
|
||||
latest_rec = []
|
||||
for rec in rec_list:
|
||||
if rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] == max_test:
|
||||
latest_rec.append(rec)
|
||||
return latest_rec, max_test
|
||||
|
||||
def reset(self):
|
||||
"""
|
||||
NOTE: This method will delete all recorder in Experiment and reset the Trainer!
|
||||
"""
|
||||
self.trainer.reset()
|
||||
# delete models
|
||||
exp = R.get_exp(experiment_name=self.exp_name)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
# delete signals
|
||||
for rid in list_recorders(self.signal_exp_name, lambda x: True if x.info["name"] == self.exp_name else False):
|
||||
exp.delete_recorder(rid)
|
||||
@@ -1,18 +1,20 @@
|
||||
from typing import Union, List
|
||||
from qlib.data.dataset import DatasetH
|
||||
from qlib.workflow import R
|
||||
from qlib.data import D
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
Updater is a module to update artifacts such as predictions, when the stock data is updating.
|
||||
"""
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
import pandas as pd
|
||||
from qlib import get_module_logger
|
||||
from qlib.workflow import R
|
||||
from qlib.model import Model
|
||||
from qlib.model.trainer import task_train
|
||||
from qlib.workflow.recorder import Recorder
|
||||
from qlib.workflow.task.utils import list_recorders
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.data import D
|
||||
from qlib.data.dataset import DatasetH
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.model import Model
|
||||
from qlib.utils import get_date_by_shift
|
||||
from qlib.workflow.recorder import Recorder
|
||||
|
||||
|
||||
class RMDLoader:
|
||||
@@ -25,19 +27,22 @@ class RMDLoader:
|
||||
|
||||
def get_dataset(self, start_time, end_time, segments=None) -> DatasetH:
|
||||
"""
|
||||
load, config and setup dataset.
|
||||
Load, config and setup dataset.
|
||||
|
||||
This dataset is for inference
|
||||
This dataset is for inference.
|
||||
|
||||
Args:
|
||||
start_time :
|
||||
the start_time of underlying data
|
||||
end_time :
|
||||
the end_time of underlying data
|
||||
segments : dict
|
||||
the segments config for dataset
|
||||
Due to the time series dataset (TSDatasetH), the test segments maybe different from start_time and end_time
|
||||
|
||||
Returns:
|
||||
DatasetH: the instance of DatasetH
|
||||
|
||||
Parameters
|
||||
----------
|
||||
start_time :
|
||||
the start_time of underlying data
|
||||
end_time :
|
||||
the end_time of underlying data
|
||||
segments : dict
|
||||
the segments config for dataset
|
||||
Due to the time series dataset (TSDatasetH), the test segments maybe different from start_time and end_time
|
||||
"""
|
||||
if segments is None:
|
||||
segments = {"test": (start_time, end_time)}
|
||||
@@ -52,7 +57,7 @@ class RMDLoader:
|
||||
|
||||
class RecordUpdater(metaclass=ABCMeta):
|
||||
"""
|
||||
Updata a specific recorders
|
||||
Update a specific recorders
|
||||
"""
|
||||
|
||||
def __init__(self, record: Recorder, need_log=True, *args, **kwargs):
|
||||
@@ -75,17 +80,22 @@ class PredUpdater(RecordUpdater):
|
||||
|
||||
def __init__(self, record: Recorder, to_date=None, hist_ref: int = 0, freq="day", need_log=True):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
record : Recorder
|
||||
to_date :
|
||||
update to prediction to the `to_date`
|
||||
hist_ref : int
|
||||
Sometimes, the dataset will have historical depends.
|
||||
Leave the problem to user to set the length of historical dependancy
|
||||
NOTE: the start_time is not included in the hist_ref
|
||||
# TODO: automate this step in the future.
|
||||
Init PredUpdater.
|
||||
|
||||
Args:
|
||||
record : Recorder
|
||||
to_date :
|
||||
update to prediction to the `to_date`
|
||||
hist_ref : int
|
||||
Sometimes, the dataset will have historical depends.
|
||||
Leave the problem to user to set the length of historical dependency
|
||||
|
||||
.. note::
|
||||
|
||||
the start_time is not included in the hist_ref
|
||||
|
||||
"""
|
||||
# TODO: automate this hist_ref in the future.
|
||||
super().__init__(record=record, need_log=need_log)
|
||||
|
||||
self.to_date = to_date
|
||||
@@ -101,9 +111,12 @@ class PredUpdater(RecordUpdater):
|
||||
|
||||
def prepare_data(self) -> DatasetH:
|
||||
"""
|
||||
# Load dataset
|
||||
Load dataset
|
||||
|
||||
Seperating this function will make it easier to reuse the dataset
|
||||
|
||||
Returns:
|
||||
DatasetH: the instance of DatasetH
|
||||
"""
|
||||
start_time_buffer = get_date_by_shift(self.last_end, -self.hist_ref + 1, clip_shift=False, freq=self.freq)
|
||||
start_time = get_date_by_shift(self.last_end, 1, freq=self.freq)
|
||||
@@ -113,9 +126,12 @@ class PredUpdater(RecordUpdater):
|
||||
|
||||
def update(self, dataset: DatasetH = None):
|
||||
"""
|
||||
update the precition in a recorder
|
||||
Update the precition in a recorder
|
||||
|
||||
Args:
|
||||
DatasetH: the instance of DatasetH. None for reprepare.
|
||||
"""
|
||||
# FIXME: the problme below is not solved
|
||||
# FIXME: the problem below is not solved
|
||||
# The model dumped on GPU instances can not be loaded on CPU instance. Follow exception will raised
|
||||
# RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
|
||||
# https://github.com/pytorch/pytorch/issues/16797
|
||||
|
||||
170
qlib/workflow/online/utils.py
Normal file
170
qlib/workflow/online/utils.py
Normal file
@@ -0,0 +1,170 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
OnlineTool is a module to set and unset a series of `online` models.
|
||||
The `online` models are some decisive models in some time point, which can be changed with the change of time.
|
||||
This allows us to use efficient submodels as the market style changing.
|
||||
"""
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.workflow.online.update import PredUpdater
|
||||
from qlib.workflow.recorder import Recorder
|
||||
from qlib.workflow.task.utils import list_recorders
|
||||
|
||||
|
||||
class OnlineTool:
|
||||
"""
|
||||
OnlineTool.
|
||||
"""
|
||||
|
||||
ONLINE_KEY = "online_status" # the online status key in recorder
|
||||
ONLINE_TAG = "online" # the 'online' model
|
||||
OFFLINE_TAG = "offline" # the 'offline' model, not for online serving
|
||||
|
||||
def __init__(self, need_log=True):
|
||||
"""
|
||||
Init OnlineTool.
|
||||
|
||||
Args:
|
||||
need_log (bool, optional): print log or not. Defaults to True.
|
||||
"""
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
self.need_log = need_log
|
||||
|
||||
def set_online_tag(self, tag, recorder: Union[list, object]):
|
||||
"""
|
||||
Set `tag` to the model to sign whether online.
|
||||
|
||||
Args:
|
||||
tag (str): the tags in `ONLINE_TAG`, `OFFLINE_TAG`
|
||||
recorder (Union[list,object]): the model's recorder
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `set_online_tag` method.")
|
||||
|
||||
def get_online_tag(self, recorder: object) -> str:
|
||||
"""
|
||||
Given a model recorder and return its online tag.
|
||||
|
||||
Args:
|
||||
recorder (Object): the model's recorder
|
||||
|
||||
Returns:
|
||||
str: the online tag
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_online_tag` method.")
|
||||
|
||||
def reset_online_tag(self, recorder: Union[list, object]):
|
||||
"""
|
||||
Offline all models and set the recorders to 'online'.
|
||||
|
||||
Args:
|
||||
recorder (Union[list,object]):
|
||||
the recorder you want to reset to 'online'.
|
||||
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `reset_online_tag` method.")
|
||||
|
||||
def online_models(self) -> list:
|
||||
"""
|
||||
Get current `online` models
|
||||
|
||||
Returns:
|
||||
list: a list of `online` models.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `online_models` method.")
|
||||
|
||||
def update_online_pred(self, to_date=None):
|
||||
"""
|
||||
Update the predictions of `online` models to a date.
|
||||
|
||||
Args:
|
||||
to_date (pd.Timestamp): the pred before this date will be updated. None for update to latest.
|
||||
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `update_online_pred` method.")
|
||||
|
||||
|
||||
class OnlineToolR(OnlineTool):
|
||||
"""
|
||||
The implementation of OnlineTool based on (R)ecorder.
|
||||
"""
|
||||
|
||||
def __init__(self, experiment_name: str, need_log=True):
|
||||
"""
|
||||
Init OnlineToolR.
|
||||
|
||||
Args:
|
||||
experiment_name (str): the experiment name.
|
||||
need_log (bool, optional): print log or not. Defaults to True.
|
||||
"""
|
||||
super().__init__(need_log=need_log)
|
||||
self.exp_name = experiment_name
|
||||
|
||||
def set_online_tag(self, tag, recorder: Union[Recorder, List]):
|
||||
"""
|
||||
Set `tag` to the model's recorder to sign whether online.
|
||||
|
||||
Args:
|
||||
tag (str): the tags in `ONLINE_TAG`, `NEXT_ONLINE_TAG`, `OFFLINE_TAG`
|
||||
recorder (Union[Recorder, List]): a list of Recorder or an instance of Recorder
|
||||
"""
|
||||
if isinstance(recorder, Recorder):
|
||||
recorder = [recorder]
|
||||
for rec in recorder:
|
||||
rec.set_tags(**{self.ONLINE_KEY: tag})
|
||||
if self.need_log:
|
||||
self.logger.info(f"Set {len(recorder)} models to '{tag}'.")
|
||||
|
||||
def get_online_tag(self, recorder: Recorder) -> str:
|
||||
"""
|
||||
Given a model recorder and return its online tag.
|
||||
|
||||
Args:
|
||||
recorder (Recorder): an instance of recorder
|
||||
|
||||
Returns:
|
||||
str: the online tag
|
||||
"""
|
||||
tags = recorder.list_tags()
|
||||
return tags.get(self.ONLINE_KEY, self.OFFLINE_TAG)
|
||||
|
||||
def reset_online_tag(self, recorder: Union[Recorder, List]):
|
||||
"""
|
||||
Offline all models and set the recorders to 'online'.
|
||||
|
||||
Args:
|
||||
recorder (Union[Recorder, List]):
|
||||
the recorder you want to reset to 'online'.
|
||||
|
||||
"""
|
||||
if isinstance(recorder, Recorder):
|
||||
recorder = [recorder]
|
||||
recs = list_recorders(self.exp_name)
|
||||
self.set_online_tag(self.OFFLINE_TAG, list(recs.values()))
|
||||
self.set_online_tag(self.ONLINE_TAG, recorder)
|
||||
|
||||
def online_models(self) -> list:
|
||||
"""
|
||||
Get current `online` models
|
||||
|
||||
Returns:
|
||||
list: a list of `online` models.
|
||||
"""
|
||||
return list(list_recorders(self.exp_name, lambda rec: self.get_online_tag(rec) == self.ONLINE_TAG).values())
|
||||
|
||||
def update_online_pred(self, to_date=None):
|
||||
"""
|
||||
Update the predictions of online models to a date.
|
||||
|
||||
Args:
|
||||
to_date (pd.Timestamp): the pred before this date will be updated. None for update to latest time in Calendar.
|
||||
"""
|
||||
online_models = self.online_models()
|
||||
for rec in online_models:
|
||||
PredUpdater(rec, to_date=to_date, need_log=self.need_log).update()
|
||||
|
||||
if self.need_log:
|
||||
self.logger.info(f"Finished updating {len(online_models)} online model predictions of {self.exp_name}.")
|
||||
@@ -1,7 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import re
|
||||
import re, logging
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
@@ -13,10 +13,10 @@ from ..data.dataset.handler import DataHandlerLP
|
||||
from ..utils import init_instance_by_config, get_module_by_module_path
|
||||
from ..log import get_module_logger
|
||||
from ..utils import flatten_dict
|
||||
from ..contrib.eva.alpha import calc_ic, calc_long_short_return
|
||||
from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec
|
||||
from ..contrib.strategy.strategy import BaseStrategy
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
logger = get_module_logger("workflow", logging.INFO)
|
||||
|
||||
|
||||
class RecordTemp:
|
||||
@@ -166,6 +166,60 @@ class SignalRecord(RecordTemp):
|
||||
return super().load(name)
|
||||
|
||||
|
||||
class HFSignalRecord(SignalRecord):
|
||||
"""
|
||||
This is the Signal Analysis Record class that generates the analysis results such as IC and IR. This class inherits the ``RecordTemp`` class.
|
||||
"""
|
||||
|
||||
artifact_path = "hg_sig_analysis"
|
||||
|
||||
def __init__(self, recorder, **kwargs):
|
||||
super().__init__(recorder=recorder)
|
||||
|
||||
def generate(self):
|
||||
pred = self.load("pred.pkl")
|
||||
raw_label = self.load("label.pkl")
|
||||
long_pre, short_pre = calc_long_short_prec(pred.iloc[:, 0], raw_label.iloc[:, 0], is_alpha=True)
|
||||
ic, ric = calc_ic(pred.iloc[:, 0], raw_label.iloc[:, 0])
|
||||
metrics = {
|
||||
"IC": ic.mean(),
|
||||
"ICIR": ic.mean() / ic.std(),
|
||||
"Rank IC": ric.mean(),
|
||||
"Rank ICIR": ric.mean() / ric.std(),
|
||||
"Long precision": long_pre.mean(),
|
||||
"Short precision": short_pre.mean(),
|
||||
}
|
||||
objects = {"ic.pkl": ic, "ric.pkl": ric}
|
||||
objects.update({"long_pre.pkl": long_pre, "short_pre.pkl": short_pre})
|
||||
long_short_r, long_avg_r = calc_long_short_return(pred.iloc[:, 0], raw_label.iloc[:, 0])
|
||||
metrics.update(
|
||||
{
|
||||
"Long-Short Average Return": long_short_r.mean(),
|
||||
"Long-Short Average Sharpe": long_short_r.mean() / long_short_r.std(),
|
||||
}
|
||||
)
|
||||
objects.update(
|
||||
{
|
||||
"long_short_r.pkl": long_short_r,
|
||||
"long_avg_r.pkl": long_avg_r,
|
||||
}
|
||||
)
|
||||
self.recorder.log_metrics(**metrics)
|
||||
self.recorder.save_objects(**objects, artifact_path=self.get_path())
|
||||
pprint(metrics)
|
||||
|
||||
def list(self):
|
||||
paths = [
|
||||
self.get_path("ic.pkl"),
|
||||
self.get_path("ric.pkl"),
|
||||
self.get_path("long_pre.pkl"),
|
||||
self.get_path("short_pre.pkl"),
|
||||
self.get_path("long_short_r.pkl"),
|
||||
self.get_path("long_avg_r.pkl"),
|
||||
]
|
||||
return paths
|
||||
|
||||
|
||||
class SigAnaRecord(SignalRecord):
|
||||
"""
|
||||
This is the Signal Analysis Record class that generates the analysis results such as IC and IR. This class inherits the ``RecordTemp`` class.
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import mlflow
|
||||
import mlflow, logging
|
||||
import shutil, os, pickle, tempfile, codecs, pickle
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from ..utils.objm import FileManager
|
||||
from ..log import get_module_logger
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
logger = get_module_logger("workflow", logging.INFO)
|
||||
|
||||
|
||||
class Recorder:
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Callable, Union
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
Collector can collect object from everywhere and process them such as merging, grouping, averaging and so on.
|
||||
"""
|
||||
|
||||
from qlib.model.ens.ensemble import SingleKeyEnsemble
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.task.utils import list_recorders
|
||||
from qlib.utils.serial import Serializable
|
||||
import dill as pickle
|
||||
|
||||
|
||||
@@ -18,7 +22,7 @@ class Collector:
|
||||
process_list = [process_list]
|
||||
self.process_list = process_list
|
||||
|
||||
def collect(self):
|
||||
def collect(self) -> dict:
|
||||
"""Collect the results and return a dict like {key: things}
|
||||
|
||||
Returns:
|
||||
@@ -35,7 +39,7 @@ class Collector:
|
||||
raise NotImplementedError(f"Please implement the `collect` method.")
|
||||
|
||||
@staticmethod
|
||||
def process_collect(collected_dict, process_list=[], *args, **kwargs):
|
||||
def process_collect(collected_dict, process_list=[], *args, **kwargs) -> dict:
|
||||
"""do a series of processing to the dict returned by collect and return a dict like {key: things}
|
||||
For example: you can group and ensemble.
|
||||
|
||||
@@ -60,7 +64,7 @@ class Collector:
|
||||
result[artifact] = value
|
||||
return result
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
def __call__(self, *args, **kwargs) -> dict:
|
||||
"""
|
||||
do the workflow including collect and process_collect
|
||||
|
||||
@@ -78,7 +82,7 @@ class Collector:
|
||||
filepath (str): the path of file
|
||||
|
||||
Returns:
|
||||
bool: if successed
|
||||
bool: if succeeded
|
||||
"""
|
||||
try:
|
||||
with open(filepath, "wb") as f:
|
||||
@@ -109,6 +113,29 @@ class Collector:
|
||||
raise TypeError(f"The instance of {type(collector)} is not a valid `Collector`!")
|
||||
|
||||
|
||||
class HyperCollector(Collector):
|
||||
"""
|
||||
A collector to collect the results of other Collectors
|
||||
"""
|
||||
|
||||
def __init__(self, collector_dict, process_list=[]):
|
||||
"""
|
||||
Args:
|
||||
collector_dict (dict): the dict like {collector_key, Collector}
|
||||
process_list (list or Callable): the list of processors or the instance of processor to process dict.
|
||||
NOTE: process_list = [SingleKeyEnsemble()] can ignore key and use value directly if there is only one {k,v} in a dict.
|
||||
This can make result more readable. If you want to maintain as it should be, just give a empty process list.
|
||||
"""
|
||||
super().__init__(process_list=process_list)
|
||||
self.collector_dict = collector_dict
|
||||
|
||||
def collect(self) -> dict:
|
||||
collect_dict = {}
|
||||
for key, collector in self.collector_dict.items():
|
||||
collect_dict[key] = collector()
|
||||
return collect_dict
|
||||
|
||||
|
||||
class RecorderCollector(Collector):
|
||||
ART_KEY_RAW = "__raw"
|
||||
|
||||
@@ -131,10 +158,10 @@ class RecorderCollector(Collector):
|
||||
artifacts_path (dict, optional): The artifacts name and its path in Recorder. Defaults to {"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}.
|
||||
artifacts_key (str or List, optional): the artifacts key you want to get. If None, get all artifacts.
|
||||
"""
|
||||
super().__init__(process_list=process_list)
|
||||
if isinstance(experiment, str):
|
||||
experiment = R.get_exp(experiment_name=experiment)
|
||||
self.experiment = experiment
|
||||
self.process_list = process_list
|
||||
self.artifacts_path = artifacts_path
|
||||
if rec_key_func is None:
|
||||
rec_key_func = lambda rec: rec.info["id"]
|
||||
@@ -144,7 +171,7 @@ class RecorderCollector(Collector):
|
||||
self.artifacts_key = artifacts_key
|
||||
self._rec_filter_func = rec_filter_func
|
||||
|
||||
def collect(self, artifacts_key=None, rec_filter_func=None):
|
||||
def collect(self, artifacts_key=None, rec_filter_func=None) -> dict:
|
||||
"""Collect different artifacts based on recorder after filtering.
|
||||
|
||||
Args:
|
||||
@@ -180,3 +207,12 @@ class RecorderCollector(Collector):
|
||||
collect_dict.setdefault(key, {})[rec_key] = artifact
|
||||
|
||||
return collect_dict
|
||||
|
||||
def get_exp_name(self) -> str:
|
||||
"""
|
||||
Get experiment name
|
||||
|
||||
Returns:
|
||||
str: experiment name
|
||||
"""
|
||||
return self.experiment.name
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
this is a task generator
|
||||
Task generator can generate many tasks based on TaskGen and some task templates.
|
||||
"""
|
||||
import abc
|
||||
import copy
|
||||
@@ -113,7 +113,7 @@ class RollingGen(TaskGen):
|
||||
self.test_key = "test"
|
||||
self.train_key = "train"
|
||||
|
||||
def generate(self, task: dict):
|
||||
def generate(self, task: dict) -> typing.List[dict]:
|
||||
"""
|
||||
Converting the task into a rolling task.
|
||||
|
||||
@@ -158,6 +158,10 @@ class RollingGen(TaskGen):
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
Returns
|
||||
----------
|
||||
typing.List[dict]: a list of tasks
|
||||
"""
|
||||
res = []
|
||||
|
||||
@@ -196,16 +200,18 @@ class RollingGen(TaskGen):
|
||||
|
||||
# update segments of this task
|
||||
t["dataset"]["kwargs"]["segments"] = copy.deepcopy(segments)
|
||||
# if end_time < the end of test_segments, then change end_time to allow load more data
|
||||
if (
|
||||
self.modify_end_time
|
||||
and self.ta.cal_interval(
|
||||
|
||||
try:
|
||||
interval = self.ta.cal_interval(
|
||||
t["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"],
|
||||
t["dataset"]["kwargs"]["segments"][self.test_key][1],
|
||||
)
|
||||
< 0
|
||||
):
|
||||
t["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"] = copy.deepcopy(segments[self.test_key][1])
|
||||
# if end_time < the end of test_segments, then change end_time to allow load more data
|
||||
if self.modify_end_time and interval < 0:
|
||||
t["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"] = copy.deepcopy(segments[self.test_key][1])
|
||||
except KeyError:
|
||||
# Maybe the user dataset has no handler or end_time
|
||||
pass
|
||||
prev_seg = segments
|
||||
res.append(t)
|
||||
return res
|
||||
|
||||
@@ -1,31 +1,39 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
A task consists of 3 parts
|
||||
TaskManager can fetch unused tasks automatically and manager the lifecycle of a set of tasks with error handling.
|
||||
These features can run tasks concurrently and ensure every task will be used only once.
|
||||
Task Manager will store all tasks in `MongoDB <https://www.mongodb.com/>`_.
|
||||
Users **MUST** finished the configuration of `MongoDB <https://www.mongodb.com/>`_ when using this module.
|
||||
|
||||
A task in TaskManager consists of 3 parts
|
||||
- tasks description: the desc will define the task
|
||||
- tasks status: the status of the task
|
||||
- tasks result information : A user can get the task with the task description and task result.
|
||||
|
||||
"""
|
||||
from bson.binary import Binary
|
||||
import pickle
|
||||
from pymongo.errors import InvalidDocument
|
||||
from bson.objectid import ObjectId
|
||||
from contextlib import contextmanager
|
||||
import qlib
|
||||
from tqdm.cli import tqdm
|
||||
import time
|
||||
import concurrent
|
||||
import pymongo
|
||||
from qlib.config import C
|
||||
from .utils import get_mongodb
|
||||
from qlib import get_module_logger, auto_init
|
||||
import pickle
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import Callable, List
|
||||
|
||||
import fire
|
||||
import pymongo
|
||||
from bson.binary import Binary
|
||||
from bson.objectid import ObjectId
|
||||
from pymongo.errors import InvalidDocument
|
||||
from qlib import auto_init, get_module_logger
|
||||
from tqdm.cli import tqdm
|
||||
|
||||
from .utils import get_mongodb
|
||||
|
||||
|
||||
class TaskManager:
|
||||
"""TaskManager
|
||||
here is what will a task looks like when it created by TaskManager
|
||||
"""
|
||||
TaskManager
|
||||
|
||||
Here is what will a task looks like when it created by TaskManager
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -42,6 +50,16 @@ class TaskManager:
|
||||
.. note::
|
||||
|
||||
Assumption: the data in MongoDB was encoded and the data out of MongoDB was decoded
|
||||
|
||||
Here are four status which are:
|
||||
|
||||
STATUS_WAITING: waiting for train
|
||||
|
||||
STATUS_RUNNING: training
|
||||
|
||||
STATUS_PART_DONE: finished some step and waiting for next step
|
||||
|
||||
STATUS_DONE: all work done
|
||||
"""
|
||||
|
||||
STATUS_WAITING = "waiting"
|
||||
@@ -53,7 +71,7 @@ class TaskManager:
|
||||
|
||||
def __init__(self, task_pool: str = None):
|
||||
"""
|
||||
init Task Manager, remember to make the statement of MongoDB url and database name firstly.
|
||||
Init Task Manager, remember to make the statement of MongoDB url and database name firstly.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -65,7 +83,7 @@ class TaskManager:
|
||||
self.task_pool = getattr(self.mdb, task_pool)
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
|
||||
def list(self):
|
||||
def list(self) -> list:
|
||||
"""
|
||||
list the all collection(task_pool) of the db
|
||||
|
||||
@@ -92,7 +110,9 @@ class TaskManager:
|
||||
return {k: str(v) for k, v in flt.items()}
|
||||
|
||||
def replace_task(self, task, new_task):
|
||||
# assume that the data out of interface was decoded and the data in interface was encoded
|
||||
"""
|
||||
Use a new task to replace a old one
|
||||
"""
|
||||
new_task = self._encode_task(new_task)
|
||||
query = {"_id": ObjectId(task["_id"])}
|
||||
try:
|
||||
@@ -121,7 +141,7 @@ class TaskManager:
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
pymongo.results.InsertOneResult
|
||||
"""
|
||||
task = self._encode_task(
|
||||
{
|
||||
@@ -133,9 +153,9 @@ class TaskManager:
|
||||
insert_result = self.insert_task(task)
|
||||
return insert_result
|
||||
|
||||
def create_task(self, task_def_l, dry_run=False, print_nt=False):
|
||||
def create_task(self, task_def_l, dry_run=False, print_nt=False) -> List[str]:
|
||||
"""
|
||||
if the tasks in task_def_l is new, then insert new tasks into the task_pool
|
||||
If the tasks in task_def_l is new, then insert new tasks into the task_pool
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -145,6 +165,7 @@ class TaskManager:
|
||||
if insert those new tasks to task pool
|
||||
print_nt: bool
|
||||
if print new task
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
@@ -165,7 +186,7 @@ class TaskManager:
|
||||
print(t)
|
||||
|
||||
if dry_run:
|
||||
return
|
||||
return []
|
||||
|
||||
_id_list = []
|
||||
for t in new_tasks:
|
||||
@@ -174,7 +195,17 @@ class TaskManager:
|
||||
|
||||
return _id_list
|
||||
|
||||
def fetch_task(self, query={}, status=STATUS_WAITING):
|
||||
def fetch_task(self, query={}, status=STATUS_WAITING) -> dict:
|
||||
"""
|
||||
Use query to fetch tasks
|
||||
|
||||
Args:
|
||||
query (dict, optional): query dict. Defaults to {}.
|
||||
status (str, optional): [description]. Defaults to STATUS_WAITING.
|
||||
|
||||
Returns:
|
||||
dict: a task(document in collection) after decoding
|
||||
"""
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
@@ -191,7 +222,7 @@ class TaskManager:
|
||||
@contextmanager
|
||||
def safe_fetch_task(self, query={}, status=STATUS_WAITING):
|
||||
"""
|
||||
fetch task from task_pool using query with contextmanager
|
||||
Fetch task from task_pool using query with contextmanager
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -200,7 +231,7 @@ class TaskManager:
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
dict: a task(document in collection) after decoding
|
||||
"""
|
||||
task = self.fetch_task(query=query, status=status)
|
||||
try:
|
||||
@@ -231,7 +262,7 @@ class TaskManager:
|
||||
|
||||
Returns
|
||||
-------
|
||||
|
||||
dict: a task(document in collection) after decoding
|
||||
"""
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
@@ -240,16 +271,40 @@ class TaskManager:
|
||||
yield self._decode_task(t)
|
||||
|
||||
def re_query(self, _id):
|
||||
"""
|
||||
Use _id to query task.
|
||||
|
||||
Args:
|
||||
_id (str): _id of a document
|
||||
|
||||
Returns:
|
||||
dict: a task(document in collection) after decoding
|
||||
"""
|
||||
t = self.task_pool.find_one({"_id": ObjectId(_id)})
|
||||
return self._decode_task(t)
|
||||
|
||||
def commit_task_res(self, task, res, status=None):
|
||||
def commit_task_res(self, task, res, status=STATUS_DONE):
|
||||
"""
|
||||
Commit the result to task['res'].
|
||||
|
||||
Args:
|
||||
task ([type]): [description]
|
||||
res (object): the result you want to save
|
||||
status (str, optional): STATUS_WAITING, STATUS_RUNNING, STATUS_DONE, STATUS_PART_DONE. Defaults to STATUS_DONE.
|
||||
"""
|
||||
# A workaround to use the class attribute.
|
||||
if status is None:
|
||||
status = TaskManager.STATUS_DONE
|
||||
self.task_pool.update_one({"_id": task["_id"]}, {"$set": {"status": status, "res": Binary(pickle.dumps(res))}})
|
||||
|
||||
def return_task(self, task, status=None):
|
||||
def return_task(self, task, status=STATUS_WAITING):
|
||||
"""
|
||||
Return a task to status. Alway using in error handling.
|
||||
|
||||
Args:
|
||||
task ([type]): [description]
|
||||
status (str, optional): STATUS_WAITING, STATUS_RUNNING, STATUS_DONE, STATUS_PART_DONE. Defaults to STATUS_WAITING.
|
||||
"""
|
||||
if status is None:
|
||||
status = TaskManager.STATUS_WAITING
|
||||
update_dict = {"$set": {"status": status}}
|
||||
@@ -257,7 +312,7 @@ class TaskManager:
|
||||
|
||||
def remove(self, query={}):
|
||||
"""
|
||||
remove the task using query
|
||||
Remove the task using query
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -295,7 +350,7 @@ class TaskManager:
|
||||
|
||||
def prioritize(self, task, priority: int):
|
||||
"""
|
||||
set priority for task
|
||||
Set priority for task
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -331,29 +386,41 @@ class TaskManager:
|
||||
|
||||
|
||||
def run_task(
|
||||
task_func,
|
||||
task_pool,
|
||||
force_release=False,
|
||||
before_status=TaskManager.STATUS_WAITING,
|
||||
after_status=TaskManager.STATUS_DONE,
|
||||
*args,
|
||||
task_func: Callable,
|
||||
task_pool: str,
|
||||
force_release: bool = False,
|
||||
before_status: str = TaskManager.STATUS_WAITING,
|
||||
after_status: str = TaskManager.STATUS_DONE,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
While task pool is not empty (has WAITING tasks), use task_func to fetch and run tasks in task_pool
|
||||
|
||||
After running this method, here are 4 situations (before_status -> after_status):
|
||||
|
||||
STATUS_WAITING -> STATUS_DONE: use task["def"] as `task_func` param
|
||||
|
||||
STATUS_WAITING -> STATUS_PART_DONE: use task["def"] as `task_func` param
|
||||
|
||||
STATUS_PART_DONE -> STATUS_PART_DONE: use task["res"] as `task_func` param
|
||||
|
||||
STATUS_PART_DONE -> STATUS_DONE: use task["res"] as `task_func` param
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task_func : def (task_def, *args, **kwargs) -> <res which will be committed>
|
||||
the function to run the task
|
||||
task_func : Callable
|
||||
def (task_def, **kwargs) -> <res which will be committed>
|
||||
the function to run the task
|
||||
task_pool : str
|
||||
the name of the task pool (Collection in MongoDB)
|
||||
force_release :
|
||||
force_release : bool
|
||||
will the program force to release the resource
|
||||
args :
|
||||
args
|
||||
kwargs :
|
||||
kwargs
|
||||
before_status : str:
|
||||
the tasks in before_status will be fetched and trained. Can be STATUS_WAITING, STATUS_PART_DONE.
|
||||
after_status : str:
|
||||
the tasks after trained will become after_status. Can be STATUS_WAITING, STATUS_PART_DONE.
|
||||
kwargs
|
||||
the params for `task_func`
|
||||
"""
|
||||
tm = TaskManager(task_pool)
|
||||
|
||||
@@ -364,19 +431,19 @@ def run_task(
|
||||
if task is None:
|
||||
break
|
||||
get_module_logger("run_task").info(task["def"])
|
||||
# when fetching `WAITING` task, use task_def to train
|
||||
# when fetching `WAITING` task, use task["def"] to train
|
||||
if before_status == TaskManager.STATUS_WAITING:
|
||||
param = task["def"]
|
||||
# when fetching `PART_DONE` task, use task_res to train for the result has been saved
|
||||
# when fetching `PART_DONE` task, use task["res"] to train because the middle result has been saved to task["res"]
|
||||
elif before_status == TaskManager.STATUS_PART_DONE:
|
||||
param = task["res"]
|
||||
else:
|
||||
raise ValueError("The fetched task must be `STATUS_WAITING` or `STATUS_PART_DONE`!")
|
||||
if force_release:
|
||||
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
|
||||
res = executor.submit(task_func, param, *args, **kwargs).result()
|
||||
res = executor.submit(task_func, param, **kwargs).result()
|
||||
else:
|
||||
res = task_func(param, *args, **kwargs)
|
||||
res = task_func(param, **kwargs)
|
||||
tm.commit_task_res(task, res, status=after_status)
|
||||
ever_run = True
|
||||
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
Some tools for task management.
|
||||
"""
|
||||
|
||||
import bisect
|
||||
import pandas as pd
|
||||
from qlib.data import D
|
||||
@@ -7,13 +12,14 @@ from qlib.workflow import R
|
||||
from qlib.config import C
|
||||
from qlib.log import get_module_logger
|
||||
from pymongo import MongoClient
|
||||
from pymongo.database import Database
|
||||
from typing import Union
|
||||
|
||||
|
||||
def get_mongodb():
|
||||
"""
|
||||
def get_mongodb() -> Database:
|
||||
|
||||
get database in MongoDB, which means you need to declare the address and the name of database.
|
||||
"""
|
||||
Get database in MongoDB, which means you need to declare the address and the name of database.
|
||||
for example:
|
||||
|
||||
Using qlib.init():
|
||||
@@ -31,6 +37,8 @@ def get_mongodb():
|
||||
"task_db_name" : "rolling_db"
|
||||
}
|
||||
|
||||
Returns:
|
||||
Database: the Database instance
|
||||
"""
|
||||
try:
|
||||
cfg = C["mongo"]
|
||||
@@ -43,7 +51,8 @@ def get_mongodb():
|
||||
|
||||
|
||||
def list_recorders(experiment, rec_filter_func=None):
|
||||
"""list all recorders which can pass the filter in a experiment.
|
||||
"""
|
||||
List all recorders which can pass the filter in a experiment.
|
||||
|
||||
Args:
|
||||
experiment (str or Experiment): the name of a Experiment or a instance
|
||||
@@ -65,7 +74,7 @@ def list_recorders(experiment, rec_filter_func=None):
|
||||
|
||||
class TimeAdjuster:
|
||||
"""
|
||||
find appropriate date and adjust date.
|
||||
Find appropriate date and adjust date.
|
||||
"""
|
||||
|
||||
def __init__(self, future=True, end_time=None):
|
||||
@@ -88,15 +97,15 @@ class TimeAdjuster:
|
||||
return None
|
||||
return self.cals[idx]
|
||||
|
||||
def max(self):
|
||||
def max(self) -> pd.Timestamp:
|
||||
"""
|
||||
Return the max calendar datetime
|
||||
"""
|
||||
return max(self.cals)
|
||||
|
||||
def align_idx(self, time_point, tp_type="start"):
|
||||
def align_idx(self, time_point, tp_type="start") -> int:
|
||||
"""
|
||||
align the index of time_point in the calendar
|
||||
Align the index of time_point in the calendar
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -116,9 +125,9 @@ class TimeAdjuster:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
return idx
|
||||
|
||||
def cal_interval(self, time_point_A, time_point_B):
|
||||
def cal_interval(self, time_point_A, time_point_B) -> int:
|
||||
"""
|
||||
calculate the trading day interval
|
||||
Calculate the trading day interval (time_point_A - time_point_B)
|
||||
|
||||
Args:
|
||||
time_point_A : time_point_A
|
||||
@@ -129,20 +138,22 @@ class TimeAdjuster:
|
||||
"""
|
||||
return self.align_idx(time_point_A) - self.align_idx(time_point_B)
|
||||
|
||||
def align_time(self, time_point, tp_type="start"):
|
||||
def align_time(self, time_point, tp_type="start") -> pd.Timestamp:
|
||||
"""
|
||||
Align time_point to trade date of calendar
|
||||
|
||||
Parameters
|
||||
----------
|
||||
time_point
|
||||
Time point
|
||||
tp_type : str
|
||||
time point type (`"start"`, `"end"`)
|
||||
Args:
|
||||
time_point
|
||||
Time point
|
||||
tp_type : str
|
||||
time point type (`"start"`, `"end"`)
|
||||
|
||||
Returns:
|
||||
pd.Timestamp
|
||||
"""
|
||||
return self.cals[self.align_idx(time_point, tp_type=tp_type)]
|
||||
|
||||
def align_seg(self, segment: Union[dict, tuple]):
|
||||
def align_seg(self, segment: Union[dict, tuple]) -> Union[dict, tuple]:
|
||||
"""
|
||||
align the given date to trade date
|
||||
|
||||
@@ -162,7 +173,7 @@ class TimeAdjuster:
|
||||
|
||||
Returns
|
||||
-------
|
||||
the start and end trade date (pd.Timestamp) between the given start and end date.
|
||||
Union[dict, tuple]: the start and end trade date (pd.Timestamp) between the given start and end date.
|
||||
"""
|
||||
if isinstance(segment, dict):
|
||||
return {k: self.align_seg(seg) for k, seg in segment.items()}
|
||||
@@ -171,7 +182,7 @@ class TimeAdjuster:
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
def truncate(self, segment: tuple, test_start, days: int):
|
||||
def truncate(self, segment: tuple, test_start, days: int) -> tuple:
|
||||
"""
|
||||
truncate the segment based on the test_start date
|
||||
|
||||
@@ -183,6 +194,10 @@ class TimeAdjuster:
|
||||
days : int
|
||||
The trading days to be truncated
|
||||
the data in this segment may need 'days' data
|
||||
|
||||
Returns
|
||||
---------
|
||||
tuple: new segment
|
||||
"""
|
||||
test_idx = self.align_idx(test_start)
|
||||
if isinstance(segment, tuple):
|
||||
@@ -198,7 +213,7 @@ class TimeAdjuster:
|
||||
SHIFT_SD = "sliding"
|
||||
SHIFT_EX = "expanding"
|
||||
|
||||
def shift(self, seg: tuple, step: int, rtype=SHIFT_SD):
|
||||
def shift(self, seg: tuple, step: int, rtype=SHIFT_SD) -> tuple:
|
||||
"""
|
||||
shift the datatime of segment
|
||||
|
||||
@@ -211,6 +226,10 @@ class TimeAdjuster:
|
||||
rtype : str
|
||||
rolling type ("sliding" or "expanding")
|
||||
|
||||
Returns
|
||||
--------
|
||||
tuple: new segment
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError:
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys, traceback, signal, atexit
|
||||
import sys, traceback, signal, atexit, logging
|
||||
from . import R
|
||||
from .recorder import Recorder
|
||||
from ..log import get_module_logger
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
logger = get_module_logger("workflow", logging.INFO)
|
||||
|
||||
|
||||
# function to handle the experiment when unusual program ending occurs
|
||||
|
||||
24
scripts/data_collector/contrib/README.md
Normal file
24
scripts/data_collector/contrib/README.md
Normal file
@@ -0,0 +1,24 @@
|
||||
# Get future trading days
|
||||
|
||||
> `D.calendar(future=True)` will be used
|
||||
|
||||
## Requirements
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Collector Data
|
||||
|
||||
```bash
|
||||
# parse instruments, using in qlib/instruments.
|
||||
python future_trading_date_collector.py --qlib_dir ~/.qlib/qlib_data/cn_data --freq day
|
||||
```
|
||||
|
||||
## Parameters
|
||||
|
||||
- qlib_dir: qlib data directory
|
||||
- freq: value from [`day`, `1min`], default `day`
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from loguru import logger
|
||||
|
||||
# get data from baostock
|
||||
import baostock as bs
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
|
||||
|
||||
from data_collector.utils import generate_minutes_calendar_from_daily
|
||||
|
||||
|
||||
def read_calendar_from_qlib(qlib_dir: Path) -> pd.DataFrame:
|
||||
calendar_path = qlib_dir.joinpath("calendars").joinpath("day.txt")
|
||||
if not calendar_path.exists():
|
||||
return pd.DataFrame()
|
||||
return pd.read_csv(calendar_path, header=None)
|
||||
|
||||
|
||||
def write_calendar_to_qlib(qlib_dir: Path, date_list: List[str], freq: str = "day"):
|
||||
calendar_path = str(qlib_dir.joinpath("calendars").joinpath(f"{freq}_future.txt"))
|
||||
|
||||
np.savetxt(calendar_path, date_list, fmt="%s", encoding="utf-8")
|
||||
logger.info(f"write future calendars success: {calendar_path}")
|
||||
|
||||
|
||||
def generate_qlib_calendar(date_list: List[str], freq: str) -> List[str]:
|
||||
print(freq)
|
||||
if freq == "day":
|
||||
return date_list
|
||||
elif freq == "1min":
|
||||
date_list = generate_minutes_calendar_from_daily(date_list, freq=freq).tolist()
|
||||
return list(map(lambda x: pd.Timestamp(x).strftime("%Y-%m-%d %H:%M:%S"), date_list))
|
||||
else:
|
||||
raise ValueError(f"Unsupported freq: {freq}")
|
||||
|
||||
|
||||
def future_calendar_collector(qlib_dir: [str, Path], freq: str = "day"):
|
||||
"""get future calendar
|
||||
|
||||
Parameters
|
||||
----------
|
||||
qlib_dir: str or Path
|
||||
qlib data directory
|
||||
freq: str
|
||||
value from ["day", "1min"], by default day
|
||||
"""
|
||||
qlib_dir = Path(qlib_dir).expanduser().resolve()
|
||||
if not qlib_dir.exists():
|
||||
raise FileNotFoundError(str(qlib_dir))
|
||||
|
||||
lg = bs.login()
|
||||
if lg.error_code != "0":
|
||||
logger.error(f"login error: {lg.error_msg}")
|
||||
return
|
||||
# read daily calendar
|
||||
daily_calendar = read_calendar_from_qlib(qlib_dir)
|
||||
end_year = pd.Timestamp.now().year
|
||||
if daily_calendar.empty:
|
||||
start_year = pd.Timestamp.now().year
|
||||
else:
|
||||
start_year = pd.Timestamp(daily_calendar.iloc[-1, 0]).year
|
||||
rs = bs.query_trade_dates(start_date=pd.Timestamp(f"{start_year}-01-01"), end_date=f"{end_year}-12-31")
|
||||
data_list = []
|
||||
while (rs.error_code == "0") & rs.next():
|
||||
_row_data = rs.get_row_data()
|
||||
if int(_row_data[1]) == 1:
|
||||
data_list.append(_row_data[0])
|
||||
data_list = sorted(data_list)
|
||||
date_list = generate_qlib_calendar(data_list, freq=freq)
|
||||
write_calendar_to_qlib(qlib_dir, date_list, freq=freq)
|
||||
bs.logout()
|
||||
logger.info(f"get trading dates success: {start_year}-01-01 to {end_year}-12-31")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(future_calendar_collector)
|
||||
5
scripts/data_collector/contrib/requirements.txt
Normal file
5
scripts/data_collector/contrib/requirements.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
baostock
|
||||
fire
|
||||
numpy
|
||||
pandas
|
||||
loguru
|
||||
@@ -10,7 +10,9 @@ import random
|
||||
import requests
|
||||
import functools
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from lxml import etree
|
||||
from loguru import logger
|
||||
@@ -418,5 +420,40 @@ def get_trading_date_by_shift(trading_list: list, trading_date: pd.Timestamp, sh
|
||||
return res
|
||||
|
||||
|
||||
def generate_minutes_calendar_from_daily(
|
||||
calendars: Iterable,
|
||||
freq: str = "1min",
|
||||
am_range: Tuple[str, str] = ("09:30:00", "11:29:00"),
|
||||
pm_range: Tuple[str, str] = ("13:00:00", "14:59:00"),
|
||||
) -> pd.Index:
|
||||
"""generate minutes calendar
|
||||
|
||||
Parameters
|
||||
----------
|
||||
calendars: Iterable
|
||||
daily calendar
|
||||
freq: str
|
||||
by default 1min
|
||||
am_range: Tuple[str, str]
|
||||
AM Time Range, by default China-Stock: ("09:30:00", "11:29:00")
|
||||
pm_range: Tuple[str, str]
|
||||
PM Time Range, by default China-Stock: ("13:00:00", "14:59:00")
|
||||
|
||||
"""
|
||||
daily_format: str = "%Y-%m-%d"
|
||||
res = []
|
||||
for _day in calendars:
|
||||
for _range in [am_range, pm_range]:
|
||||
res.append(
|
||||
pd.date_range(
|
||||
f"{pd.Timestamp(_day).strftime(daily_format)} {_range[0]}",
|
||||
f"{pd.Timestamp(_day).strftime(daily_format)} {_range[1]}",
|
||||
freq=freq,
|
||||
)
|
||||
)
|
||||
|
||||
return pd.Index(sorted(set(np.hstack(res))))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
assert len(get_hs_stock_symbols()) >= MINIMUM_SYMBOLS_NUM
|
||||
|
||||
@@ -24,7 +24,12 @@ from qlib.config import REG_CN as REGION_CN
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
from data_collector.base import BaseCollector, BaseNormalize, BaseRun
|
||||
from data_collector.utils import get_calendar_list, get_hs_stock_symbols, get_us_stock_symbols
|
||||
from data_collector.utils import (
|
||||
get_calendar_list,
|
||||
get_hs_stock_symbols,
|
||||
get_us_stock_symbols,
|
||||
generate_minutes_calendar_from_daily,
|
||||
)
|
||||
|
||||
INDEX_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{index_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg={begin}&end={end}"
|
||||
|
||||
@@ -418,21 +423,9 @@ class YahooNormalize1min(YahooNormalize, ABC):
|
||||
return calendar_list_1d
|
||||
|
||||
def generate_1min_from_daily(self, calendars: Iterable) -> pd.Index:
|
||||
res = []
|
||||
daily_format = self.DAILY_FORMAT
|
||||
am_range = self.AM_RANGE
|
||||
pm_range = self.PM_RANGE
|
||||
for _day in calendars:
|
||||
for _range in [am_range, pm_range]:
|
||||
res.append(
|
||||
pd.date_range(
|
||||
f"{_day.strftime(daily_format)} {_range[0]}",
|
||||
f"{_day.strftime(daily_format)} {_range[1]}",
|
||||
freq="1min",
|
||||
)
|
||||
)
|
||||
|
||||
return pd.Index(sorted(set(np.hstack(res))))
|
||||
return generate_minutes_calendar_from_daily(
|
||||
calendars, freq="1min", am_range=self.AM_RANGE, pm_range=self.PM_RANGE
|
||||
)
|
||||
|
||||
def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
# TODO: using daily data factor
|
||||
|
||||
Reference in New Issue
Block a user