1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-02 18:40:58 +08:00

Merge remote-tracking branch 'microsoft/main' into data_storage

This commit is contained in:
zhupr
2021-05-21 09:49:29 +08:00
104 changed files with 5656 additions and 485 deletions

62
.github/stale.yml vendored
View File

@@ -1,62 +0,0 @@
# Configuration for probot-stale - https://github.com/probot/stale
# Number of days of inactivity before an Issue or Pull Request becomes stale
daysUntilStale: 60
# Number of days of inactivity before an Issue or Pull Request with the stale label is closed.
# Set to false to disable. If disabled, issues still need to be closed manually, but will remain marked as stale.
daysUntilClose: 7
# Only issues or pull requests with all of these labels are check if stale. Defaults to `[]` (disabled)
onlyLabels: []
# Issues or Pull Requests with these labels will never be considered stale. Set to `[]` to disable
exemptLabels:
- bug
- pinned
- security
- "[Status] Maybe Later"
# Set to true to ignore issues in a project (defaults to false)
exemptProjects: false
# Set to true to ignore issues in a milestone (defaults to false)
exemptMilestones: false
# Set to true to ignore issues with an assignee (defaults to false)
exemptAssignees: false
# Label to use when marking as stale
staleLabel: wontfix
# Comment to post when marking as stale. Set to `false` to disable
markComment: >
This issue has been automatically marked as stale because it has not had
recent activity. It will be closed if no further activity occurs. Thank you
for your contributions.
# Comment to post when removing the stale label.
# unmarkComment: >
# Your comment here.
# Comment to post when closing a stale Issue or Pull Request.
# closeComment: >
# Your comment here.
# Limit the number of actions per hour, from 1-30. Default is 30
limitPerRun: 30
# Limit to only `issues` or `pulls`
# only: issues
# Optionally, specify configuration settings that are specific to just 'issues' or 'pulls':
# pulls:
# daysUntilStale: 30
# markComment: >
# This pull request has been automatically marked as stale because it has not had
# recent activity. It will be closed if no further activity occurs. Thank you
# for your contributions.
# issues:
# exemptLabels:
# - confirmed

24
.github/workflows/stale.yml vendored Normal file
View File

@@ -0,0 +1,24 @@
name: Mark stale issues and pull requests
on:
schedule:
- cron: "0 0/3 * * *"
jobs:
stale:
runs-on: ubuntu-latest
steps:
- uses: actions/stale@v3
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: 'This issue is stale because it has been open for three months with no activity. Remove the stale label or comment on the issue otherwise this will be closed in 5 days'
stale-pr-message: 'This PR is stale because it has been open for a year with no activity. Remove the stale label or comment on the PR otherwise this will be closed in 5 days'
stale-issue-label: 'stale'
stale-pr-label: 'stale'
days-before-stale: 90
days-before-close: 5
operations-per-run: 100
exempt-issue-labels: 'bug,enhancement'
remove-stale-when-updated: true

View File

@@ -45,16 +45,16 @@ New features under development(order by estimated release time).
Your feedbacks about the features are very important.
| Feature | Status |
| -- | ------ |
| Online serving and automatic model rolling | Under review: https://github.com/microsoft/qlib/pull/290 |
| Planning-based portfolio optimization | Under review: https://github.com/microsoft/qlib/pull/280 |
| Fund data supporting and analysis | Under review: https://github.com/microsoft/qlib/pull/292 |
| Point-in-Time database | Under review: https://github.com/microsoft/qlib/pull/343 |
| High-frequency trading | Initial opensource version under development |
| High-frequency trading | Under review: https://github.com/microsoft/qlib/pull/408 |
| Meta-Learning-based data selection | Initial opensource version under development |
Recent released features
| Feature | Status |
| -- | ------ |
| Online serving and automatic model rolling | Released: https://github.com/microsoft/qlib/pull/290 |
| DoubleEnsemble Model | Released https://github.com/microsoft/qlib/pull/286 |
| High-frequency data processing example | Released https://github.com/microsoft/qlib/pull/257 |
| High-frequency trading example | Part of code released https://github.com/microsoft/qlib/pull/227 |
@@ -243,6 +243,7 @@ Qlib provides a tool named `qrun` to run the whole workflow automatically (inclu
- Rank Label
![Rank Label](docs/_static/img/rank_label.png)
-->
- [Explanation](https://qlib.readthedocs.io/en/latest/component/report.html) of above results
## Building Customized Quant Research Workflow by Code
The automatic workflow may not suit the research workflow of all Quant researchers. To support a flexible Quant research workflow, Qlib also provides a modularized interface to allow researchers to build their own workflow by code. [Here](examples/workflow_by_code.ipynb) is a demo for customized Quant research workflow by code.

BIN
docs/_static/img/online_serving.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 440 KiB

View File

@@ -14,6 +14,9 @@ Serializable Class
``Qlib`` provides a base class ``qlib.utils.serial.Serializable``, whose state can be dumped into or loaded from disk in `pickle` format.
When users dump the state of a ``Serializable`` instance, the attributes of the instance whose name **does not** start with `_` will be saved on the disk.
However, users can use ``config`` method or override ``default_dump_all`` attribute to prevent this feature.
Users can also override ``pickle_backend`` attribute to choose a pickle backend. The supported value is "pickle" (default and common) and "dill" (dump more things such as function, more information in `here <https://pypi.org/project/dill/>`_).
Example
==========================

View File

@@ -0,0 +1,89 @@
.. _task_management:
=================================
Task Management
=================================
.. currentmodule:: qlib
Introduction
=============
The `Workflow <../component/introduction.html>`_ part introduces how to run research workflow in a loosely-coupled way. But it can only execute one ``task`` when you use ``qrun``.
To automatically generate and execute different tasks, ``Task Management`` provides a whole process including `Task Generating`_, `Task Storing`_, `Task Training`_ and `Task Collecting`_.
With this module, users can run their ``task`` automatically at different periods, in different losses, or even by different models.
This whole process can be used in `Online Serving <../component/online.html>`_.
An example of the entire process is shown `here <https://github.com/microsoft/qlib/tree/main/examples/model_rolling/task_manager_rolling.py>`_.
Task Generating
===============
A ``task`` consists of `Model`, `Dataset`, `Record`, or anything added by users.
The specific task template can be viewed in
`Task Section <../component/workflow.html#task-section>`_.
Even though the task template is fixed, users can customize their ``TaskGen`` to generate different ``task`` by task template.
Here is the base class of ``TaskGen``:
.. autoclass:: qlib.workflow.task.gen.TaskGen
:members:
``Qlib`` provides a class `RollingGen <https://github.com/microsoft/qlib/tree/main/qlib/workflow/task/gen.py>`_ to generate a list of ``task`` of the dataset in different date segments.
This class allows users to verify the effect of data from different periods on the model in one experiment. More information is `here <../reference/api.html#TaskGen>`_.
Task Storing
===============
To achieve higher efficiency and the possibility of cluster operation, ``Task Manager`` will store all tasks in `MongoDB <https://www.mongodb.com/>`_.
``TaskManager`` can fetch undone tasks automatically and manage the lifecycle of a set of tasks with error handling.
Users **MUST** finish the configuration of `MongoDB <https://www.mongodb.com/>`_ when using this module.
Users need to provide the MongoDB URL and database name for using ``TaskManager`` in `initialization <../start/initialization.html#Parameters>`_ or make a statement like this.
.. code-block:: python
from qlib.config import C
C["mongo"] = {
"task_url" : "mongodb://localhost:27017/", # your MongoDB url
"task_db_name" : "rolling_db" # database name
}
.. autoclass:: qlib.workflow.task.manage.TaskManager
:members:
More information of ``Task Manager`` can be found in `here <../reference/api.html#TaskManager>`_.
Task Training
===============
After generating and storing those ``task``, it's time to run the ``task`` which is in the *WAITING* status.
``Qlib`` provides a method called ``run_task`` to run those ``task`` in task pool, however, users can also customize how tasks are executed.
An easy way to get the ``task_func`` is using ``qlib.model.trainer.task_train`` directly.
It will run the whole workflow defined by ``task``, which includes *Model*, *Dataset*, *Record*.
.. autofunction:: qlib.workflow.task.manage.run_task
Meanwhile, ``Qlib`` provides a module called ``Trainer``.
.. autoclass:: qlib.model.trainer.Trainer
:members:
``Trainer`` will train a list of tasks and return a list of model recorders.
``Qlib`` offer two kinds of Trainer, TrainerR is the simplest way and TrainerRM is based on TaskManager to help manager tasks lifecycle automatically.
If you do not want to use ``Task Manager`` to manage tasks, then use TrainerR to train a list of tasks generated by ``TaskGen`` is enough.
`Here <../reference/api.html#Trainer>`_ are the details about different ``Trainer``.
Task Collecting
===============
To collect the results of ``task`` after training, ``Qlib`` provides `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_ to collect the results in a readable, expandable and loosely-coupled way.
`Collector <../reference/api.html#Collector>`_ can collect objects from everywhere and process them such as merging, grouping, averaging and so on. It has 2 step action including ``collect`` (collect anything in a dict) and ``process_collect`` (process collected dict).
`Group <../reference/api.html#Group>`_ also has 2 steps including ``group`` (can group a set of object based on `group_func` and change them to a dict) and ``reduce`` (can make a dict become an ensemble based on some rule).
For example: {(A,B,C1): object, (A,B,C2): object} ---``group``---> {(A,B): {C1: object, C2: object}} ---``reduce``---> {(A,B): object}
`Ensemble <../reference/api.html#Ensemble>`_ can merge the objects in an ensemble.
For example: {C1: object, C2: object} ---``Ensemble``---> object
So the hierarchy is ``Collector``'s second step corresponds to ``Group``. And ``Group``'s second step correspond to ``Ensemble``.
For more information, please see `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_, or the `example <https://github.com/microsoft/qlib/tree/main/examples/model_rolling/task_manager_rolling.py>`_.

View File

@@ -182,6 +182,11 @@ The `trade unit` defines the unit number of stocks can be used in a trade, and t
qlib.init(provider_uri='~/.qlib/qlib_data/us_data', region=REG_US)
.. note::
PRs for new data source are highly welcome! Users could commit the code to crawl data as a PR like `the examples here <https://github.com/microsoft/qlib/tree/main/scripts>`_. And then we will use the code to create data cache on our server which other users could use directly.
Data API
========================
@@ -298,9 +303,10 @@ Here are some important interfaces that ``DataHandlerLP`` provides:
.. autoclass:: qlib.data.dataset.handler.DataHandlerLP
:members: __init__, fetch, get_cols
If users want to load features and labels by config, users can inherit ``qlib.data.dataset.handler.ConfigDataHandler``, ``Qlib`` also provides some preprocess method in this subclass.
If users want to use qlib data, `QLibDataHandler` is recommended. Users can inherit their custom class from `QLibDataHandler`, which is also a subclass of `ConfigDataHandler`.
If users want to load features and labels by config, users can define a new handler and call the static method `parse_config_to_fields` of ``qlib.contrib.data.handler.Alpha158``.
Also, users can pass ``qlib.contrib.data.processor.ConfigSectionProcessor`` that provides some preprocess methods for features defined by config into the new handler.
Processor
@@ -337,7 +343,6 @@ Qlib provides implemented data handler `Alpha158`. The following example shows h
.. note:: Users need to initialize ``Qlib`` with `qlib.init` first, please refer to `initialization <../start/initialization.html>`_.
.. code-block:: Python
import qlib
@@ -364,6 +369,9 @@ Qlib provides implemented data handler `Alpha158`. The following example shows h
# fetch all the features
print(h.fetch(col_set="feature"))
.. note:: In the ``Alpha158``, ``Qlib`` uses the label `Ref($close, -2)/Ref($close, -1) - 1` that means the change from T+1 to T+2, rather than `Ref($close, -1)/$close - 1`, of which the reason is that when getting the T day close price of a china stock, the stock can be bought on T+1 day and sold on T+2 day.
API
---------
@@ -388,8 +396,7 @@ The ``DatasetH`` class is the `dataset` with `Data Handler`. Here is the most im
API
---------
To know more about ``Dataset``, please refer to `Dataset API <../reference/api.html#module-qlib.data.dataset.__init__>`_.
To know more about ``Dataset``, please refer to `Dataset API <../reference/api.html#dataset>`_.
Cache

46
docs/component/online.rst Normal file
View File

@@ -0,0 +1,46 @@
.. _online:
=================================
Online Serving
=================================
.. currentmodule:: qlib
Introduction
=============
.. image:: ../_static/img/online_serving.png
:align: center
In addition to backtesting, one way to test a model is effective is to make predictions in real market conditions or even do real trading based on those predictions.
``Online Serving`` is a set of modules for online models using the latest data,
which including `Online Manager <#Online Manager>`_, `Online Strategy <#Online Strategy>`_, `Online Tool <#Online Tool>`_, `Updater <#Updater>`_.
`Here <https://github.com/microsoft/qlib/tree/main/examples/online_srv>`_ are several examples for reference, which demonstrate different features of ``Online Serving``.
If you have many models or `task` needs to be managed, please consider `Task Management <../advanced/task_management.html>`_.
The `examples <https://github.com/microsoft/qlib/tree/main/examples/online_srv>`_ are based on some components in `Task Management <../advanced/task_management.html>`_ such as ``TrainerRM`` or ``Collector``.
Online Manager
=============
.. automodule:: qlib.workflow.online.manager
:members:
Online Strategy
=============
.. automodule:: qlib.workflow.online.strategy
:members:
Online Tool
=============
.. automodule:: qlib.workflow.online.utils
:members:
Updater
=============
.. automodule:: qlib.workflow.online.update
:members:

View File

@@ -34,6 +34,7 @@ Here is a general view of the structure of the system:
- Recorder 2
- ...
- ...
This experiment management system defines a set of interface and provided a concrete implementation ``MLflowExpManager``, which is based on the machine learning platform: ``MLFlow`` (`link <https://mlflow.org/>`_).
If users set the implementation of ``ExpManager`` to be ``MLflowExpManager``, they can use the command `mlflow ui` to visualize and check the experiment results. For more information, pleaes refer to the related documents `here <https://www.mlflow.org/docs/latest/cli.html#mlflow-ui>`_.

View File

@@ -42,6 +42,7 @@ Document Structure
Intraday Trading: Model&Strategy Testing <component/backtest.rst>
Qlib Recorder: Experiment Management <component/recorder.rst>
Analysis: Evaluation & Results Analysis <component/report.rst>
Online Serving: Online Management & Strategy & Tool <component/online.rst>
.. toctree::
:maxdepth: 3
@@ -50,6 +51,7 @@ Document Structure
Building Formulaic Alphas <advanced/alpha.rst>
Online & Offline mode <advanced/server.rst>
Serialization <advanced/serial.rst>
Task Management <advanced/task_management.rst>
.. toctree::
:maxdepth: 3

View File

@@ -154,6 +154,70 @@ Record Template
.. automodule:: qlib.workflow.record_temp
:members:
Task Management
====================
TaskGen
--------------------
.. automodule:: qlib.workflow.task.gen
:members:
TaskManager
--------------------
.. automodule:: qlib.workflow.task.manage
:members:
Trainer
--------------------
.. automodule:: qlib.model.trainer
:members:
Collector
--------------------
.. automodule:: qlib.workflow.task.collect
:members:
Group
--------------------
.. automodule:: qlib.model.ens.group
:members:
Ensemble
--------------------
.. automodule:: qlib.model.ens.ensemble
:members:
Utils
--------------------
.. automodule:: qlib.workflow.task.utils
:members:
Online Serving
====================
Online Manager
--------------------
.. automodule:: qlib.workflow.online.manager
:members:
Online Strategy
--------------------
.. automodule:: qlib.workflow.online.strategy
:members:
Online Tool
--------------------
.. automodule:: qlib.workflow.online.utils
:members:
RecordUpdater
--------------------
.. automodule:: qlib.workflow.online.update
:members:
Utils
====================
@@ -162,4 +226,7 @@ Serializable
--------------------
.. automodule:: qlib.utils.serial.Serializable
:members:
:members:

View File

@@ -75,3 +75,14 @@ Besides `provider_uri` and `region`, `qlib.init` has other parameters. The follo
"default_exp_name": "Experiment",
}
})
- `mongo`
Type: dict, optional parameter, the setting of `MongoDB <https://www.mongodb.com/>`_ which will be used in some features such as `Task Management <../advanced/task_management.html>`_, with high performance and clustered processing.
Users need finished `installation <https://www.mongodb.com/try/download/community>`_ firstly, and run it in a fixed URL.
.. code-block:: Python
# For example, you can initialize qlib below
qlib.init(provider_uri=provider_uri, region=REG_CN, mongo={
"task_url": "mongodb://localhost:27017/", # your mongo url
"task_db_name": "rolling_db", # the database name of Task Management
})

View File

@@ -0,0 +1,81 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data"
region: cn
market: &market csi300
benchmark: &benchmark SH000300
data_handler_config: &data_handler_config
start_time: 2008-01-01
end_time: 2020-08-01
instruments: *market
data_loader:
class: QlibDataLoader
kwargs:
config:
feature:
- ["Resi($close, 15)/$close", "Std(Abs($close/Ref($close, 1)-1)*$volume, 5)/(Mean(Abs($close/Ref($close, 1)-1)*$volume, 5)+1e-12)", "Rsquare($close, 5)", "($high-$low)/$open", "Rsquare($close, 10)", "Corr($close, Log($volume+1), 5)", "Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), 5)", "Corr($close, Log($volume+1), 10)", "Rsquare($close, 20)", "Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), 60)", "Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), 10)", "Corr($close, Log($volume+1), 20)", "(Less($open, $close)-$low)/$open"]
- ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10", "RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"]
label:
- ["Ref($close, -2)/Ref($close, -1) - 1"]
- ["LABEL0"]
freq: day
learn_processors:
- class: DropnaLabel
- class: CSZScoreNorm
kwargs:
fields_group: label
port_analysis_config: &port_analysis_config
strategy:
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy.strategy
kwargs:
topk: 50
n_drop: 5
backtest:
verbose: False
limit_threshold: 0.095
account: 100000000
benchmark: *benchmark
deal_price: close
open_cost: 0.0005
close_cost: 0.0015
min_cost: 5
task:
model:
class: LGBModel
module_path: qlib.contrib.model.gbdt
kwargs:
loss: mse
colsample_bytree: 0.8879
learning_rate: 0.2
subsample: 0.8789
lambda_l1: 205.6999
lambda_l2: 580.9768
max_depth: 8
num_leaves: 210
num_threads: 20
dataset:
class: DatasetH
module_path: qlib.data.dataset
kwargs:
handler:
class: DataHandlerLP
module_path: qlib.data.dataset.handler
kwargs: *data_handler_config
segments:
train: [2008-01-01, 2014-12-31]
valid: [2015-01-01, 2016-12-31]
test: [2017-01-01, 2020-08-01]
record:
- class: SignalRecord
module_path: qlib.workflow.record_temp
kwargs: {}
- class: SigAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
ana_long_short: False
ann_scaler: 252
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config

View File

@@ -17,6 +17,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
| ALSTM (Yao Qin, et al.) | Alpha360 | 0.0493±0.01 | 0.3778±0.06| 0.0585±0.00 | 0.4606±0.04 | 0.0513±0.03 | 0.6727±0.38| -0.1085±0.02 |
| GATs (Petar Velickovic, et al.) | Alpha360 | 0.0475±0.00 | 0.3515±0.02| 0.0592±0.00 | 0.4585±0.01 | 0.0876±0.02 | 1.1513±0.27| -0.0795±0.02 |
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha360 | 0.0407±0.00| 0.3053±0.00 | 0.0490±0.00 | 0.3840±0.00 | 0.0380±0.02 | 0.5000±0.21 | -0.0984±0.02 |
| TabNet (Sercan O. Arik, et al.)| Alpha360 | 0.0192±0.00 | 0.1401±0.00| 0.0291±0.00 | 0.2163±0.00 | -0.0258±0.00 | -0.2961±0.00| -0.1429±0.00 |
## Alpha158 dataset
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
@@ -32,6 +33,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
| ALSTM (Yao Qin, et al.) | Alpha158 (with selected 20 features) | 0.0385±0.01 | 0.3022±0.06| 0.0478±0.00 | 0.3874±0.04 | 0.0486±0.03 | 0.7141±0.45| -0.1088±0.03 |
| GATs (Petar Velickovic, et al.) | Alpha158 (with selected 20 features) | 0.0349±0.00 | 0.2511±0.01| 0.0457±0.00 | 0.3537±0.01 | 0.0578±0.02 | 0.8221±0.25| -0.0824±0.02 |
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha158 | 0.0544±0.00 | 0.4338±0.01 | 0.0523±0.00 | 0.4257±0.01 | 0.1253±0.01 | 1.4105±0.14 | -0.0902±0.01 |
| TabNet (Sercan O. Arik, et al.)| Alpha158 | 0.0383±0.00 | 0.3414±0.00| 0.0388±0.00 | 0.3460±0.00 | 0.0226±0.00 | 0.2652±0.00| -0.1072±0.00 |
- The selected 20 features are based on the feature importance of a lightgbm-based model.
- The base model of DoubleEnsemble is LGBM.

View File

@@ -132,7 +132,7 @@ class GenericDataFormatter(abc.ABC):
return -1, -1
def get_column_definition(self):
""""Returns formatted column definition in order expected by the TFT."""
"""Returns formatted column definition in order expected by the TFT."""
column_definition = self._column_definition

View File

@@ -25,4 +25,11 @@ The example is given in `workflow.py`, users can run the code as follows.
Run the example by running the following command:
```bash
python workflow.py dump_and_load_dataset
```
```
## Benchmarks Performance
### Signal Test
Here are the results of signal test for benchmark models. We will keep updating benchmark models in future.
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Long precision| Short Precision | Long-Short Average Return | Long-Short Average Sharpe |
|---|---|---|---|---|---|---|---|---|---|
| LightGBM | Alpha158 | 0.3042±0.00 | 1.5372±0.00| 0.3117±0.00 | 1.6258±0.00 | 0.6720±0.00 | 0.6870±0.00 | 0.000769±0.00 | 1.0190±0.00 |

View File

@@ -27,12 +27,11 @@ from qlib.tests.data import GetData
from highfreq_ops import get_calendar_day, DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut
class HighfreqWorkflow(object):
class HighfreqWorkflow:
SPEC_CONF = {"custom_ops": [DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut], "expression_cache": None}
MARKET = "all"
BENCHMARK = "SH000300"
start_time = "2020-09-15 00:00:00"
end_time = "2021-01-18 16:00:00"
@@ -146,35 +145,40 @@ class HighfreqWorkflow(object):
self._prepare_calender_cache()
##=============reinit dataset=============
dataset.init(
dataset.config(
handler_kwargs={
"start_time": "2021-01-19 00:00:00",
"end_time": "2021-01-25 16:00:00",
},
segments={
"test": (
"2021-01-19 00:00:00",
"2021-01-25 16:00:00",
),
},
)
dataset.setup_data(
handler_kwargs={
"init_type": DataHandlerLP.IT_LS,
"start_time": "2021-01-19 00:00:00",
"end_time": "2021-01-25 16:00:00",
},
segment_kwargs={
"test": (
"2021-01-19 00:00:00",
"2021-01-25 16:00:00",
),
},
)
dataset_backtest.init(
dataset_backtest.config(
handler_kwargs={
"start_time": "2021-01-19 00:00:00",
"end_time": "2021-01-25 16:00:00",
},
segment_kwargs={
segments={
"test": (
"2021-01-19 00:00:00",
"2021-01-25 16:00:00",
),
},
)
dataset_backtest.setup_data(handler_kwargs={})
##=============get data=============
xtest = dataset.prepare(["test"])
backtest_test = dataset_backtest.prepare(["test"])
xtest = dataset.prepare("test")
backtest_test = dataset_backtest.prepare("test")
print(xtest, backtest_test)
return

View File

@@ -0,0 +1,65 @@
qlib_init:
provider_uri: "~/.qlib/qlib_data/cn_data_1min"
region: cn
market: &market 'csi300'
start_time: &start_time "2020-09-15 00:00:00"
end_time: &end_time "2021-01-18 16:00:00"
train_end_time: &train_end_time "2020-11-15 16:00:00"
valid_start_time: &valid_start_time "2020-11-16 00:00:00"
valid_end_time: &valid_end_time "2020-11-30 16:00:00"
test_start_time: &test_start_time "2020-12-01 00:00:00"
data_handler_config: &data_handler_config
start_time: *start_time
end_time: *end_time
fit_start_time: *start_time
fit_end_time: *train_end_time
instruments: *market
freq: '1min'
infer_processors:
- class: 'RobustZScoreNorm'
kwargs:
fields_group: 'feature'
clip_outlier: false
- class: "Fillna"
kwargs:
fields_group: 'feature'
learn_processors:
- class: 'DropnaLabel'
- class: 'CSRankNorm'
kwargs:
fields_group: 'label'
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
task:
model:
class: "HFLGBModel"
module_path: "qlib.contrib.model.highfreq_gdbt_model"
kwargs:
objective: 'binary'
metric: ['binary_logloss','auc']
verbosity: -1
learning_rate: 0.01
max_depth: 8
num_leaves: 150
lambda_l1: 1.5
lambda_l2: 1
num_threads: 20
dataset:
class: "DatasetH"
module_path: "qlib.data.dataset"
kwargs:
handler:
class: "Alpha158"
module_path: "qlib.contrib.data.handler"
kwargs: *data_handler_config
segments:
train: [*start_time, *train_end_time]
valid: [*train_end_time, *valid_end_time]
test: [*test_start_time, *end_time]
record:
- class: "SignalRecord"
module_path: "qlib.workflow.record_temp"
kwargs: {}
- class: "HFSignalRecord"
module_path: "qlib.workflow.record_temp"
kwargs: {}

View File

@@ -0,0 +1,23 @@
# LightGBM hyperparameter
## Alpha158
First terminal
```
optuna create-study --study LGBM_158 --storage sqlite:///db.sqlite3
optuna-dashboard --port 5000 --host 0.0.0.0 sqlite:///db.sqlite3
```
Second terminal
```
python hyperparameter_158.py
```
## Alpha360
First terminal
```
optuna create-study --study LGBM_360 --storage sqlite:///db.sqlite3
optuna-dashboard --port 5000 --host 0.0.0.0 sqlite:///db.sqlite3
```
Second terminal
```
python hyperparameter_360.py
```

View File

@@ -0,0 +1,76 @@
import qlib
from qlib.config import REG_CN
from qlib.utils import exists_qlib_data, init_instance_by_config
import optuna
provider_uri = "~/.qlib/qlib_data/cn_data"
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
sys.path.append(str(scripts_dir))
from get_data import GetData
GetData().qlib_data(target_dir=provider_uri, region="cn")
qlib.init(provider_uri=provider_uri, region="cn")
market = "csi300"
benchmark = "SH000300"
data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": market,
}
dataset_task = {
"dataset": {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": data_handler_config,
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
},
},
},
}
dataset = init_instance_by_config(dataset_task["dataset"])
def objective(trial):
task = {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
"kwargs": {
"loss": "mse",
"colsample_bytree": trial.suggest_uniform("colsample_bytree", 0.5, 1),
"learning_rate": trial.suggest_uniform("learning_rate", 0, 1),
"subsample": trial.suggest_uniform("subsample", 0, 1),
"lambda_l1": trial.suggest_loguniform("lambda_l1", 1e-8, 1e4),
"lambda_l2": trial.suggest_loguniform("lambda_l2", 1e-8, 1e4),
"max_depth": 10,
"num_leaves": trial.suggest_int("num_leaves", 1, 1024),
"feature_fraction": trial.suggest_uniform("feature_fraction", 0.4, 1.0),
"bagging_fraction": trial.suggest_uniform("bagging_fraction", 0.4, 1.0),
"bagging_freq": trial.suggest_int("bagging_freq", 1, 7),
"min_data_in_leaf": trial.suggest_int("min_data_in_leaf", 1, 50),
"min_child_samples": trial.suggest_int("min_child_samples", 5, 100),
},
},
}
evals_result = dict()
model = init_instance_by_config(task["model"])
model.fit(dataset, evals_result=evals_result)
return min(evals_result["valid"])
study = optuna.Study(study_name="LGBM_158", storage="sqlite:///db.sqlite3")
study.optimize(objective, n_jobs=6)

View File

@@ -0,0 +1,76 @@
import qlib
from qlib.config import REG_CN
from qlib.utils import exists_qlib_data, init_instance_by_config
import optuna
provider_uri = "~/.qlib/qlib_data/cn_data"
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
sys.path.append(str(scripts_dir))
from get_data import GetData
GetData().qlib_data(target_dir=provider_uri, region="cn")
qlib.init(provider_uri=provider_uri, region="cn")
market = "csi300"
benchmark = "SH000300"
data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": market,
}
dataset_task = {
"dataset": {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha360",
"module_path": "qlib.contrib.data.handler",
"kwargs": data_handler_config,
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
},
},
},
}
dataset = init_instance_by_config(dataset_task["dataset"])
def objective(trial):
task = {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
"kwargs": {
"loss": "mse",
"colsample_bytree": trial.suggest_uniform("colsample_bytree", 0.5, 1),
"learning_rate": trial.suggest_uniform("learning_rate", 0, 1),
"subsample": trial.suggest_uniform("subsample", 0, 1),
"lambda_l1": trial.suggest_loguniform("lambda_l1", 1e-8, 1e4),
"lambda_l2": trial.suggest_loguniform("lambda_l2", 1e-8, 1e4),
"max_depth": 10,
"num_leaves": trial.suggest_int("num_leaves", 1, 1024),
"feature_fraction": trial.suggest_uniform("feature_fraction", 0.4, 1.0),
"bagging_fraction": trial.suggest_uniform("bagging_fraction", 0.4, 1.0),
"bagging_freq": trial.suggest_int("bagging_freq", 1, 7),
"min_data_in_leaf": trial.suggest_int("min_data_in_leaf", 1, 50),
"min_child_samples": trial.suggest_int("min_child_samples", 5, 100),
},
},
}
evals_result = dict()
model = init_instance_by_config(task["model"])
model.fit(dataset, evals_result=evals_result)
return min(evals_result["valid"])
study = optuna.Study(study_name="LGBM_360", storage="sqlite:///db.sqlite3")
study.optimize(objective, n_jobs=6)

View File

@@ -0,0 +1,5 @@
pandas==1.1.2
numpy==1.17.4
lightgbm==3.1.0
optuna==2.7.0
optuna-dashboard==0.4.1

View File

@@ -0,0 +1,159 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
This example shows how a TrainerRM works based on TaskManager with rolling tasks.
After training, how to collect the rolling results will be shown in task_collecting.
"""
from pprint import pprint
import fire
import qlib
from qlib.config import REG_CN
from qlib.workflow import R
from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.manage import TaskManager
from qlib.workflow.task.collect import RecorderCollector
from qlib.model.ens.group import RollingGroup
from qlib.model.trainer import TrainerRM
data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": "csi100",
}
dataset_config = {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": data_handler_config,
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
},
},
}
record_config = [
{
"class": "SignalRecord",
"module_path": "qlib.workflow.record_temp",
},
{
"class": "SigAnaRecord",
"module_path": "qlib.workflow.record_temp",
},
]
# use lgb
task_lgb_config = {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
},
"dataset": dataset_config,
"record": record_config,
}
# use xgboost
task_xgboost_config = {
"model": {
"class": "XGBModel",
"module_path": "qlib.contrib.model.xgboost",
},
"dataset": dataset_config,
"record": record_config,
}
class RollingTaskExample:
def __init__(
self,
provider_uri="~/.qlib/qlib_data/cn_data",
region=REG_CN,
task_url="mongodb://10.0.0.4:27017/",
task_db_name="rolling_db",
experiment_name="rolling_exp",
task_pool="rolling_task",
task_config=[task_xgboost_config, task_lgb_config],
rolling_step=550,
rolling_type=RollingGen.ROLL_SD,
):
# TaskManager config
mongo_conf = {
"task_url": task_url,
"task_db_name": task_db_name,
}
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
self.experiment_name = experiment_name
self.task_pool = task_pool
self.task_config = task_config
self.rolling_gen = RollingGen(step=rolling_step, rtype=rolling_type)
# Reset all things to the first status, be careful to save important data
def reset(self):
print("========== reset ==========")
TaskManager(task_pool=self.task_pool).remove()
exp = R.get_exp(experiment_name=self.experiment_name)
for rid in exp.list_recorders():
exp.delete_recorder(rid)
def task_generating(self):
print("========== task_generating ==========")
tasks = task_generator(
tasks=self.task_config,
generators=self.rolling_gen, # generate different date segments
)
pprint(tasks)
return tasks
def task_training(self, tasks):
print("========== task_training ==========")
trainer = TrainerRM(self.experiment_name, self.task_pool)
trainer.train(tasks)
def task_collecting(self):
print("========== task_collecting ==========")
def rec_key(recorder):
task_config = recorder.load_object("task")
model_key = task_config["model"]["class"]
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
return model_key, rolling_key
def my_filter(recorder):
# only choose the results of "LGBModel"
model_key, rolling_key = rec_key(recorder)
if model_key == "LGBModel":
return True
return False
collector = RecorderCollector(
experiment=self.experiment_name,
process_list=RollingGroup(),
rec_key_func=rec_key,
rec_filter_func=my_filter,
)
print(collector())
def main(self):
self.reset()
tasks = self.task_generating()
self.task_training(tasks)
self.task_collecting()
if __name__ == "__main__":
## to see the whole process with your own parameters, use the command below
# python task_manager_rolling.py main --experiment_name="your_exp_name"
fire.Fire(RollingTaskExample)

View File

@@ -0,0 +1,146 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
This example is about how can simulate the OnlineManager based on rolling tasks.
"""
import fire
import qlib
from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM
from qlib.workflow import R
from qlib.workflow.online.manager import OnlineManager
from qlib.workflow.online.strategy import RollingStrategy
from qlib.workflow.task.gen import RollingGen
from qlib.workflow.task.manage import TaskManager
data_handler_config = {
"start_time": "2018-01-01",
"end_time": "2018-10-31",
"fit_start_time": "2018-01-01",
"fit_end_time": "2018-03-31",
"instruments": "csi100",
}
dataset_config = {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": data_handler_config,
},
"segments": {
"train": ("2018-01-01", "2018-03-31"),
"valid": ("2018-04-01", "2018-05-31"),
"test": ("2018-06-01", "2018-09-10"),
},
},
}
record_config = [
{
"class": "SignalRecord",
"module_path": "qlib.workflow.record_temp",
},
{
"class": "SigAnaRecord",
"module_path": "qlib.workflow.record_temp",
},
]
# use lgb model
task_lgb_config = {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
},
"dataset": dataset_config,
"record": record_config,
}
# use xgboost model
task_xgboost_config = {
"model": {
"class": "XGBModel",
"module_path": "qlib.contrib.model.xgboost",
},
"dataset": dataset_config,
"record": record_config,
}
class OnlineSimulationExample:
def __init__(
self,
provider_uri="~/.qlib/qlib_data/cn_data",
region="cn",
exp_name="rolling_exp",
task_url="mongodb://10.0.0.4:27017/",
task_db_name="rolling_db",
task_pool="rolling_task",
rolling_step=80,
start_time="2018-09-10",
end_time="2018-10-31",
tasks=[task_xgboost_config, task_lgb_config],
):
"""
Init OnlineManagerExample.
Args:
provider_uri (str, optional): the provider uri. Defaults to "~/.qlib/qlib_data/cn_data".
region (str, optional): the stock region. Defaults to "cn".
exp_name (str, optional): the experiment name. Defaults to "rolling_exp".
task_url (str, optional): your MongoDB url. Defaults to "mongodb://10.0.0.4:27017/".
task_db_name (str, optional): database name. Defaults to "rolling_db".
task_pool (str, optional): the task pool name (a task pool is a collection in MongoDB). Defaults to "rolling_task".
rolling_step (int, optional): the step for rolling. Defaults to 80.
start_time (str, optional): the start time of simulating. Defaults to "2018-09-10".
end_time (str, optional): the end time of simulating. Defaults to "2018-10-31".
tasks (dict or list[dict]): a set of the task config waiting for rolling and training
"""
self.exp_name = exp_name
self.task_pool = task_pool
self.start_time = start_time
self.end_time = end_time
mongo_conf = {
"task_url": task_url,
"task_db_name": task_db_name,
}
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
self.rolling_gen = RollingGen(
step=rolling_step, rtype=RollingGen.ROLL_SD, ds_extra_mod_func=None
) # The rolling tasks generator, ds_extra_mod_func is None because we just need to simulate to 2018-10-31 and needn't change the handler end time.
self.trainer = DelayTrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR
self.rolling_online_manager = OnlineManager(
RollingStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen),
trainer=self.trainer,
begin_time=self.start_time,
)
self.tasks = tasks
# Reset all things to the first status, be careful to save important data
def reset(self):
TaskManager(self.task_pool).remove()
exp = R.get_exp(experiment_name=self.exp_name)
for rid in exp.list_recorders():
exp.delete_recorder(rid)
# Run this to run all workflow automatically
def main(self):
print("========== reset ==========")
self.reset()
print("========== simulate ==========")
self.rolling_online_manager.simulate(end_time=self.end_time)
print("========== collect results ==========")
print(self.rolling_online_manager.get_collector()())
print("========== signals ==========")
print(self.rolling_online_manager.get_signals())
if __name__ == "__main__":
## to run all workflow automatically with your own parameters, use the command below
# python online_management_simulate.py main --experiment_name="your_exp_name" --rolling_step=60
fire.Fire(OnlineSimulationExample)

View File

@@ -0,0 +1,181 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
This example shows how OnlineManager works with rolling tasks.
There are four parts including first train, routine 1, add strategy and routine 2.
Firstly, the OnlineManager will finish the first training and set trained models to `online` models.
Next, the OnlineManager will finish a routine process, including update online prediction -> prepare tasks -> prepare new models -> prepare signals
Then, we will add some new strategies to the OnlineManager. This will finish first training of new strategies.
Finally, the OnlineManager will finish second routine and update all strategies.
"""
import os
import fire
import qlib
from qlib.workflow import R
from qlib.workflow.online.strategy import RollingStrategy
from qlib.workflow.task.gen import RollingGen
from qlib.workflow.online.manager import OnlineManager
data_handler_config = {
"start_time": "2013-01-01",
"end_time": "2020-09-25",
"fit_start_time": "2013-01-01",
"fit_end_time": "2014-12-31",
"instruments": "csi100",
}
dataset_config = {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": data_handler_config,
},
"segments": {
"train": ("2013-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2015-12-31"),
"test": ("2016-01-01", "2020-07-10"),
},
},
}
record_config = [
{
"class": "SignalRecord",
"module_path": "qlib.workflow.record_temp",
},
{
"class": "SigAnaRecord",
"module_path": "qlib.workflow.record_temp",
},
]
# use lgb model
task_lgb_config = {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
},
"dataset": dataset_config,
"record": record_config,
}
# use xgboost model
task_xgboost_config = {
"model": {
"class": "XGBModel",
"module_path": "qlib.contrib.model.xgboost",
},
"dataset": dataset_config,
"record": record_config,
}
class RollingOnlineExample:
def __init__(
self,
provider_uri="~/.qlib/qlib_data/cn_data",
region="cn",
task_url="mongodb://10.0.0.4:27017/",
task_db_name="rolling_db",
rolling_step=550,
tasks=[task_xgboost_config],
add_tasks=[task_lgb_config],
):
mongo_conf = {
"task_url": task_url, # your MongoDB url
"task_db_name": task_db_name, # database name
}
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
self.tasks = tasks
self.add_tasks = add_tasks
self.rolling_step = rolling_step
strategies = []
for task in tasks:
name_id = task["model"]["class"] # NOTE: Assumption: The model class can specify only one strategy
strategies.append(
RollingStrategy(
name_id,
task,
RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD),
)
)
self.rolling_online_manager = OnlineManager(strategies)
_ROLLING_MANAGER_PATH = (
".RollingOnlineExample" # the OnlineManager will dump to this file, for it can be loaded when calling routine.
)
# Reset all things to the first status, be careful to save important data
def reset(self):
for task in self.tasks + self.add_tasks:
name_id = task["model"]["class"]
exp = R.get_exp(experiment_name=name_id)
for rid in exp.list_recorders():
exp.delete_recorder(rid)
if os.path.exists(self._ROLLING_MANAGER_PATH):
os.remove(self._ROLLING_MANAGER_PATH)
def first_run(self):
print("========== reset ==========")
self.reset()
print("========== first_run ==========")
self.rolling_online_manager.first_train()
print("========== collect results ==========")
print(self.rolling_online_manager.get_collector()())
print("========== dump ==========")
self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)
def routine(self):
print("========== load ==========")
self.rolling_online_manager = OnlineManager.load(self._ROLLING_MANAGER_PATH)
print("========== routine ==========")
self.rolling_online_manager.routine()
print("========== collect results ==========")
print(self.rolling_online_manager.get_collector()())
print("========== signals ==========")
print(self.rolling_online_manager.get_signals())
print("========== dump ==========")
self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)
def add_strategy(self):
print("========== load ==========")
self.rolling_online_manager = OnlineManager.load(self._ROLLING_MANAGER_PATH)
print("========== add strategy ==========")
strategies = []
for task in self.add_tasks:
name_id = task["model"]["class"] # NOTE: Assumption: The model class can specify only one strategy
strategies.append(
RollingStrategy(
name_id,
task,
RollingGen(step=self.rolling_step, rtype=RollingGen.ROLL_SD),
)
)
self.rolling_online_manager.add_strategy(strategies=strategies)
print("========== dump ==========")
self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)
def main(self):
self.first_run()
self.routine()
self.add_strategy()
self.routine()
if __name__ == "__main__":
####### to train the first version's models, use the command below
# python rolling_online_management.py first_run
####### to update the models and predictions after the trading time, use the command below
# python rolling_online_management.py routine
####### to define your own parameters, use `--`
# python rolling_online_management.py first_run --exp_name='your_exp_name' --rolling_step=40
fire.Fire(RollingOnlineExample)

View File

@@ -0,0 +1,91 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
This example shows how OnlineTool works when we need update prediction.
There are two parts including first_train and update_online_pred.
Firstly, we will finish the training and set the trained models to the `online` models.
Next, we will finish updating online predictions.
"""
import fire
import qlib
from qlib.config import REG_CN
from qlib.model.trainer import task_train
from qlib.workflow.online.utils import OnlineToolR
data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": "csi100",
}
task = {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
"kwargs": {
"loss": "mse",
"colsample_bytree": 0.8879,
"learning_rate": 0.0421,
"subsample": 0.8789,
"lambda_l1": 205.6999,
"lambda_l2": 580.9768,
"max_depth": 8,
"num_leaves": 210,
"num_threads": 20,
},
},
"dataset": {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": data_handler_config,
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
},
},
},
"record": {
"class": "SignalRecord",
"module_path": "qlib.workflow.record_temp",
},
}
class UpdatePredExample:
def __init__(
self, provider_uri="~/.qlib/qlib_data/cn_data", region=REG_CN, experiment_name="online_srv", task_config=task
):
qlib.init(provider_uri=provider_uri, region=region)
self.experiment_name = experiment_name
self.online_tool = OnlineToolR(self.experiment_name)
self.task_config = task_config
def first_train(self):
rec = task_train(self.task_config, experiment_name=self.experiment_name)
self.online_tool.reset_online_tag(rec) # set to online model
def update_online_pred(self):
self.online_tool.update_online_pred()
def main(self):
self.first_train()
self.update_online_pred()
if __name__ == "__main__":
## to train a model and set it to online model, use the command below
# python update_online_pred.py first_train
## to update online predictions once a day, use the command below
# python update_online_pred.py update_online_pred
## to see the whole process with your own parameters, use the command below
# python update_online_pred.py main --experiment_name="your_exp_name"
fire.Fire(UpdatePredExample)

View File

@@ -0,0 +1,17 @@
# Rolling Process Data
This workflow is an example for `Rolling Process Data`.
## Background
When rolling train the models, data also needs to be generated in the different rolling windows. When the rolling window moves, the training data will change, and the processor's learnable state (such as standard deviation, mean, etc.) will also change.
In order to avoid regenerating data, this example uses the `DataHandler-based DataLoader` to load the raw features that are not related to the rolling window, and then used Processors to generate processed-features related to the rolling window.
## Run the Code
Run the example by running the following command:
```bash
python workflow.py rolling_process
```

View File

@@ -0,0 +1,32 @@
from qlib.data.dataset.handler import DataHandlerLP
from qlib.data.dataset.loader import DataLoaderDH
from qlib.contrib.data.handler import check_transform_proc
class RollingDataHandler(DataHandlerLP):
def __init__(
self,
start_time=None,
end_time=None,
infer_processors=[],
learn_processors=[],
fit_start_time=None,
fit_end_time=None,
data_loader_kwargs={},
):
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
data_loader = {
"class": "DataLoaderDH",
"kwargs": {**data_loader_kwargs},
}
super().__init__(
instruments=None,
start_time=start_time,
end_time=end_time,
data_loader=data_loader,
infer_processors=infer_processors,
learn_processors=learn_processors,
)

View File

@@ -0,0 +1,141 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import qlib
import fire
import pickle
import pandas as pd
from datetime import datetime
from qlib.config import REG_CN
from qlib.data.dataset.handler import DataHandlerLP
from qlib.contrib.data.handler import Alpha158
from qlib.utils import exists_qlib_data, init_instance_by_config
from qlib.tests.data import GetData
class RollingDataWorkflow:
MARKET = "csi300"
start_time = "2010-01-01"
end_time = "2019-12-31"
rolling_cnt = 5
def _init_qlib(self):
"""initialize qlib"""
# use yahoo_cn_1min data
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
qlib.init(provider_uri=provider_uri, region=REG_CN)
def _dump_pre_handler(self, path):
handler_config = {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": {
"start_time": self.start_time,
"end_time": self.end_time,
"instruments": self.MARKET,
"infer_processors": [],
"learn_processors": [],
},
}
pre_handler = init_instance_by_config(handler_config)
pre_handler.config(dump_all=True)
pre_handler.to_pickle(path)
def _load_pre_handler(self, path):
with open(path, "rb") as file_dataset:
pre_handler = pickle.load(file_dataset)
return pre_handler
def rolling_process(self):
self._init_qlib()
self._dump_pre_handler("pre_handler.pkl")
pre_handler = self._load_pre_handler("pre_handler.pkl")
train_start_time = (2010, 1, 1)
train_end_time = (2012, 12, 31)
valid_start_time = (2013, 1, 1)
valid_end_time = (2013, 12, 31)
test_start_time = (2014, 1, 1)
test_end_time = (2014, 12, 31)
dataset_config = {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "RollingDataHandler",
"module_path": "rolling_handler",
"kwargs": {
"start_time": datetime(*train_start_time),
"end_time": datetime(*test_end_time),
"fit_start_time": datetime(*train_start_time),
"fit_end_time": datetime(*train_end_time),
"infer_processors": [
{"class": "RobustZScoreNorm", "kwargs": {"fields_group": "feature"}},
],
"learn_processors": [
{"class": "DropnaLabel"},
{"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}},
],
"data_loader_kwargs": {
"handler_config": pre_handler,
},
},
},
"segments": {
"train": (datetime(*train_start_time), datetime(*train_end_time)),
"valid": (datetime(*valid_start_time), datetime(*valid_end_time)),
"test": (datetime(*test_start_time), datetime(*test_end_time)),
},
},
}
dataset = init_instance_by_config(dataset_config)
for rolling_offset in range(self.rolling_cnt):
print(f"===========rolling{rolling_offset} start===========")
if rolling_offset:
dataset.config(
handler_kwargs={
"start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]),
"end_time": datetime(test_end_time[0] + rolling_offset, *test_end_time[1:]),
"processor_kwargs": {
"fit_start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]),
"fit_end_time": datetime(train_end_time[0] + rolling_offset, *train_end_time[1:]),
},
},
segments={
"train": (
datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]),
datetime(train_end_time[0] + rolling_offset, *train_end_time[1:]),
),
"valid": (
datetime(valid_start_time[0] + rolling_offset, *valid_start_time[1:]),
datetime(valid_end_time[0] + rolling_offset, *valid_end_time[1:]),
),
"test": (
datetime(test_start_time[0] + rolling_offset, *test_start_time[1:]),
datetime(test_end_time[0] + rolling_offset, *test_end_time[1:]),
),
},
)
dataset.setup_data(
handler_kwargs={
"init_type": DataHandlerLP.IT_FIT_SEQ,
}
)
dtrain, dvalid, dtest = dataset.prepare(["train", "valid", "test"])
print(dtrain, dvalid, dtest)
## print or dump data
print(f"===========rolling{rolling_offset} end===========")
if __name__ == "__main__":
fire.Fire(RollingDataWorkflow)

View File

@@ -28,11 +28,17 @@
"import sys, site\n",
"from pathlib import Path\n",
"\n",
"################################# NOTE #################################\n",
"# Please be aware that if colab installs the latest numpy and pyqlib #\n",
"# in this cell, users should RESTART the runtime in order to run the #\n",
"# following cells successfully. #\n",
"########################################################################\n",
"\n",
"try:\n",
" import qlib\n",
"except ImportError:\n",
" # install qlib\n",
" ! pip install --upgrade numpy\n",
" ! pip install pyqlib\n",
" # reload\n",
" site.main()\n",
@@ -238,9 +244,7 @@
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"metadata": {},
"outputs": [],
"source": [
"from qlib.contrib.report import analysis_model, analysis_position\n",
@@ -359,7 +363,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
"version": "3.8.3"
},
"toc": {
"base_numbering": 1,
@@ -377,4 +381,4 @@
},
"nbformat": 4,
"nbformat_minor": 4
}
}

View File

@@ -3,6 +3,7 @@
__version__ = "0.6.3.99"
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
import os
@@ -10,12 +11,13 @@ import yaml
import logging
import platform
import subprocess
from pathlib import Path
from .log import get_module_logger
# init qlib
def init(default_conf="client", **kwargs):
from .config import C
from .log import get_module_logger
from .data.cache import H
H.clear()
@@ -48,7 +50,6 @@ def init(default_conf="client", **kwargs):
def _mount_nfs_uri(C):
from .log import get_module_logger
LOG = get_module_logger("mount nfs", level=logging.INFO)
@@ -151,3 +152,74 @@ def init_from_yaml_conf(conf_path, **kwargs):
config.update(kwargs)
default_conf = config.pop("default_conf", "client")
init(default_conf, **config)
def get_project_path(config_name="config.yaml", cur_path=None) -> Path:
"""
If users are building a project follow the following pattern.
- Qlib is a sub folder in project path
- There is a file named `config.yaml` in qlib.
For example:
If your project file system stucuture follows such a pattern
<project_path>/
- config.yaml
- ...some folders...
- qlib/
This folder will return <project_path>
NOTE: link is not supported here.
This method is often used when
- user want to use a relative config path instead of hard-coding qlib config path in code
Raises
------
FileNotFoundError:
If project path is not found
"""
if cur_path is None:
cur_path = Path(__file__).absolute().resolve()
while True:
if (cur_path / config_name).exists():
return cur_path
if cur_path == cur_path.parent:
raise FileNotFoundError("We can't find the project path")
cur_path = cur_path.parent
def auto_init(**kwargs):
"""
This function will init qlib automatically with following priority
- Find the project configuration and init qlib
- The parsing process will be affected by the `conf_type` of the configuration file
- Init qlib with default config
"""
try:
pp = get_project_path(cur_path=kwargs.pop("cur_path", None))
except FileNotFoundError:
init(**kwargs)
else:
conf_pp = pp / "config.yaml"
with conf_pp.open() as f:
conf = yaml.safe_load(f)
conf_type = conf.get("conf_type", "origin")
if conf_type == "origin":
# The type of config is just like original qlib config
init_from_yaml_conf(conf_pp, **kwargs)
elif conf_type == "ref":
# This config type will be more convenient in following scenario
# - There is a shared configure file and you don't want to edit it inplace.
# - The shared configure may be updated later and you don't want to copy it.
# - You have some customized config.
qlib_conf_path = conf["qlib_cfg"]
qlib_conf_update = conf.get("qlib_cfg_update")
init_from_yaml_conf(qlib_conf_path, **qlib_conf_update, **kwargs)
logger = get_module_logger("Initialization")
logger.info(f"Auto load project config: {conf_pp}")

View File

@@ -33,6 +33,9 @@ class Config:
raise AttributeError(f"No such {attr} in self._config")
def get(self, key, default=None):
return self.__dict__["_config"].get(key, default)
def __setitem__(self, key, value):
self.__dict__["_config"][key] = value
@@ -131,7 +134,7 @@ _default_config = {
},
"loggers": {"qlib": {"level": logging.DEBUG, "handlers": ["console"]}},
},
# Defatult config for experiment manager
# Default config for experiment manager
"exp_manager": {
"class": "MLflowExpManager",
"module_path": "qlib.workflow.expm",
@@ -140,6 +143,11 @@ _default_config = {
"default_exp_name": "Experiment",
},
},
# Default config for MongoDB
"mongo": {
"task_url": "mongodb://localhost:27017/",
"task_db_name": "default_task_db",
},
}
MODE_CONF = {
@@ -310,8 +318,22 @@ class QlibConfig(Config):
# clean up experiment when python program ends
experiment_exit_handler()
# Supporting user reset qlib version (useful when user want to connect to qlib server with old version)
self.reset_qlib_version()
self._registered = True
def reset_qlib_version(self):
import qlib
reset_version = self.get("qlib_reset_version", None)
if reset_version is not None:
qlib.__version__ = reset_version
else:
qlib.__version__ = getattr(qlib, "__version__bak")
# Due to a bug? that converting __version__ to _QlibConfig__version__bak
# Using __version__bak instead of __version__
@property
def registered(self):
return self._registered

View File

@@ -15,7 +15,8 @@ LOG = get_module_logger("backtest")
def backtest(pred, strategy, executor, trade_exchange, shift, verbose, account, benchmark, return_order):
"""Parameters
"""
Parameters
----------
pred : pandas.DataFrame
predict should has <datetime, instrument> index and one `score` column
@@ -124,7 +125,9 @@ def backtest(pred, strategy, executor, trade_exchange, shift, verbose, account,
def update_account(trade_account, trade_info, trade_exchange, trade_date):
"""Update the account and strategy
"""
Update the account and strategy
Parameters
----------
trade_account : Account()

View File

@@ -1,10 +1,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import pandas as pd
import copy
import pathlib
import pandas as pd
import numpy as np
from .order import Order
"""
@@ -128,7 +128,7 @@ class Position:
return self.position["cash"]
def get_stock_amount_dict(self):
"""generate stock amount dict {stock_id : amount of stock} """
"""generate stock amount dict {stock_id : amount of stock}"""
d = {}
stock_list = self.get_stock_list()
for stock_code in stock_list:

View File

@@ -26,6 +26,7 @@ def check_transform_proc(proc_l, fit_start_time, fit_end_time):
"fit_end_time": fit_end_time,
}
)
# FIXME: the `module_path` parameter is missed.
new_l.append({"class": klass.__name__, "kwargs": pkwargs})
else:
new_l.append(p)

View File

@@ -8,6 +8,59 @@ import pandas as pd
from typing import Tuple
def calc_long_short_prec(
pred: pd.Series, label: pd.Series, date_col="datetime", quantile: float = 0.2, dropna=False, is_alpha=False
) -> Tuple[pd.Series, pd.Series]:
"""
calculate the precision for long and short operation
:param pred/label: index is **pd.MultiIndex**, index name is **[datetime, instruments]**; columns names is **[score]**.
.. code-block:: python
score
datetime instrument
2020-12-01 09:30:00 SH600068 0.553634
SH600195 0.550017
SH600276 0.540321
SH600584 0.517297
SH600715 0.544674
label :
label
date_col :
date_col
Returns
-------
(pd.Series, pd.Series)
long precision and short precision in time level
"""
if is_alpha:
label = label - label.mean(level=date_col)
if int(1 / quantile) >= len(label.index.get_level_values(1).unique()):
raise ValueError("Need more instruments to calculate precision")
df = pd.DataFrame({"pred": pred, "label": label})
if dropna:
df.dropna(inplace=True)
group = df.groupby(level=date_col)
N = lambda x: int(len(x) * quantile)
# find the top/low quantile of prediction and treat them as long and short target
long = group.apply(lambda x: x.nlargest(N(x), columns="pred").label).reset_index(level=0, drop=True)
short = group.apply(lambda x: x.nsmallest(N(x), columns="pred").label).reset_index(level=0, drop=True)
groupll = long.groupby(date_col)
l_dom = groupll.apply(lambda x: x > 0)
l_c = groupll.count()
groups = short.groupby(date_col)
s_dom = groups.apply(lambda x: x < 0)
s_c = groups.count()
return (l_dom.groupby(date_col).sum() / l_c), (s_dom.groupby(date_col).sum() / s_c)
def calc_ic(pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False) -> Tuple[pd.Series, pd.Series]:
"""calc_ic.

View File

@@ -0,0 +1,39 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
try:
from .catboost_model import CatBoostModel
except ModuleNotFoundError:
CatBoostModel = None
print("Please install necessary libs for CatBoostModel.")
try:
from .double_ensemble import DEnsembleModel
from .gbdt import LGBModel
except ModuleNotFoundError:
DEnsembleModel, LGBModel = None, None
print("Please install necessary libs for DEnsembleModel and LGBModel, such as lightgbm.")
try:
from .xgboost import XGBModel
except ModuleNotFoundError:
XGBModel = None
print("Please install necessary libs for XGBModel, such as xgboost.")
try:
from .linear import LinearModel
except ModuleNotFoundError:
LinearModel = None
print("Please install necessary libs for LinearModel, such as scipy and sklearn.")
# import pytorch models
try:
from .pytorch_alstm import ALSTM
from .pytorch_gats import GATs
from .pytorch_gru import GRU
from .pytorch_lstm import LSTM
from .pytorch_nn import DNNModelPytorch
from .pytorch_tabnet import TabnetModel
from .pytorch_sfm import SFM_Model
pytorch_classes = (ALSTM, GATs, GRU, LSTM, DNNModelPytorch, TabnetModel, SFM_Model)
except ModuleNotFoundError:
pytorch_classes = ()
print("Please install necessary libs for PyTorch models.")
all_model_classes = (CatBoostModel, DEnsembleModel, LGBModel, XGBModel, LinearModel) + pytorch_classes

View File

@@ -3,6 +3,7 @@
import numpy as np
import pandas as pd
from typing import Text, Union
from catboost import Pool, CatBoost
from catboost.utils import get_gpu_device_count
@@ -62,10 +63,10 @@ class CatBoostModel(Model):
evals_result["train"] = list(evals_result["learn"].values())[0]
evals_result["valid"] = list(evals_result["validation"].values())[0]
def predict(self, dataset):
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if self.model is None:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare("test", col_set="feature")
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
return pd.Series(self.model.predict(x_test.values), index=x_test.index)

View File

@@ -4,7 +4,7 @@
import lightgbm as lgb
import numpy as np
import pandas as pd
from typing import Text, Union
from ...model.base import Model
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
@@ -40,6 +40,10 @@ class DEnsembleModel(Model):
self.bins_sr = bins_sr
self.bins_fs = bins_fs
self.decay = decay
if sample_ratios is None: # the default values for sample_ratios
sample_ratios = [0.8, 0.7, 0.6, 0.5, 0.4]
if sub_weights is None: # the default values for sub_weights
sub_weights = [1.0, 0.2, 0.2, 0.2, 0.2, 0.2]
if not len(sample_ratios) == bins_fs:
raise ValueError("The length of sample_ratios should be equal to bins_fs.")
self.sample_ratios = sample_ratios
@@ -228,10 +232,10 @@ class DEnsembleModel(Model):
raise ValueError("not implemented yet")
return loss_curve
def predict(self, dataset):
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if self.ensemble is None:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
pred = pd.Series(np.zeros(x_test.shape[0]), index=x_test.index)
for i_sub, submodel in enumerate(self.ensemble):
feat_sub = self.sub_features[i_sub]

View File

@@ -4,7 +4,7 @@
import numpy as np
import pandas as pd
import lightgbm as lgb
from typing import Text, Union
from ...model.base import ModelFT
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
@@ -61,10 +61,10 @@ class LGBModel(ModelFT):
evals_result["train"] = list(evals_result["train"].values())[0]
evals_result["valid"] = list(evals_result["valid"].values())[0]
def predict(self, dataset):
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if self.model is None:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
return pd.Series(self.model.predict(x_test.values), index=x_test.index)
def finetune(self, dataset: DatasetH, num_boost_round=10, verbose_eval=20):

View File

@@ -0,0 +1,157 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import numpy as np
import pandas as pd
import lightgbm as lgb
from qlib.model.base import ModelFT
from qlib.data.dataset import DatasetH
from qlib.data.dataset.handler import DataHandlerLP
import warnings
class HFLGBModel(ModelFT):
"""LightGBM Model for high frequency prediction"""
def __init__(self, loss="mse", **kwargs):
if loss not in {"mse", "binary"}:
raise NotImplementedError
self.params = {"objective": loss, "verbosity": -1}
self.params.update(kwargs)
self.model = None
def _cal_signal_metrics(self, y_test, l_cut, r_cut):
"""
Calcaute the signal metrics by daily level
"""
up_pre, down_pre = [], []
up_alpha_ll, down_alpha_ll = [], []
for date in y_test.index.get_level_values(0).unique():
df_res = y_test.loc[date].sort_values("pred")
if int(l_cut * len(df_res)) < 10:
warnings.warn("Warning: threhold is too low or instruments number is not enough")
continue
top = df_res.iloc[: int(l_cut * len(df_res))]
bottom = df_res.iloc[int(r_cut * len(df_res)) :]
down_precision = len(top[top[top.columns[0]] < 0]) / (len(top))
up_precision = len(bottom[bottom[top.columns[0]] > 0]) / (len(bottom))
down_alpha = top[top.columns[0]].mean()
up_alpha = bottom[bottom.columns[0]].mean()
up_pre.append(up_precision)
down_pre.append(down_precision)
up_alpha_ll.append(up_alpha)
down_alpha_ll.append(down_alpha)
return (
np.array(up_pre).mean(),
np.array(down_pre).mean(),
np.array(up_alpha_ll).mean(),
np.array(down_alpha_ll).mean(),
)
def hf_signal_test(self, dataset: DatasetH, threhold=0.2):
"""
Test the sigal in high frequency test set
"""
if self.model == None:
raise ValueError("Model hasn't been trained yet")
df_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
df_test.dropna(inplace=True)
x_test, y_test = df_test["feature"], df_test["label"]
# Convert label into alpha
y_test[y_test.columns[0]] = y_test[y_test.columns[0]] - y_test[y_test.columns[0]].mean(level=0)
res = pd.Series(self.model.predict(x_test.values), index=x_test.index)
y_test["pred"] = res
up_p, down_p, up_a, down_a = self._cal_signal_metrics(y_test, threhold, 1 - threhold)
print("===============================")
print("High frequency signal test")
print("===============================")
print("Test set precision: ")
print("Positive precision: {}, Negative precision: {}".format(up_p, down_p))
print("Test Alpha Average in test set: ")
print("Positive average alpha: {}, Negative average alpha: {}".format(up_a, down_a))
def _prepare_data(self, dataset: DatasetH):
df_train, df_valid = dataset.prepare(
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
)
x_train, y_train = df_train["feature"], df_train["label"]
x_valid, y_valid = df_train["feature"], df_valid["label"]
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
l_name = df_train["label"].columns[0]
# Convert label into alpha
df_train["label"][l_name] = df_train["label"][l_name] - df_train["label"][l_name].mean(level=0)
df_valid["label"][l_name] = df_valid["label"][l_name] - df_valid["label"][l_name].mean(level=0)
mapping_fn = lambda x: 0 if x < 0 else 1
df_train["label_c"] = df_train["label"][l_name].apply(mapping_fn)
df_valid["label_c"] = df_valid["label"][l_name].apply(mapping_fn)
x_train, y_train = df_train["feature"], df_train["label_c"].values
x_valid, y_valid = df_valid["feature"], df_valid["label_c"].values
else:
raise ValueError("LightGBM doesn't support multi-label training")
dtrain = lgb.Dataset(x_train.values, label=y_train)
dvalid = lgb.Dataset(x_valid.values, label=y_valid)
return dtrain, dvalid
def fit(
self,
dataset: DatasetH,
num_boost_round=1000,
early_stopping_rounds=50,
verbose_eval=20,
evals_result=dict(),
**kwargs
):
dtrain, dvalid = self._prepare_data(dataset)
self.model = lgb.train(
self.params,
dtrain,
num_boost_round=num_boost_round,
valid_sets=[dtrain, dvalid],
valid_names=["train", "valid"],
early_stopping_rounds=early_stopping_rounds,
verbose_eval=verbose_eval,
evals_result=evals_result,
**kwargs
)
evals_result["train"] = list(evals_result["train"].values())[0]
evals_result["valid"] = list(evals_result["valid"].values())[0]
def predict(self, dataset):
if self.model is None:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
return pd.Series(self.model.predict(x_test.values), index=x_test.index)
def finetune(self, dataset: DatasetH, num_boost_round=10, verbose_eval=20):
"""
finetune model
Parameters
----------
dataset : DatasetH
dataset for finetuning
num_boost_round : int
number of round to finetune model
verbose_eval : int
verbose level
"""
# Based on existing model and finetune by train more rounds
dtrain, _ = self._prepare_data(dataset)
self.model = lgb.train(
self.params,
dtrain,
num_boost_round=num_boost_round,
init_model=self.model,
valid_sets=[dtrain],
valid_names=["train"],
verbose_eval=verbose_eval,
)

View File

@@ -3,7 +3,7 @@
import numpy as np
import pandas as pd
from typing import Text, Union
from scipy.optimize import nnls
from sklearn.linear_model import LinearRegression, Ridge, Lasso
@@ -84,8 +84,8 @@ class LinearModel(Model):
self.coef_ = coef
self.intercept_ = 0.0
def predict(self, dataset):
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if self.coef_ is None:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
return pd.Series(x_test.values @ self.coef_ + self.intercept_, index=x_test.index)

View File

@@ -8,13 +8,9 @@ from __future__ import print_function
import os
import numpy as np
import pandas as pd
from typing import Text, Union
import copy
from ...utils import (
unpack_archive_with_buffer,
save_multiple_parts_file,
get_or_create_path,
drop_nan_by_y_index,
)
from ...utils import get_or_create_path
from ...log import get_module_logger
import torch
@@ -273,11 +269,11 @@ class ALSTM(Model):
if self.use_gpu:
torch.cuda.empty_cache()
def predict(self, dataset):
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if not self.fitted:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare("test", col_set="feature")
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
index = x_test.index
self.ALSTM_model.eval()
x_values = x_test.values

View File

@@ -8,13 +8,9 @@ from __future__ import print_function
import os
import numpy as np
import pandas as pd
from typing import Text, Union
import copy
from ...utils import (
unpack_archive_with_buffer,
save_multiple_parts_file,
get_or_create_path,
drop_nan_by_y_index,
)
from ...utils import get_or_create_path
from ...log import get_module_logger
import torch
@@ -264,11 +260,11 @@ class ALSTM(Model):
if self.use_gpu:
torch.cuda.empty_cache()
def predict(self, dataset):
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if not self.fitted:
raise ValueError("model is not fitted yet!")
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
dl_test = dataset.prepare(segment, col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
dl_test.config(fillna_type="ffill+bfill")
test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)
self.ALSTM_model.eval()

View File

@@ -8,13 +8,9 @@ from __future__ import print_function
import os
import numpy as np
import pandas as pd
from typing import Text, Union
import copy
from ...utils import (
unpack_archive_with_buffer,
save_multiple_parts_file,
get_or_create_path,
drop_nan_by_y_index,
)
from ...utils import get_or_create_path
from ...log import get_module_logger
import torch
import torch.nn as nn
@@ -83,7 +79,6 @@ class GATs(Model):
self.with_pretrain = with_pretrain
self.model_path = model_path
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
self.use_gpu = torch.cuda.is_available()
self.seed = seed
self.logger.info(
@@ -310,11 +305,11 @@ class GATs(Model):
if self.use_gpu:
torch.cuda.empty_cache()
def predict(self, dataset):
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if not self.fitted:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare("test", col_set="feature")
x_test = dataset.prepare(segment, col_set="feature")
index = x_test.index
self.GAT_model.eval()
x_values = x_test.values

View File

@@ -9,12 +9,7 @@ import os
import numpy as np
import pandas as pd
import copy
from ...utils import (
unpack_archive_with_buffer,
save_multiple_parts_file,
get_or_create_path,
drop_nan_by_y_index,
)
from ...utils import get_or_create_path
from ...log import get_module_logger
import torch
import torch.nn as nn

View File

@@ -8,13 +8,9 @@ from __future__ import print_function
import os
import numpy as np
import pandas as pd
from typing import Text, Union
import copy
from ...utils import (
unpack_archive_with_buffer,
save_multiple_parts_file,
get_or_create_path,
drop_nan_by_y_index,
)
from ...utils import get_or_create_path
from ...log import get_module_logger
import torch
@@ -273,11 +269,11 @@ class GRU(Model):
if self.use_gpu:
torch.cuda.empty_cache()
def predict(self, dataset):
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if not self.fitted:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare("test", col_set="feature")
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
index = x_test.index
self.gru_model.eval()
x_values = x_test.values

View File

@@ -9,12 +9,7 @@ import os
import numpy as np
import pandas as pd
import copy
from ...utils import (
unpack_archive_with_buffer,
save_multiple_parts_file,
get_or_create_path,
drop_nan_by_y_index,
)
from ...utils import get_or_create_path
from ...log import get_module_logger
import torch
@@ -126,8 +121,8 @@ class GRU(Model):
num_layers=self.num_layers,
dropout=self.dropout,
)
self.logger.info("model:\n{:}".format(self.gru_model))
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.gru_model)))
self.logger.info("model:\n{:}".format(self.GRU_model))
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.GRU_model)))
if optimizer.lower() == "adam":
self.train_optimizer = optim.Adam(self.GRU_model.parameters(), lr=self.lr)

View File

@@ -8,13 +8,9 @@ from __future__ import print_function
import os
import numpy as np
import pandas as pd
from typing import Text, Union
import copy
from ...utils import (
unpack_archive_with_buffer,
save_multiple_parts_file,
get_or_create_path,
drop_nan_by_y_index,
)
from ...utils import get_or_create_path
from ...log import get_module_logger
import torch
@@ -268,11 +264,11 @@ class LSTM(Model):
if self.use_gpu:
torch.cuda.empty_cache()
def predict(self, dataset):
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if not self.fitted:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare("test", col_set="feature")
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
index = x_test.index
self.lstm_model.eval()
x_values = x_test.values
@@ -280,17 +276,13 @@ class LSTM(Model):
preds = []
for begin in range(sample_num)[:: self.batch_size]:
if sample_num - begin < self.batch_size:
end = sample_num
else:
end = begin + self.batch_size
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
with torch.no_grad():
pred = self.lstm_model(x_batch).detach().cpu().numpy()
preds.append(pred)
return pd.Series(np.concatenate(preds), index=index)

View File

@@ -9,12 +9,7 @@ import os
import numpy as np
import pandas as pd
import copy
from ...utils import (
unpack_archive_with_buffer,
save_multiple_parts_file,
get_or_create_path,
drop_nan_by_y_index,
)
from ...utils import get_or_create_path
from ...log import get_module_logger
import torch

View File

@@ -8,6 +8,7 @@ from __future__ import print_function
import os
import numpy as np
import pandas as pd
from typing import Text, Union
from sklearn.metrics import roc_auc_score, mean_squared_error
import torch
@@ -18,7 +19,7 @@ from .pytorch_utils import count_parameters
from ...model.base import Model
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, get_or_create_path, drop_nan_by_y_index
from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, get_or_create_path
from ...log import get_module_logger
from ...workflow import R
@@ -48,8 +49,8 @@ class DNNModelPytorch(Model):
def __init__(
self,
input_dim,
output_dim,
input_dim=360,
output_dim=1,
layers=(256,),
lr=0.001,
max_steps=300,
@@ -271,13 +272,12 @@ class DNNModelPytorch(Model):
else:
raise NotImplementedError("loss {} is not supported!".format(loss_type))
def predict(self, dataset):
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if not self.fitted:
raise ValueError("model is not fitted yet!")
x_test_pd = dataset.prepare("test", col_set="feature")
x_test_pd = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
x_test = torch.from_numpy(x_test_pd.values).float().to(self.device)
self.dnn_model.eval()
with torch.no_grad():
preds = self.dnn_model(x_test).detach().cpu().numpy()
return pd.Series(np.squeeze(preds), index=x_test_pd.index)

View File

@@ -7,13 +7,9 @@ from __future__ import print_function
import os
import numpy as np
import pandas as pd
from typing import Text, Union
import copy
from ...utils import (
unpack_archive_with_buffer,
save_multiple_parts_file,
get_or_create_path,
drop_nan_by_y_index,
)
from ...utils import get_or_create_path
from ...log import get_module_logger
import torch
@@ -442,11 +438,11 @@ class SFM(Model):
raise ValueError("unknown metric `%s`" % self.metric)
def predict(self, dataset):
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if not self.fitted:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare("test", col_set="feature")
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
index = x_test.index
self.sfm_model.eval()
x_values = x_test.values
@@ -459,10 +455,7 @@ class SFM(Model):
else:
end = begin + self.batch_size
x_batch = torch.from_numpy(x_values[begin:end]).float()
if self.device != "cpu":
x_batch = x_batch.to(self.device)
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
with torch.no_grad():
pred = self.sfm_model(x_batch).detach().cpu().numpy()

View File

@@ -6,13 +6,9 @@ from __future__ import print_function
import os
import numpy as np
import pandas as pd
from typing import Text, Union
import copy
from ...utils import (
unpack_archive_with_buffer,
save_multiple_parts_file,
get_or_create_path,
drop_nan_by_y_index,
)
from ...utils import get_or_create_path
from ...log import get_module_logger
import torch
@@ -217,11 +213,11 @@ class TabnetModel(Model):
if self.use_gpu:
torch.cuda.empty_cache()
def predict(self, dataset):
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if not self.fitted:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
index = x_test.index
self.tabnet_model.eval()
x_values = torch.from_numpy(x_test.values)

View File

@@ -4,7 +4,7 @@
import numpy as np
import pandas as pd
import xgboost as xgb
from typing import Text, Union
from ...model.base import Model
from ...data.dataset import DatasetH
from ...data.dataset.handler import DataHandlerLP
@@ -57,8 +57,8 @@ class XGBModel(Model):
evals_result["train"] = list(evals_result["train"].values())[0]
evals_result["valid"] = list(evals_result["valid"].values())[0]
def predict(self, dataset):
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
if self.model is None:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare("test", col_set="feature")
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
return pd.Series(self.model.predict(xgb.DMatrix(x_test.values)), index=x_test.index)

View File

@@ -214,7 +214,7 @@ def cumulative_return_graph(
features_df = D.features(D.instruments('csi500'), ['Ref($close, -1)/$close - 1'], pred_df_dates.min(), pred_df_dates.max())
features_df.columns = ['label']
qcr.cumulative_return_graph(positions, report_normal_df, features_df)
qcr.analysis_position.cumulative_return_graph(positions, report_normal_df, features_df)
Graph desc:

View File

@@ -94,7 +94,7 @@ def rank_label_graph(
features_df = D.features(D.instruments('csi500'), ['Ref($close, -1)/$close-1'], pred_df_dates.min(), pred_df_dates.max())
features_df.columns = ['label']
qcr.rank_label_graph(positions, features_df, pred_df_dates.min(), pred_df_dates.max())
qcr.analysis_position.rank_label_graph(positions, features_df, pred_df_dates.min(), pred_df_dates.max())
:param position: position data; **qlib.contrib.backtest.backtest.backtest** result.

View File

@@ -186,7 +186,7 @@ def report_graph(report_df: pd.DataFrame, show_notebook: bool = True) -> [list,
report_normal_df, _ = backtest(pred_df, strategy, **bparas)
qcr.report_graph(report_normal_df)
qcr.analysis_position.report_graph(report_normal_df)
:param report_df: **df.index.name** must be **date**, **df.columns** must contain **return**, **turnover**, **cost**, **bench**.

View File

@@ -18,7 +18,7 @@ from ...utils import get_module_by_module_path
class BaseGraph:
""""""
""" """
_name = None

View File

@@ -251,7 +251,7 @@ class TopkDropoutStrategy(BaseStrategy, ListAdjustTimer):
def generate_order_list(self, score_series, current, trade_exchange, pred_date, trade_date):
"""
Gnererate order list according to score_series at trade_date, will not change current.
Generate order list according to score_series at trade_date, will not change current.
Parameters
-----------

View File

@@ -0,0 +1,4 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .record_temp import MultiSegRecord
from .record_temp import SignalMseRecord

View File

@@ -1,16 +1,60 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import re
import logging
import pandas as pd
from sklearn.metrics import mean_squared_error
from pprint import pprint
import numpy as np
from sklearn.metrics import mean_squared_error
from typing import Dict, Text, Any
from ...contrib.eva.alpha import calc_ic
from ...workflow.record_temp import RecordTemp
from ...workflow.record_temp import SignalRecord
from ...data import dataset as qlib_dataset
from ...log import get_module_logger
logger = get_module_logger("workflow", "INFO")
logger = get_module_logger("workflow", logging.INFO)
class MultiSegRecord(RecordTemp):
"""
This is the multiple segments signal record class that generates the signal prediction.
This class inherits the ``RecordTemp`` class.
"""
def __init__(self, model, dataset, recorder=None):
super().__init__(recorder=recorder)
if not isinstance(dataset, qlib_dataset.DatasetH):
raise ValueError("The type of dataset is not DatasetH instead of {:}".format(type(dataset)))
self.model = model
self.dataset = dataset
def generate(self, segments: Dict[Text, Any], save: bool = False):
for key, segment in segments.items():
predics = self.model.predict(self.dataset, segment)
if isinstance(predics, pd.Series):
predics = predics.to_frame("score")
labels = self.dataset.prepare(
segments=segment, col_set="label", data_key=qlib_dataset.handler.DataHandlerLP.DK_R
)
# Compute the IC and Rank IC
ic, ric = calc_ic(predics.iloc[:, 0], labels.iloc[:, 0])
results = {"all-IC": ic, "mean-IC": ic.mean(), "all-Rank-IC": ric, "mean-Rank-IC": ric.mean()}
logger.info("--- Results for {:} ({:}) ---".format(key, segment))
ic_x100, ric_x100 = ic * 100, ric * 100
logger.info("IC: {:.4f}%".format(ic_x100.mean()))
logger.info("ICIR: {:.4f}%".format(ic_x100.mean() / ic_x100.std()))
logger.info("Rank IC: {:.4f}%".format(ric_x100.mean()))
logger.info("Rank ICIR: {:.4f}%".format(ric_x100.mean() / ric_x100.std()))
if save:
save_name = "results-{:}.pkl".format(key)
self.recorder.save_objects(**{save_name: results})
logger.info(
"The record '{:}' has been saved as the artifact of the Experiment {:}".format(
save_name, self.recorder.experiment_id
)
)
class SignalMseRecord(SignalRecord):
@@ -38,7 +82,7 @@ class SignalMseRecord(SignalRecord):
objects = {"mse.pkl": mse, "rmse.pkl": np.sqrt(mse)}
self.recorder.log_metrics(**metrics)
self.recorder.save_objects(**objects, artifact_path=self.get_path())
pprint(metrics)
logger.info("The evaluation results in SignalMseRecord is {:}".format(metrics))
def list(self):
paths = [self.get_path("mse.pkl"), self.get_path("rmse.pkl")]

View File

@@ -1037,7 +1037,8 @@ class ClientProvider(BaseProvider):
self.logger = get_module_logger(self.__class__.__name__)
if isinstance(Cal, ClientCalendarProvider):
Cal.set_conn(self.client)
Inst.set_conn(self.client)
if isinstance(Inst, ClientInstrumentProvider):
Inst.set_conn(self.client)
if hasattr(DatasetD, "provider"):
DatasetD.provider.set_conn(self.client)
else:

View File

@@ -3,6 +3,7 @@ from typing import Union, List, Tuple, Dict, Text, Optional
from ...utils import init_instance_by_config, np_ffill
from ...log import get_module_logger
from .handler import DataHandler, DataHandlerLP
from copy import deepcopy
from inspect import getfullargspec
import pandas as pd
import numpy as np
@@ -16,22 +17,28 @@ class Dataset(Serializable):
Preparing data for model training and inferencing.
"""
def __init__(self, *args, **kwargs):
def __init__(self, **kwargs):
"""
init is designed to finish following steps:
- init the sub instance and the state of the dataset(info to prepare the data)
- The name of essential state for preparing data should not start with '_' so that it could be serialized on disk when serializing.
- setup data
- The data related attributes' names should start with '_' so that it will not be saved on disk when serializing.
- initialize the state of the dataset(info to prepare the data)
- The name of essential state for preparing data should not start with '_' so that it could be serialized on disk when serializing.
The data could specify the info to caculate the essential data for preparation
The data could specify the info to calculate the essential data for preparation
"""
self.setup_data(*args, **kwargs)
self.setup_data(**kwargs)
super().__init__()
def setup_data(self, *args, **kwargs):
def config(self, **kwargs):
"""
config is designed to configure and parameters that cannot be learned from the data
"""
super().config(**kwargs)
def setup_data(self, **kwargs):
"""
Setup the data.
@@ -39,7 +46,7 @@ class Dataset(Serializable):
- User have a Dataset object with learned status on disk.
- User load the Dataset object from the disk(Note the init function is skiped).
- User load the Dataset object from the disk.
- User call `setup_data` to load new data.
@@ -47,7 +54,7 @@ class Dataset(Serializable):
"""
pass
def prepare(self, *args, **kwargs) -> object:
def prepare(self, **kwargs) -> object:
"""
The type of dataset depends on the model. (It could be pd.DataFrame, pytorch.DataLoader, etc.)
The parameters should specify the scope for the prepared data
@@ -76,44 +83,7 @@ class DatasetH(Dataset):
- The processing is related to data split.
"""
def init(self, handler_kwargs: dict = None, segment_kwargs: dict = None):
"""
Initialize the DatasetH
Parameters
----------
handler_kwargs : dict
Config of DataHanlder, which could include the following arguments:
- arguments of DataHandler.conf_data, such as 'instruments', 'start_time' and 'end_time'.
- arguments of DataHandler.init, such as 'enable_cache', etc.
segment_kwargs : dict
Config of segments which is same as 'segments' in DatasetH.setup_data
"""
if handler_kwargs:
if not isinstance(handler_kwargs, dict):
raise TypeError(f"param handler_kwargs must be type dict, not {type(handler_kwargs)}")
kwargs_init = {}
kwargs_conf_data = {}
conf_data_arg = {"instruments", "start_time", "end_time"}
for k, v in handler_kwargs.items():
if k in conf_data_arg:
kwargs_conf_data.update({k: v})
else:
kwargs_init.update({k: v})
self.handler.conf_data(**kwargs_conf_data)
self.handler.init(**kwargs_init)
if segment_kwargs:
if not isinstance(segment_kwargs, dict):
raise TypeError(f"param handler_kwargs must be type dict, not {type(segment_kwargs)}")
self.segments = segment_kwargs.copy()
def setup_data(self, handler: Union[Dict, DataHandler], segments: Dict[Text, Tuple]):
def __init__(self, handler: Union[Dict, DataHandler], segments: Dict[Text, Tuple], **kwargs):
"""
Setup the underlying data.
@@ -122,7 +92,7 @@ class DatasetH(Dataset):
handler : Union[dict, DataHandler]
handler could be:
- insntance of `DataHandler`
- instance of `DataHandler`
- config of `DataHandler`. Please refer to `DataHandler`
@@ -142,8 +112,52 @@ class DatasetH(Dataset):
'outsample': ("2017-01-01", "2020-08-01",),
}
"""
self.handler = init_instance_by_config(handler, accept_types=DataHandler)
self.handler: DataHandler = init_instance_by_config(handler, accept_types=DataHandler)
self.segments = segments.copy()
self.fetch_kwargs = {}
super().__init__(**kwargs)
def config(self, handler_kwargs: dict = None, **kwargs):
"""
Initialize the DatasetH
Parameters
----------
handler_kwargs : dict
Config of DataHandler, which could include the following arguments:
- arguments of DataHandler.conf_data, such as 'instruments', 'start_time' and 'end_time'.
kwargs : dict
Config of DatasetH, such as
- segments : dict
Config of segments which is same as 'segments' in self.__init__
"""
if handler_kwargs is not None:
self.handler.config(**handler_kwargs)
if "segments" in kwargs:
self.segments = deepcopy(kwargs.pop("segments"))
super().config(**kwargs)
def setup_data(self, handler_kwargs: dict = None, **kwargs):
"""
Setup the Data
Parameters
----------
handler_kwargs : dict
init arguments of DataHandler, which could include the following arguments:
- init_type : Init Type of Handler
- enable_cache : whether to enable cache
"""
super().setup_data(**kwargs)
if handler_kwargs is not None:
self.handler.setup_data(**handler_kwargs)
def __repr__(self):
return "{name}(handler={handler}, segments={segments})".format(
@@ -158,7 +172,10 @@ class DatasetH(Dataset):
----------
slc : slice
"""
return self.handler.fetch(slc, **kwargs)
if hasattr(self, "fetch_kwargs"):
return self.handler.fetch(slc, **kwargs, **self.fetch_kwargs)
else:
return self.handler.fetch(slc, **kwargs)
def prepare(
self,
@@ -186,6 +203,12 @@ class DatasetH(Dataset):
The data to fetch: DK_*
Default is DK_I, which indicate fetching data for **inference**.
kwargs :
The parameters that kwargs may contain:
flt_col : str
It only exists in TSDatasetH, can be used to add a column of data(True or False) to filter data.
This parameter is only supported when it is an instance of TSDatasetH.
Returns
-------
Union[List[pd.DataFrame], pd.DataFrame]:
@@ -218,7 +241,7 @@ class TSDataSampler:
(T)ime-(S)eries DataSampler
This is the result of TSDatasetH
It works like `torch.data.utils.Dataset`, it provides a very convient interface for constructing time-series
It works like `torch.data.utils.Dataset`, it provides a very convenient interface for constructing time-series
dataset based on tabular data.
If user have further requirements for processing data, user could process them based on `TSDataSampler` or create
@@ -230,7 +253,9 @@ class TSDataSampler:
"""
def __init__(self, data: pd.DataFrame, start, end, step_len: int, fillna_type: str = "none"):
def __init__(
self, data: pd.DataFrame, start, end, step_len: int, fillna_type: str = "none", dtype=None, flt_data=None
):
"""
Build a dataset which looks like torch.data.utils.Dataset.
@@ -252,6 +277,11 @@ class TSDataSampler:
ffill with previous sample
ffill+bfill:
ffill with previous samples first and fill with later samples second
flt_data : pd.Series
a column of data(True or False) to filter data.
None:
kepp all data
"""
self.start = start
self.end = end
@@ -259,23 +289,51 @@ class TSDataSampler:
self.fillna_type = fillna_type
assert get_level_index(data, "datetime") == 0
self.data = lazy_sort_index(data)
self.data_arr = np.array(self.data) # Get index from numpy.array will much faster than DataFrame.values!
# NOTE: append last line with full NaN for better performance in `__getitem__`
self.data_arr = np.append(self.data_arr, np.full((1, self.data_arr.shape[1]), np.nan), axis=0)
kwargs = {"object": self.data}
if dtype is not None:
kwargs["dtype"] = dtype
self.data_arr = np.array(**kwargs) # Get index from numpy.array will much faster than DataFrame.values!
# NOTE:
# - append last line with full NaN for better performance in `__getitem__`
# - Keep the same dtype will result in a better performance
self.data_arr = np.append(
self.data_arr, np.full((1, self.data_arr.shape[1]), np.nan, dtype=self.data_arr.dtype), axis=0
)
self.nan_idx = -1 # The last line is all NaN
# the data type will be changed
# The index of usable data is between start_idx and end_idx
self.start_idx, self.end_idx = self.data.index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end))
self.idx_df, self.idx_map = self.build_index(self.data)
self.data_index = deepcopy(self.data.index)
if flt_data is not None:
self.flt_data = np.array(flt_data.reindex(self.data_index)).reshape(-1)
self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map)
self.data_index = self.data_index[np.where(self.flt_data == True)[0]]
self.start_idx, self.end_idx = self.data_index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end))
self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance
del self.data # save memory
@staticmethod
def flt_idx_map(flt_data, idx_map):
idx = 0
new_idx_map = {}
for i, exist in enumerate(flt_data):
if exist:
new_idx_map[idx] = idx_map[i]
idx += 1
return new_idx_map
def get_index(self):
"""
Get the pandas index of the data, it will be useful in following scenarios
- Special sampler will be used (e.g. user want to sample day by day)
"""
return self.data.index[self.start_idx : self.end_idx]
return self.data_index[self.start_idx : self.end_idx]
def config(self, **kwargs):
# Config the attributes
@@ -419,7 +477,7 @@ class TSDatasetH(DatasetH):
(T)ime-(S)eries Dataset (H)andler
Covnert the tabular data to Time-Series data
Convert the tabular data to Time-Series data
Requirements analysis
@@ -433,18 +491,22 @@ class TSDatasetH(DatasetH):
- The dimension of a batch of data <batch_idx, feature, timestep>
"""
def __init__(self, step_len=30, *args, **kwargs):
def __init__(self, step_len=30, **kwargs):
self.step_len = step_len
super().__init__(*args, **kwargs)
super().__init__(**kwargs)
def setup_data(self, *args, **kwargs):
super().setup_data(*args, **kwargs)
def config(self, **kwargs):
if "step_len" in kwargs:
self.step_len = kwargs.pop("step_len")
super().config(**kwargs)
def setup_data(self, **kwargs):
super().setup_data(**kwargs)
cal = self.handler.fetch(col_set=self.handler.CS_RAW).index.get_level_values("datetime").unique()
cal = sorted(cal)
# Get the datatime index for building timestamp
self.cal = cal
def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
def _prepare_raw_seg(self, slc: slice, **kwargs) -> pd.DataFrame:
# Dataset decide how to slice data(Get more data for timeseries).
start, end = slc.start, slc.stop
start_idx = bisect.bisect_left(self.cal, pd.Timestamp(start))
@@ -453,6 +515,25 @@ class TSDatasetH(DatasetH):
# TSDatasetH will retrieve more data for complete
data = super()._prepare_seg(slice(pad_start, end), **kwargs)
return data
tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len)
def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
"""
split the _prepare_raw_seg is to leave a hook for data preprocessing before creating processing data
"""
dtype = kwargs.pop("dtype", None)
start, end = slc.start, slc.stop
flt_col = kwargs.pop("flt_col", None)
# TSDatasetH will retrieve more data for complete
data = self._prepare_raw_seg(slc, **kwargs)
flt_kwargs = deepcopy(kwargs)
if flt_col is not None:
flt_kwargs["col_set"] = flt_col
flt_data = self._prepare_raw_seg(slc, **flt_kwargs)
assert len(flt_data.columns) == 1
else:
flt_data = None
tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len, dtype=dtype, flt_data=flt_data)
return tsds

View File

@@ -6,7 +6,8 @@ import abc
import bisect
import logging
import warnings
from typing import Union, Tuple, List, Iterator, Optional
from inspect import getfullargspec
from typing import Callable, Union, Tuple, List, Iterator, Optional
import pandas as pd
import numpy as np
@@ -16,7 +17,7 @@ from ...data import D
from ...config import C
from ...utils import parse_config, transform_end_date, init_instance_by_config
from ...utils.serial import Serializable
from .utils import get_level_index, fetch_df_by_index
from .utils import fetch_df_by_index
from pathlib import Path
from .loader import DataLoader
@@ -35,7 +36,7 @@ class DataHandler(Serializable):
The data handler try to maintain a handler with 2 level.
`datetime` & `instruments`.
Any order of the index level can be suported (The order will be implied in the data).
Any order of the index level can be supported (The order will be implied in the data).
The order <`datetime`, `instruments`> will be used when the dataframe index name is missed.
Example of the data:
@@ -50,6 +51,9 @@ class DataHandler(Serializable):
SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042
SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
Tips for improving the performance of datahandler
- Fetching data with `col_set=CS_RAW` will return the raw data and may avoid pandas from copying the data when calling `loc`
"""
def __init__(
@@ -57,7 +61,7 @@ class DataHandler(Serializable):
instruments=None,
start_time=None,
end_time=None,
data_loader: Tuple[dict, str, DataLoader] = None,
data_loader: Union[dict, str, DataLoader] = None,
init_data=True,
fetch_orig=True,
):
@@ -70,10 +74,10 @@ class DataHandler(Serializable):
start_time of the original data.
end_time :
end_time of the original data.
data_loader : Tuple[dict, str, DataLoader]
data_loader : Union[dict, str, DataLoader]
data loader to load the data.
init_data :
intialize the original data in the constructor.
initialize the original data in the constructor.
fetch_orig : bool
Return the original data instead of copy if possible.
"""
@@ -99,10 +103,10 @@ class DataHandler(Serializable):
self.fetch_orig = fetch_orig
if init_data:
with TimeInspector.logt("Init data"):
self.init()
self.setup_data()
super().__init__()
def conf_data(self, **kwargs):
def config(self, **kwargs):
"""
configuration of data.
# what data to be loaded from data source
@@ -115,13 +119,16 @@ class DataHandler(Serializable):
for k, v in kwargs.items():
if k in attr_list:
setattr(self, k, v)
else:
raise KeyError("Such config is not supported.")
def init(self, enable_cache: bool = False):
for attr in attr_list:
if attr in kwargs:
kwargs.pop(attr)
super().config(**kwargs)
def setup_data(self, enable_cache: bool = False):
"""
initialize the data.
In case of running intialization for multiple time, it will do nothing for the second time.
Set Up the data in case of running initialization for multiple time
It is responsible for maintaining following variable
1) self._data
@@ -159,6 +166,7 @@ class DataHandler(Serializable):
level: Union[str, int] = "datetime",
col_set: Union[str, List[str]] = CS_ALL,
squeeze: bool = False,
proc_func: Callable = None,
) -> pd.DataFrame:
"""
fetch data from underlying data source
@@ -181,6 +189,14 @@ class DataHandler(Serializable):
- if isinstance(col_set, List[str]):
select several sets of meaningful columns, the returned data has multiple levels
proc_func: Callable
- Give a hook for processing data before fetching
- An example to explain the necessity of the hook:
- A Dataset learned some processors to process data which is related to data segmentation
- It will apply them every time when preparing data.
- The learned processor require the dataframe remains the same format when fitting and applying
- However the data format will change according to the parameters.
- So the processors should be applied to the underlayer data.
squeeze : bool
whether squeeze columns and index
@@ -189,8 +205,15 @@ class DataHandler(Serializable):
-------
pd.DataFrame.
"""
if proc_func is None:
df = self._data
else:
# FIXME: fetching by time first will be more friendly to `proc_func`
# Copy in case of `proc_func` changing the data inplace....
df = proc_func(fetch_df_by_index(self._data, selector, level, fetch_orig=self.fetch_orig).copy())
# Fetch column first will be more friendly to SepDataFrame
df = self._fetch_df_by_col(self._data, col_set)
df = self._fetch_df_by_col(df, col_set)
df = fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig)
if squeeze:
# squeeze columns
@@ -257,6 +280,10 @@ class DataHandler(Serializable):
class DataHandlerLP(DataHandler):
"""
DataHandler with **(L)earnable (P)rocessor**
Tips to improving the performance of data handler
- To reduce the memory cost
- `drop_raw=True`: this will modify the data inplace on raw data;
"""
# data key
@@ -278,7 +305,7 @@ class DataHandlerLP(DataHandler):
instruments=None,
start_time=None,
end_time=None,
data_loader: Tuple[dict, str, DataLoader] = None,
data_loader: Union[dict, str, DataLoader] = None,
infer_processors=[],
learn_processors=[],
process_type=PTYPE_A,
@@ -405,14 +432,28 @@ class DataHandlerLP(DataHandler):
if self.drop_raw:
del self._data
def config(self, processor_kwargs: dict = None, **kwargs):
"""
configuration of data.
# what data to be loaded from data source
This method will be used when loading pickled handler from dataset.
The data will be initialized with different time range.
"""
super().config(**kwargs)
if processor_kwargs is not None:
for processor in self.get_all_processors():
processor.config(**processor_kwargs)
# init type
IT_FIT_SEQ = "fit_seq" # the input of `fit` will be the output of the previous processor
IT_FIT_IND = "fit_ind" # the input of `fit` will be the original df
IT_LS = "load_state" # The state of the object has been load by pickle
def init(self, init_type: str = IT_FIT_SEQ, enable_cache: bool = False):
def setup_data(self, init_type: str = IT_FIT_SEQ, **kwargs):
"""
Initialize the data of Qlib
Set up the data in case of running initialization for multiple time
Parameters
----------
@@ -427,7 +468,7 @@ class DataHandlerLP(DataHandler):
when we call `init` next time
"""
# init raw data
super().init(enable_cache=enable_cache)
super().setup_data(**kwargs)
with TimeInspector.logt("fit & process data"):
if init_type == DataHandlerLP.IT_FIT_IND:
@@ -456,6 +497,7 @@ class DataHandlerLP(DataHandler):
level: Union[str, int] = "datetime",
col_set=DataHandler.CS_ALL,
data_key: str = DK_I,
proc_func: Callable = None,
) -> pd.DataFrame:
"""
fetch data from underlying data source
@@ -470,12 +512,18 @@ class DataHandlerLP(DataHandler):
select a set of meaningful columns.(e.g. features, columns).
data_key : str
the data to fetch: DK_*.
proc_func: Callable
please refer to the doc of DataHandler.fetch
Returns
-------
pd.DataFrame:
"""
df = self._get_df_by_key(data_key)
if proc_func is not None:
# FIXME: fetch by time first will be more friendly to proc_func
# Copy incase of `proc_func` changing the data inplace....
df = proc_func(fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig).copy())
# Fetch column first will be more friendly to SepDataFrame
df = self._fetch_df_by_col(df, col_set)
return fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig)

View File

@@ -13,6 +13,7 @@ from qlib.data import D
from qlib.data import filter as filter_module
from qlib.data.filter import BaseDFilter
from qlib.utils import load_dataset, init_instance_by_config
from qlib.log import get_module_logger
class DataLoader(abc.ABC):
@@ -217,3 +218,68 @@ class StaticDataLoader(DataLoader):
join=self.join,
)
self._data.sort_index(inplace=True)
class DataLoaderDH(DataLoader):
"""DataLoaderDH
DataLoader based on (D)ata (H)andler
It is designed to load multiple data from data handler
- If you just want to load data from single datahandler, you can write them in single data handler
TODO: What make this module not that easy to use.
- For online scenario
- The underlayer data handler should be configured. But data loader doesn't provide such interface & hook.
"""
def __init__(self, handler_config: dict, fetch_kwargs: dict = {}, is_group=False):
"""
Parameters
----------
handler_config : dict
handler_config will be used to describe the handlers
.. code-block::
<handler_config> := {
"group_name1": <handler>
"group_name2": <handler>
}
or
<handler_config> := <handler>
<handler> := DataHandler Instance | DataHandler Config
fetch_kwargs : dict
fetch_kwargs will be used to describe the different arguments of fetch method, such as col_set, squeeze, data_key, etc.
is_group: bool
is_group will be used to describe whether the key of handler_config is group
"""
from qlib.data.dataset.handler import DataHandler
if is_group:
self.handlers = {
grp: init_instance_by_config(config, accept_types=DataHandler) for grp, config in handler_config.items()
}
else:
self.handlers = init_instance_by_config(handler_config, accept_types=DataHandler)
self.is_group = is_group
self.fetch_kwargs = {"col_set": DataHandler.CS_RAW}
self.fetch_kwargs.update(fetch_kwargs)
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
if instruments is not None:
get_module_logger(self.__class__.__name__).warning(f"instruments[{instruments}] is ignored")
if self.is_group:
df = pd.concat(
{
grp: dh.fetch(selector=slice(start_time, end_time), level="datetime", **self.fetch_kwargs)
for grp, dh in self.handlers.items()
},
axis=1,
)
else:
df = self.handlers.fetch(selector=slice(start_time, end_time), level="datetime", **self.fetch_kwargs)
return df

18
qlib/data/dataset/processor.py Executable file → Normal file
View File

@@ -2,6 +2,7 @@
# Licensed under the MIT License.
import abc
from typing import Union, Text
import numpy as np
import pandas as pd
import copy
@@ -14,7 +15,7 @@ from ...utils.paral import datetime_groupby_apply
EPS = 1e-12
def get_group_columns(df: pd.DataFrame, group: str):
def get_group_columns(df: pd.DataFrame, group: Union[Text, None]):
"""
get a group of columns from multi-index columns DataFrame
@@ -72,6 +73,17 @@ class Processor(Serializable):
"""
return True
def config(self, **kwargs):
attr_list = {"fit_start_time", "fit_end_time"}
for k, v in kwargs.items():
if k in attr_list and hasattr(self, k):
setattr(self, k, v)
for attr in attr_list:
if attr in kwargs:
kwargs.pop(attr)
super().config(**kwargs)
class DropnaProcessor(Processor):
def __init__(self, fields_group=None):
@@ -118,7 +130,7 @@ class FilterCol(Processor):
class TanhProcess(Processor):
""" Use tanh to process noise data"""
"""Use tanh to process noise data"""
def __call__(self, df):
def tanh_denoise(data):
@@ -133,7 +145,7 @@ class TanhProcess(Processor):
class ProcessInf(Processor):
"""Process infinity """
"""Process infinity"""
def __call__(self, df):
def replace_inf(data):

View File

@@ -12,7 +12,41 @@ from contextlib import contextmanager
from .config import C
def get_module_logger(module_name, level: Optional[int] = None):
class MetaLogger(type):
def __new__(cls, name, bases, dict):
wrapper_dict = logging.Logger.__dict__.copy()
for key in wrapper_dict:
if key not in dict and key != "__reduce__":
dict[key] = wrapper_dict[key]
return type.__new__(cls, name, bases, dict)
class QlibLogger(metaclass=MetaLogger):
"""
Customized logger for Qlib.
"""
def __init__(self, module_name):
self.module_name = module_name
self.level = 0
@property
def logger(self):
logger = logging.getLogger(self.module_name)
logger.setLevel(self.level)
return logger
def setLevel(self, level):
self.level = level
def __getattr__(self, name):
# During unpickling, python will call __getattr__. Use this line to avoid maximum recursion error.
if name in {"__setstate__"}:
raise AttributeError
return self.logger.__getattribute__(name)
def get_module_logger(module_name, level: Optional[int] = None) -> logging.Logger:
"""
Get a logger for a specific module.
@@ -27,7 +61,7 @@ def get_module_logger(module_name, level: Optional[int] = None):
module_name = "qlib.{}".format(module_name)
# Get logger.
module_logger = logging.getLogger(module_name)
module_logger = QlibLogger(module_name)
module_logger.setLevel(level)
return module_logger
@@ -129,3 +163,83 @@ class LogFilter(logging.Filter):
elif isinstance(self.param, list):
allow = not any([self.match_msg(p, record.msg) for p in self.param])
return allow
def set_global_logger_level(level: int, return_orig_handler_level: bool = False):
"""set qlib.xxx logger handlers level
Parameters
----------
level: int
logger level
return_orig_handler_level: bool
return origin handler level map
Examples
---------
.. code-block:: python
import qlib
import logging
from qlib.log import get_module_logger, set_global_logger_level
qlib.init()
tmp_logger_01 = get_module_logger("tmp_logger_01", level=logging.INFO)
tmp_logger_01.info("1. tmp_logger_01 info show")
global_level = logging.WARNING + 1
set_global_logger_level(global_level)
tmp_logger_02 = get_module_logger("tmp_logger_02", level=logging.INFO)
tmp_logger_02.log(msg="2. tmp_logger_02 log show", level=global_level)
tmp_logger_01.info("3. tmp_logger_01 info do not show")
"""
_handler_level_map = {}
qlib_logger = logging.root.manager.loggerDict.get("qlib", None)
if qlib_logger is not None:
for _handler in qlib_logger.handlers:
_handler_level_map[_handler] = _handler.level
_handler.level = level
return _handler_level_map if return_orig_handler_level else None
@contextmanager
def set_global_logger_level_cm(level: int):
"""set qlib.xxx logger handlers level to use contextmanager
Parameters
----------
level: int
logger level
Examples
---------
.. code-block:: python
import qlib
import logging
from qlib.log import get_module_logger, set_global_logger_level_cm
qlib.init()
tmp_logger_01 = get_module_logger("tmp_logger_01", level=logging.INFO)
tmp_logger_01.info("1. tmp_logger_01 info show")
global_level = logging.WARNING + 1
with set_global_logger_level_cm(global_level):
tmp_logger_02 = get_module_logger("tmp_logger_02", level=logging.INFO)
tmp_logger_02.log(msg="2. tmp_logger_02 log show", level=global_level)
tmp_logger_01.info("3. tmp_logger_01 info do not show")
tmp_logger_01.info("4. tmp_logger_01 info show")
"""
_handler_level_map = set_global_logger_level(level, return_orig_handler_level=True)
try:
yield
finally:
for _handler, _level in _handler_level_map.items():
_handler.level = _level

View File

@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import abc
from typing import Text, Union
from ..utils.serial import Serializable
from ..data.dataset import Dataset
@@ -10,11 +11,11 @@ class BaseModel(Serializable, metaclass=abc.ABCMeta):
@abc.abstractmethod
def predict(self, *args, **kwargs) -> object:
""" Make predictions after modeling things """
"""Make predictions after modeling things"""
pass
def __call__(self, *args, **kwargs) -> object:
""" leverage Python syntactic sugar to make the models' behaviors like functions """
"""leverage Python syntactic sugar to make the models' behaviors like functions"""
return self.predict(*args, **kwargs)
@@ -59,7 +60,7 @@ class Model(BaseModel):
raise NotImplementedError()
@abc.abstractmethod
def predict(self, dataset: Dataset) -> object:
def predict(self, dataset: Dataset, segment: Union[Text, slice] = "test") -> object:
"""give prediction given Dataset
Parameters
@@ -67,6 +68,9 @@ class Model(BaseModel):
dataset : Dataset
dataset will generate the processed dataset from model training.
segment : Text or slice
dataset will use this segment to prepare data. (default=test)
Returns
-------
Prediction results with certain type such as `pandas.Series`.

View File

115
qlib/model/ens/ensemble.py Normal file
View File

@@ -0,0 +1,115 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
Ensemble module can merge the objects in an Ensemble. For example, if there are many submodels predictions, we may need to merge them into an ensemble prediction.
"""
from typing import Union
import pandas as pd
from qlib.utils import FLATTEN_TUPLE, flatten_dict
class Ensemble:
"""Merge the ensemble_dict into an ensemble object.
For example: {Rollinga_b: object, Rollingb_c: object} -> object
When calling this class:
Args:
ensemble_dict (dict): the ensemble dict like {name: things} waiting for merging
Returns:
object: the ensemble object
"""
def __call__(self, ensemble_dict: dict, *args, **kwargs):
raise NotImplementedError(f"Please implement the `__call__` method.")
class SingleKeyEnsemble(Ensemble):
"""
Extract the object if there is only one key and value in the dict. Make the result more readable.
{Only key: Only value} -> Only value
If there is more than 1 key or less than 1 key, then do nothing.
Even you can run this recursively to make dict more readable.
NOTE: Default runs recursively.
When calling this class:
Args:
ensemble_dict (dict): the dict. The key of the dict will be ignored.
Returns:
dict: the readable dict.
"""
def __call__(self, ensemble_dict: Union[dict, object], recursion: bool = True) -> object:
if not isinstance(ensemble_dict, dict):
return ensemble_dict
if recursion:
tmp_dict = {}
for k, v in ensemble_dict.items():
tmp_dict[k] = self(v, recursion)
ensemble_dict = tmp_dict
keys = list(ensemble_dict.keys())
if len(keys) == 1:
ensemble_dict = ensemble_dict[keys[0]]
return ensemble_dict
class RollingEnsemble(Ensemble):
"""Merge a dict of rolling dataframe like `prediction` or `IC` into an ensemble.
NOTE: The values of dict must be pd.DataFrame, and have the index "datetime".
When calling this class:
Args:
ensemble_dict (dict): a dict like {"A": pd.DataFrame, "B": pd.DataFrame}.
The key of the dict will be ignored.
Returns:
pd.DataFrame: the complete result of rolling.
"""
def __call__(self, ensemble_dict: dict) -> pd.DataFrame:
artifact_list = list(ensemble_dict.values())
artifact_list.sort(key=lambda x: x.index.get_level_values("datetime").min())
artifact = pd.concat(artifact_list)
# If there are duplicated predition, use the latest perdiction
artifact = artifact[~artifact.index.duplicated(keep="last")]
artifact = artifact.sort_index()
return artifact
class AverageEnsemble(Ensemble):
"""
Average and standardize a dict of same shape dataframe like `prediction` or `IC` into an ensemble.
NOTE: The values of dict must be pd.DataFrame, and have the index "datetime". If it is a nested dict, then flat it.
When calling this class:
Args:
ensemble_dict (dict): a dict like {"A": pd.DataFrame, "B": pd.DataFrame}.
The key of the dict will be ignored.
Returns:
pd.DataFrame: the complete result of averaging and standardizing.
"""
def __call__(self, ensemble_dict: dict) -> pd.DataFrame:
# need to flatten the nested dict
ensemble_dict = flatten_dict(ensemble_dict, sep=FLATTEN_TUPLE)
values = list(ensemble_dict.values())
results = pd.concat(values, axis=1)
results = results.groupby("datetime").apply(lambda df: (df - df.mean()) / df.std())
results = results.mean(axis=1)
results = results.sort_index()
return results

113
qlib/model/ens/group.py Normal file
View File

@@ -0,0 +1,113 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
Group can group a set of objects based on `group_func` and change them to a dict.
After group, we provide a method to reduce them.
For example:
group: {(A,B,C1): object, (A,B,C2): object} -> {(A,B): {C1: object, C2: object}}
reduce: {(A,B): {C1: object, C2: object}} -> {(A,B): object}
"""
from qlib.model.ens.ensemble import Ensemble, RollingEnsemble
from typing import Callable, Union
from joblib import Parallel, delayed
class Group:
"""Group the objects based on dict"""
def __init__(self, group_func=None, ens: Ensemble = None):
"""
Init Group.
Args:
group_func (Callable, optional): Given a dict and return the group key and one of the group elements.
For example: {(A,B,C1): object, (A,B,C2): object} -> {(A,B): {C1: object, C2: object}}
Defaults to None.
ens (Ensemble, optional): If not None, do ensemble for grouped value after grouping.
"""
self._group_func = group_func
self._ens_func = ens
def group(self, *args, **kwargs) -> dict:
"""
Group a set of objects and change them to a dict.
For example: {(A,B,C1): object, (A,B,C2): object} -> {(A,B): {C1: object, C2: object}}
Returns:
dict: grouped dict
"""
if isinstance(getattr(self, "_group_func", None), Callable):
return self._group_func(*args, **kwargs)
else:
raise NotImplementedError(f"Please specify valid `group_func`.")
def reduce(self, *args, **kwargs) -> dict:
"""
Reduce grouped dict.
For example: {(A,B): {C1: object, C2: object}} -> {(A,B): object}
Returns:
dict: reduced dict
"""
if isinstance(getattr(self, "_ens_func", None), Callable):
return self._ens_func(*args, **kwargs)
else:
raise NotImplementedError(f"Please specify valid `_ens_func`.")
def __call__(self, ungrouped_dict: dict, n_jobs: int = 1, verbose: int = 0, *args, **kwargs) -> dict:
"""
Group the ungrouped_dict into different groups.
Args:
ungrouped_dict (dict): the ungrouped dict waiting for grouping like {name: things}
Returns:
dict: grouped_dict like {G1: object, G2: object}
n_jobs: how many progress you need.
verbose: the print mode for Parallel.
"""
# NOTE: The multiprocessing will raise error if you use `Serializable`
# Because the `Serializable` will affect the behaviors of pickle
grouped_dict = self.group(ungrouped_dict, *args, **kwargs)
key_l = []
job_l = []
for key, value in grouped_dict.items():
key_l.append(key)
job_l.append(delayed(Group.reduce)(self, value))
return dict(zip(key_l, Parallel(n_jobs=n_jobs, verbose=verbose)(job_l)))
class RollingGroup(Group):
"""Group the rolling dict"""
def group(self, rolling_dict: dict) -> dict:
"""Given an rolling dict likes {(A,B,R): things}, return the grouped dict likes {(A,B): {R:things}}
NOTE: There is an assumption which is the rolling key is at the end of the key tuple, because the rolling results always need to be ensemble firstly.
Args:
rolling_dict (dict): an rolling dict. If the key is not a tuple, then do nothing.
Returns:
dict: grouped dict
"""
grouped_dict = {}
for key, values in rolling_dict.items():
if isinstance(key, tuple):
grouped_dict.setdefault(key[:-1], {})[key[-1]] = values
return grouped_dict
def __init__(self):
super().__init__(ens=RollingEnsemble())

View File

@@ -1,27 +0,0 @@
import abc
import typing
class TaskGen(metaclass=abc.ABCMeta):
@abc.abstractmethod
def __call__(self, *args, **kwargs) -> typing.List[dict]:
"""
generate
Parameters
----------
args, kwargs:
The info for generating tasks
Example 1):
input: a specific task template
output: rolling version of the tasks
Example 2):
input: a specific task template
output: a set of tasks with different losses
Returns
-------
typing.List[dict]:
A list of tasks
"""
pass

View File

@@ -1,42 +1,446 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from qlib.utils import init_instance_by_config, flatten_dict
"""
The Trainer will train a list of tasks and return a list of model recorders.
There are two steps in each Trainer including ``train``(make model recorder) and ``end_train``(modify model recorder).
This is a concept called ``DelayTrainer``, which can be used in online simulating for parallel training.
In ``DelayTrainer``, the first step is only to save some necessary info to model recorders, and the second step which will be finished in the end can do some concurrent and time-consuming operations such as model fitting.
``Qlib`` offer two kinds of Trainer, ``TrainerR`` is the simplest way and ``TrainerRM`` is based on TaskManager to help manager tasks lifecycle automatically.
"""
import socket
from typing import Callable, List
from qlib.data.dataset import Dataset
from qlib.model.base import Model
from qlib.utils import flatten_dict, get_cls_kwargs, init_instance_by_config
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord
from qlib.workflow.recorder import Recorder
from qlib.workflow.task.manage import TaskManager, run_task
def task_train(task_config: dict, experiment_name):
def begin_task_train(task_config: dict, experiment_name: str, recorder_name: str = None) -> Recorder:
"""
task based training
Begin task training to start a recorder and save the task config.
Args:
task_config (dict): the config of a task
experiment_name (str): the name of experiment
recorder_name (str): the given name will be the recorder name. None for using rid.
Returns:
Recorder: the model recorder
"""
with R.start(experiment_name=experiment_name, recorder_name=recorder_name):
R.log_params(**flatten_dict(task_config))
R.save_objects(**{"task": task_config}) # keep the original format and datatype
R.set_tags(**{"hostname": socket.gethostname()})
recorder: Recorder = R.get_recorder()
return recorder
def end_task_train(rec: Recorder, experiment_name: str) -> Recorder:
"""
Finish task training with real model fitting and saving.
Args:
rec (Recorder): the recorder will be resumed
experiment_name (str): the name of experiment
Returns:
Recorder: the model recorder
"""
with R.start(experiment_name=experiment_name, recorder_id=rec.info["id"], resume=True):
task_config = R.load_object("task")
# model & dataset initiation
model: Model = init_instance_by_config(task_config["model"])
dataset: Dataset = init_instance_by_config(task_config["dataset"])
# model training
model.fit(dataset)
R.save_objects(**{"params.pkl": model})
# this dataset is saved for online inference. So the concrete data should not be dumped
dataset.config(dump_all=False, recursive=True)
R.save_objects(**{"dataset": dataset})
# generate records: prediction, backtest, and analysis
records = task_config.get("record", [])
if isinstance(records, dict): # prevent only one dict
records = [records]
for record in records:
cls, kwargs = get_cls_kwargs(record, default_module="qlib.workflow.record_temp")
if cls is SignalRecord:
rconf = {"model": model, "dataset": dataset, "recorder": rec}
else:
rconf = {"recorder": rec}
r = cls(**kwargs, **rconf)
r.generate()
return rec
def task_train(task_config: dict, experiment_name: str) -> Recorder:
"""
Task based training, will be divided into two steps.
Parameters
----------
task_config : dict
A dict describes a task setting.
The config of a task.
experiment_name: str
The name of experiment
Returns
----------
Recorder: The instance of the recorder
"""
recorder = begin_task_train(task_config, experiment_name)
recorder = end_task_train(recorder, experiment_name)
return recorder
class Trainer:
"""
The trainer can train a list of models.
There are Trainer and DelayTrainer, which can be distinguished by when it will finish real training.
"""
# model initiaiton
model = init_instance_by_config(task_config["model"])
dataset = init_instance_by_config(task_config["dataset"])
def __init__(self):
self.delay = False
# start exp
with R.start(experiment_name=experiment_name):
# train model
R.log_params(**flatten_dict(task_config))
model.fit(dataset)
recorder = R.get_recorder()
R.save_objects(**{"params.pkl": model})
def train(self, tasks: list, *args, **kwargs) -> list:
"""
Given a list of task definitions, begin training, and return the models.
# generate records: prediction, backtest, and analysis
for record in task_config["record"]:
if record["class"] == SignalRecord.__name__:
srconf = {"model": model, "dataset": dataset, "recorder": recorder}
record["kwargs"].update(srconf)
sr = init_instance_by_config(record)
sr.generate()
else:
rconf = {"recorder": recorder}
record["kwargs"].update(rconf)
ar = init_instance_by_config(record)
ar.generate()
For Trainer, it finishes real training in this method.
For DelayTrainer, it only does some preparation in this method.
Args:
tasks: a list of tasks
Returns:
list: a list of models
"""
raise NotImplementedError(f"Please implement the `train` method.")
def end_train(self, models: list, *args, **kwargs) -> list:
"""
Given a list of models, finished something at the end of training if you need.
The models may be Recorder, txt file, database, and so on.
For Trainer, it does some finishing touches in this method.
For DelayTrainer, it finishes real training in this method.
Args:
models: a list of models
Returns:
list: a list of models
"""
# do nothing if you finished all work in `train` method
return models
def is_delay(self) -> bool:
"""
If Trainer will delay finishing `end_train`.
Returns:
bool: if DelayTrainer
"""
return self.delay
class TrainerR(Trainer):
"""
Trainer based on (R)ecorder.
It will train a list of tasks and return a list of model recorders in a linear way.
Assumption: models were defined by `task` and the results will be saved to `Recorder`.
"""
# Those tag will help you distinguish whether the Recorder has finished traning
STATUS_KEY = "train_status"
STATUS_BEGIN = "begin_task_train"
STATUS_END = "end_task_train"
def __init__(self, experiment_name: str = None, train_func: Callable = task_train):
"""
Init TrainerR.
Args:
experiment_name (str, optional): the default name of experiment.
train_func (Callable, optional): default training method. Defaults to `task_train`.
"""
super().__init__()
self.experiment_name = experiment_name
self.train_func = train_func
def train(self, tasks: list, train_func: Callable = None, experiment_name: str = None, **kwargs) -> List[Recorder]:
"""
Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed.
Args:
tasks (list): a list of definitions based on `task` dict
train_func (Callable): the training method which needs at least `tasks` and `experiment_name`. None for the default training method.
experiment_name (str): the experiment name, None for use default name.
kwargs: the params for train_func.
Returns:
List[Recorder]: a list of Recorders
"""
if len(tasks) == 0:
return []
if train_func is None:
train_func = self.train_func
if experiment_name is None:
experiment_name = self.experiment_name
recs = []
for task in tasks:
rec = train_func(task, experiment_name, **kwargs)
rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN})
recs.append(rec)
return recs
def end_train(self, recs: list, **kwargs) -> List[Recorder]:
"""
Set STATUS_END tag to the recorders.
Args:
recs (list): a list of trained recorders.
Returns:
List[Recorder]: the same list as the param.
"""
for rec in recs:
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
return recs
class DelayTrainerR(TrainerR):
"""
A delayed implementation based on TrainerR, which means `train` method may only do some preparation and `end_train` method can do the real model fitting.
"""
def __init__(self, experiment_name: str = None, train_func=begin_task_train, end_train_func=end_task_train):
"""
Init TrainerRM.
Args:
experiment_name (str): the default name of experiment.
train_func (Callable, optional): default train method. Defaults to `begin_task_train`.
end_train_func (Callable, optional): default end_train method. Defaults to `end_task_train`.
"""
super().__init__(experiment_name, train_func)
self.end_train_func = end_train_func
self.delay = True
def end_train(self, recs, end_train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:
"""
Given a list of Recorder and return a list of trained Recorder.
This class will finish real data loading and model fitting.
Args:
recs (list): a list of Recorder, the tasks have been saved to them
end_train_func (Callable, optional): the end_train method which needs at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
experiment_name (str): the experiment name, None for use default name.
kwargs: the params for end_train_func.
Returns:
List[Recorder]: a list of Recorders
"""
if end_train_func is None:
end_train_func = self.end_train_func
if experiment_name is None:
experiment_name = self.experiment_name
for rec in recs:
if rec.list_tags()[self.STATUS_KEY] == self.STATUS_END:
continue
end_train_func(rec, experiment_name, **kwargs)
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
return recs
class TrainerRM(Trainer):
"""
Trainer based on (R)ecorder and Task(M)anager.
It can train a list of tasks and return a list of model recorders in a multiprocessing way.
Assumption: `task` will be saved to TaskManager and `task` will be fetched and trained from TaskManager
"""
# Those tag will help you distinguish whether the Recorder has finished traning
STATUS_KEY = "train_status"
STATUS_BEGIN = "begin_task_train"
STATUS_END = "end_task_train"
def __init__(self, experiment_name: str = None, task_pool: str = None, train_func=task_train):
"""
Init TrainerR.
Args:
experiment_name (str): the default name of experiment.
task_pool (str): task pool name in TaskManager. None for use same name as experiment_name.
train_func (Callable, optional): default training method. Defaults to `task_train`.
"""
super().__init__()
self.experiment_name = experiment_name
self.task_pool = task_pool
self.train_func = train_func
def train(
self,
tasks: list,
train_func: Callable = None,
experiment_name: str = None,
before_status: str = TaskManager.STATUS_WAITING,
after_status: str = TaskManager.STATUS_DONE,
**kwargs,
) -> List[Recorder]:
"""
Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed.
This method defaults to a single process, but TaskManager offered a great way to parallel training.
Users can customize their train_func to realize multiple processes or even multiple machines.
Args:
tasks (list): a list of definitions based on `task` dict
train_func (Callable): the training method which needs at least `task`s and `experiment_name`. None for the default training method.
experiment_name (str): the experiment name, None for use default name.
before_status (str): the tasks in before_status will be fetched and trained. Can be STATUS_WAITING, STATUS_PART_DONE.
after_status (str): the tasks after trained will become after_status. Can be STATUS_WAITING, STATUS_PART_DONE.
kwargs: the params for train_func.
Returns:
List[Recorder]: a list of Recorders
"""
if len(tasks) == 0:
return []
if train_func is None:
train_func = self.train_func
if experiment_name is None:
experiment_name = self.experiment_name
task_pool = self.task_pool
if task_pool is None:
task_pool = experiment_name
tm = TaskManager(task_pool=task_pool)
_id_list = tm.create_task(tasks) # all tasks will be saved to MongoDB
run_task(
train_func,
task_pool,
experiment_name=experiment_name,
before_status=before_status,
after_status=after_status,
**kwargs,
)
recs = []
for _id in _id_list:
rec = tm.re_query(_id)["res"]
rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN})
recs.append(rec)
return recs
def end_train(self, recs: list, **kwargs) -> List[Recorder]:
"""
Set STATUS_END tag to the recorders.
Args:
recs (list): a list of trained recorders.
Returns:
List[Recorder]: the same list as the param.
"""
for rec in recs:
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
return recs
class DelayTrainerRM(TrainerRM):
"""
A delayed implementation based on TrainerRM, which means `train` method may only do some preparation and `end_train` method can do the real model fitting.
"""
def __init__(
self,
experiment_name: str = None,
task_pool: str = None,
train_func=begin_task_train,
end_train_func=end_task_train,
):
"""
Init DelayTrainerRM.
Args:
experiment_name (str): the default name of experiment.
task_pool (str): task pool name in TaskManager. None for use same name as experiment_name.
train_func (Callable, optional): default train method. Defaults to `begin_task_train`.
end_train_func (Callable, optional): default end_train method. Defaults to `end_task_train`.
"""
super().__init__(experiment_name, task_pool, train_func)
self.end_train_func = end_train_func
self.delay = True
def train(self, tasks: list, train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:
"""
Same as `train` of TrainerRM, after_status will be STATUS_PART_DONE.
Args:
tasks (list): a list of definition based on `task` dict
train_func (Callable): the train method which need at least `task`s and `experiment_name`. Defaults to None for using self.train_func.
experiment_name (str): the experiment name, None for use default name.
Returns:
List[Recorder]: a list of Recorders
"""
if len(tasks) == 0:
return []
return super().train(
tasks,
train_func=train_func,
experiment_name=experiment_name,
after_status=TaskManager.STATUS_PART_DONE,
**kwargs,
)
def end_train(self, recs, end_train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:
"""
Given a list of Recorder and return a list of trained Recorder.
This class will finish real data loading and model fitting.
NOTE: This method will train all STATUS_PART_DONE tasks in the task pool, not only the ``recs``.
Args:
recs (list): a list of Recorder, the tasks have been saved to them.
end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
experiment_name (str): the experiment name, None for use default name.
kwargs: the params for end_train_func.
Returns:
List[Recorder]: a list of Recorders
"""
if end_train_func is None:
end_train_func = self.end_train_func
if experiment_name is None:
experiment_name = self.experiment_name
task_pool = self.task_pool
if task_pool is None:
task_pool = experiment_name
tasks = []
for rec in recs:
tasks.append(rec.load_object("task"))
run_task(
end_train_func,
task_pool,
query={"filter": {"$in": tasks}}, # only train these tasks
experiment_name=experiment_name,
before_status=TaskManager.STATUS_PART_DONE,
**kwargs,
)
for rec in recs:
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
return recs

View File

@@ -5,9 +5,9 @@ import abc
class BaseOptimizer(abc.ABC):
""" Construct portfolio with a optimization related method """
"""Construct portfolio with a optimization related method"""
@abc.abstractmethod
def __call__(self, *args, **kwargs) -> object:
""" Generate a optimized portfolio allocation """
"""Generate a optimized portfolio allocation"""
pass

View File

@@ -6,6 +6,7 @@ from __future__ import division
from __future__ import print_function
import os
import pickle
import re
import copy
import json
@@ -24,7 +25,9 @@ import collections
import numpy as np
import pandas as pd
from pathlib import Path
from typing import Union, Tuple, Text, Optional
from typing import Union, Tuple, Any, Text, Optional
from types import ModuleType
from urllib.parse import urlparse
from ..config import C
from ..log import get_module_logger, set_log_with_config
@@ -165,24 +168,25 @@ def parse_field(field):
return re.sub(r"\$(\w+)", r'Feature("\1")', re.sub(r"(\w+\s*)\(", r"Operators.\1(", field))
def get_module_by_module_path(module_path):
def get_module_by_module_path(module_path: Union[str, ModuleType]):
"""Load module path
:param module_path:
:return:
"""
if module_path.endswith(".py"):
module_spec = importlib.util.spec_from_file_location("", module_path)
module = importlib.util.module_from_spec(module_spec)
module_spec.loader.exec_module(module)
if isinstance(module_path, ModuleType):
module = module_path
else:
module = importlib.import_module(module_path)
if module_path.endswith(".py"):
module_spec = importlib.util.spec_from_file_location("", module_path)
module = importlib.util.module_from_spec(module_spec)
module_spec.loader.exec_module(module)
else:
module = importlib.import_module(module_path)
return module
def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
def get_cls_kwargs(config: Union[dict, str], default_module: Union[str, ModuleType] = None) -> (type, dict):
"""
extract class and kwargs from config info
@@ -191,8 +195,10 @@ def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
config : [dict, str]
similar to config
module : Python module
default_module : Python module or str
It should be a python module to load the class type
This function will load class from the config['module_path'] first.
If config['module_path'] doesn't exists, it will load the class from default_module.
Returns
-------
@@ -200,10 +206,14 @@ def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
the class object and it's arguments.
"""
if isinstance(config, dict):
module = get_module_by_module_path(config.get("module_path", default_module))
# raise AttributeError
klass = getattr(module, config["class"])
kwargs = config.get("kwargs", {})
elif isinstance(config, str):
module = get_module_by_module_path(default_module)
klass = getattr(module, config)
kwargs = {}
else:
@@ -212,8 +222,8 @@ def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
def init_instance_by_config(
config: Union[str, dict, object], module=None, accept_types: Union[type, Tuple[type]] = (), **kwargs
) -> object:
config: Union[str, dict, object], default_module=None, accept_types: Union[type, Tuple[type]] = (), **kwargs
) -> Any:
"""
get initialized instance with config
@@ -227,13 +237,19 @@ def init_instance_by_config(
'model_path': path, # It is optional if module is given
}
str example.
"ClassName": getattr(module, config)() will be used.
1) specify a pickle object
- path like 'file:///<path to pickle file>/obj.pkl'
2) specify a class name
- "ClassName": getattr(module, config)() will be used.
object example:
instance of accept_types
module : Python module
default_module : Python module
Optional. It should be a python module.
NOTE: the "module_path" will be override by `module` arguments
This function will load class from the config['module_path'] first.
If config['module_path'] doesn't exists, it will load the class from default_module.
accept_types: Union[type, Tuple[type]]
Optional. If the config is a instance of specific type, return the config directly.
This will be passed into the second parameter of isinstance.
@@ -246,10 +262,14 @@ def init_instance_by_config(
if isinstance(config, accept_types):
return config
if module is None:
module = get_module_by_module_path(config["module_path"])
if isinstance(config, str):
# path like 'file:///<path to pickle file>/obj.pkl'
pr = urlparse(config)
if pr.scheme == "file":
with open(os.path.join(pr.netloc, pr.path), "rb") as f:
return pickle.load(f)
klass, cls_kwargs = get_cls_kwargs(config, module)
klass, cls_kwargs = get_cls_kwargs(config, default_module=default_module)
return klass(**cls_kwargs, **kwargs)
@@ -502,7 +522,7 @@ def get_date_range(trading_date, left_shift=0, right_shift=0, future=False):
return calendar
def get_date_by_shift(trading_date, shift, future=False, clip_shift=True):
def get_date_by_shift(trading_date, shift, future=False, clip_shift=True, freq="day"):
"""get trading date with shift bias wil cur_date
e.g. : shift == 1, return next trading date
shift == -1, return previous trading date
@@ -515,7 +535,7 @@ def get_date_by_shift(trading_date, shift, future=False, clip_shift=True):
"""
from qlib.data import D
cal = D.calendar(future=future)
cal = D.calendar(future=future, freq=freq)
if pd.to_datetime(trading_date) not in list(cal):
raise ValueError("{} is not trading day!".format(str(trading_date)))
_index = bisect.bisect_left(cal, trading_date)
@@ -696,23 +716,33 @@ def lazy_sort_index(df: pd.DataFrame, axis=0) -> pd.DataFrame:
return df.sort_index(axis=axis)
def flatten_dict(d, parent_key="", sep="."):
"""flatten_dict.
FLATTEN_TUPLE = "_FLATTEN_TUPLE"
def flatten_dict(d, parent_key="", sep=".") -> dict:
"""
Flatten a nested dict.
>>> flatten_dict({'a': 1, 'c': {'a': 2, 'b': {'x': 5, 'y' : 10}}, 'd': [1, 2, 3]})
>>> {'a': 1, 'c.a': 2, 'c.b.x': 5, 'd': [1, 2, 3], 'c.b.y': 10}
Parameters
----------
d :
d
parent_key :
parent_key
sep :
sep
>>> flatten_dict({'a': 1, 'c': {'a': 2, 'b': {'x': 5, 'y' : 10}}, 'd': [1, 2, 3]}, sep=FLATTEN_TUPLE)
>>> {'a': 1, ('c','a'): 2, ('c','b','x'): 5, 'd': [1, 2, 3], ('c','b','y'): 10}
Args:
d (dict): the dict waiting for flatting
parent_key (str, optional): the parent key, will be a prefix in new key. Defaults to "".
sep (str, optional): the separator for string connecting. FLATTEN_TUPLE for tuple connecting.
Returns:
dict: flatten dict
"""
items = []
for k, v in d.items():
new_key = parent_key + sep + k if parent_key else k
if sep == FLATTEN_TUPLE:
new_key = (parent_key, k) if parent_key else k
else:
new_key = parent_key + sep + k if parent_key else k
if isinstance(v, collections.abc.MutableMapping):
items.extend(flatten_dict(v, new_key, sep=sep).items())
else:

View File

@@ -3,16 +3,24 @@
from pathlib import Path
import pickle
import typing
import dill
from typing import Union
class Serializable:
"""
Serializable behaves like pickle.
But it only saves the state whose name **does not** start with `_`
Serializable will change the behaviors of pickle.
- It only saves the state whose name **does not** start with `_`
It provides a syntactic sugar for distinguish the attributes which user doesn't want.
- For examples, a learnable Datahandler just wants to save the parameters without data when dumping to disk
"""
pickle_backend = "pickle" # another optional value is "dill" which can pickle more things of python.
default_dump_all = False # if dump all things
def __init__(self):
self._dump_all = False
self._dump_all = self.default_dump_all
self._exclude = []
def __getstate__(self) -> dict:
@@ -33,18 +41,86 @@ class Serializable:
@property
def exclude(self):
"""
What attribute will be dumped
What attribute will not be dumped
"""
return getattr(self, "_exclude", [])
def config(self, dump_all: bool = None, exclude: list = None):
if dump_all is not None:
self._dump_all = dump_all
FLAG_KEY = "_qlib_serial_flag"
if exclude is not None:
self._exclude = exclude
def config(self, dump_all: bool = None, exclude: list = None, recursive=False):
"""
configure the serializable object
def to_pickle(self, path: [Path, str], dump_all: bool = None, exclude: list = None):
Parameters
----------
dump_all : bool
will the object dump all object
exclude : list
What attribute will not be dumped
recursive : bool
will the configuration be recursive
"""
params = {"dump_all": dump_all, "exclude": exclude}
for k, v in params.items():
if v is not None:
attr_name = f"_{k}"
setattr(self, attr_name, v)
if recursive:
for obj in self.__dict__.values():
# set flag to prevent endless loop
self.__dict__[self.FLAG_KEY] = True
if isinstance(obj, Serializable) and self.FLAG_KEY not in obj.__dict__:
obj.config(**params, recursive=True)
del self.__dict__[self.FLAG_KEY]
def to_pickle(self, path: Union[Path, str], dump_all: bool = None, exclude: list = None):
"""
Dump self to a pickle file.
Args:
path (Union[Path, str]): the path to dump
dump_all (bool, optional): if need to dump all things. Defaults to None.
exclude (list, optional): will exclude the attributes in this list when dumping. Defaults to None.
"""
self.config(dump_all=dump_all, exclude=exclude)
with Path(path).open("wb") as f:
pickle.dump(self, f)
self.get_backend().dump(self, f)
@classmethod
def load(cls, filepath):
"""
Load the collector from a filepath.
Args:
filepath (str): the path of file
Raises:
TypeError: the pickled file must be `Collector`
Returns:
Collector: the instance of Collector
"""
with open(filepath, "rb") as f:
object = cls.get_backend().load(f)
if isinstance(object, cls):
return object
else:
raise TypeError(f"The instance of {type(object)} is not a valid `{type(cls)}`!")
@classmethod
def get_backend(cls):
"""
Return the real backend of a Serializable class. The pickle_backend value can be "pickle" or "dill".
Returns:
module: pickle or dill module based on pickle_backend
"""
if cls.pickle_backend == "pickle":
return pickle
elif cls.pickle_backend == "dill":
return dill
else:
raise ValueError("Unknown pickle backend, please use 'pickle' or 'dill'.")

View File

@@ -23,7 +23,10 @@ class QlibRecorder:
@contextmanager
def start(
self,
*,
experiment_id: Optional[Text] = None,
experiment_name: Optional[Text] = None,
recorder_id: Optional[Text] = None,
recorder_name: Optional[Text] = None,
uri: Optional[Text] = None,
resume: bool = False,
@@ -45,8 +48,12 @@ class QlibRecorder:
Parameters
----------
experiment_id : str
id of the experiment one wants to start.
experiment_name : str
name of the experiment one wants to start.
recorder_id : str
id of the recorder under the experiment one wants to start.
recorder_name : str
name of the recorder under the experiment one wants to start.
uri : str
@@ -57,7 +64,14 @@ class QlibRecorder:
resume : bool
whether to resume the specific recorder with given name under the given experiment.
"""
run = self.start_exp(experiment_name, recorder_name, uri, resume)
run = self.start_exp(
experiment_id=experiment_id,
experiment_name=experiment_name,
recorder_id=recorder_id,
recorder_name=recorder_name,
uri=uri,
resume=resume,
)
try:
yield run
except Exception as e:
@@ -65,7 +79,9 @@ class QlibRecorder:
raise e
self.end_exp(Recorder.STATUS_FI)
def start_exp(self, experiment_name=None, recorder_name=None, uri=None, resume=False):
def start_exp(
self, *, experiment_id=None, experiment_name=None, recorder_id=None, recorder_name=None, uri=None, resume=False
):
"""
Lower level method for starting an experiment. When use this method, one should end the experiment manually
and the status of the recorder may not be handled properly. Here is the example code:
@@ -79,8 +95,12 @@ class QlibRecorder:
Parameters
----------
experiment_id : str
id of the experiment one wants to start.
experiment_name : str
the name of the experiment to be started
recorder_id : str
id of the recorder under the experiment one wants to start.
recorder_name : str
name of the recorder under the experiment one wants to start.
uri : str
@@ -93,7 +113,14 @@ class QlibRecorder:
-------
An experiment instance being started.
"""
return self.exp_manager.start_exp(experiment_name, recorder_name, uri, resume)
return self.exp_manager.start_exp(
experiment_id=experiment_id,
experiment_name=experiment_name,
recorder_id=recorder_id,
recorder_name=recorder_name,
uri=uri,
resume=resume,
)
def end_exp(self, recorder_status=Recorder.STATUS_FI):
"""
@@ -202,13 +229,13 @@ class QlibRecorder:
- no id or name specified, return the active experiment.
- if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given id or name, and the experiment is set to be active.
- if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given id or name.
- If `active experiment` not exists:
- no id or name specified, create a default experiment, and the experiment is set to be active.
- if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given name or the default experiment, and the experiment is set to be active.
- if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given name or the default experiment.
- Else If '`create`' is False:
@@ -260,7 +287,7 @@ class QlibRecorder:
-------
An experiment instance with given id or name.
"""
return self.exp_manager.get_exp(experiment_id, experiment_name, create)
return self.exp_manager.get_exp(experiment_id, experiment_name, create, start=False)
def delete_exp(self, experiment_id=None, experiment_name=None):
"""
@@ -304,7 +331,7 @@ class QlibRecorder:
"""
self.exp_manager.set_uri(uri)
def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=None):
def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=None) -> Recorder:
"""
Method for retrieving a recorder.
@@ -358,7 +385,7 @@ class QlibRecorder:
A recorder instance.
"""
return self.get_exp(experiment_name=experiment_name, create=False).get_recorder(
recorder_id, recorder_name, create=False
recorder_id, recorder_name, create=False, start=False
)
def delete_recorder(self, recorder_id=None, recorder_name=None):
@@ -416,6 +443,12 @@ class QlibRecorder:
"""
self.get_exp().get_recorder().save_objects(local_path, artifact_path, **kwargs)
def load_object(self, name: Text):
"""
Method for loading an object from artifacts in the experiment in the uri.
"""
return self.get_exp().get_recorder().load_object(name)
def log_params(self, **kwargs):
"""
Method for logging parameters during an experiment. In addition to using ``R``, one can also log to a specific recorder after getting it with `get_recorder` API.

View File

@@ -1,14 +1,14 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import mlflow
import mlflow, logging
from mlflow.entities import ViewType
from mlflow.exceptions import MlflowException
from pathlib import Path
from .recorder import Recorder, MLflowRecorder
from ..log import get_module_logger
logger = get_module_logger("workflow", "INFO")
logger = get_module_logger("workflow", logging.INFO)
class Experiment:
@@ -39,12 +39,14 @@ class Experiment:
output["recorders"] = list(recorders.keys())
return output
def start(self, recorder_name=None, resume=False):
def start(self, *, recorder_id=None, recorder_name=None, resume=False):
"""
Start the experiment and set it to be active. This method will also start a new recorder.
Parameters
----------
recorder_id : str
the id of the recorder to be created.
recorder_name : str
the name of the recorder to be created.
resume : bool
@@ -107,24 +109,24 @@ class Experiment:
"""
raise NotImplementedError(f"Please implement the `delete_recorder` method.")
def get_recorder(self, recorder_id=None, recorder_name=None, create: bool = True):
def get_recorder(self, recorder_id=None, recorder_name=None, create: bool = True, start: bool = False):
"""
Retrieve a Recorder for user. When user specify recorder id and name, the method will try to return the
specific recorder. When user does not provide recorder id or name, the method will try to return the current
active recorder. The `create` argument determines whether the method will automatically create a new recorder
according to user's specification if the recorder hasn't been created before
according to user's specification if the recorder hasn't been created before.
* If `create` is True:
* If `active recorder` exists:
* no id or name specified, return the active recorder.
* if id or name is specified, return the specified recorder. If no such exp found, create a new recorder with given id or name, and the recorder shoud be active.
* if id or name is specified, return the specified recorder. If no such exp found, create a new recorder with given id or name. If `start` is set to be True, the recorder is set to be active.
* If `active recorder` not exists:
* no id or name specified, create a new recorder.
* if id or name is specified, return the specified experiment. If no such exp found, create a new recorder with given id or name, and the recorder shoud be active.
* if id or name is specified, return the specified experiment. If no such exp found, create a new recorder with given id or name. If `start` is set to be True, the recorder is set to be active.
* Else If `create` is False:
@@ -146,6 +148,8 @@ class Experiment:
the name of the recorder to be deleted.
create : boolean
create the recorder if it hasn't been created before.
start : boolean
start the new recorder if one is created.
Returns
-------
@@ -159,8 +163,11 @@ class Experiment:
if create:
recorder, is_new = self._get_or_create_rec(recorder_id=recorder_id, recorder_name=recorder_name)
else:
recorder, is_new = self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name), False
if is_new:
recorder, is_new = (
self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name),
False,
)
if is_new and start:
self.active_recorder = recorder
# start the recorder
self.active_recorder.start_run()
@@ -174,7 +181,10 @@ class Experiment:
try:
if recorder_id is None and recorder_name is None:
recorder_name = self._default_rec_name
return self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name), False
return (
self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name),
False,
)
except ValueError:
if recorder_name is None:
recorder_name = self._default_rec_name
@@ -230,14 +240,14 @@ class MLflowExperiment(Experiment):
def __repr__(self):
return "{name}(id={id}, info={info})".format(name=self.__class__.__name__, id=self.id, info=self.info)
def start(self, recorder_name=None, resume=False):
def start(self, *, recorder_id=None, recorder_name=None, resume=False):
logger.info(f"Experiment {self.id} starts running ...")
# Get or create recorder
if recorder_name is None:
recorder_name = self._default_rec_name
# resume the recorder
if resume:
recorder, _ = self._get_or_create_rec(recorder_name=recorder_name)
recorder, _ = self._get_or_create_rec(recorder_id=recorder_id, recorder_name=recorder_name)
# create a new recorder
else:
recorder = self.create_recorder(recorder_name)

View File

@@ -4,7 +4,7 @@
import mlflow
from mlflow.exceptions import MlflowException
from mlflow.entities import ViewType
import os
import os, logging
from pathlib import Path
from contextlib import contextmanager
from typing import Optional, Text
@@ -14,7 +14,7 @@ from ..config import C
from .recorder import Recorder
from ..log import get_module_logger
logger = get_module_logger("workflow", "INFO")
logger = get_module_logger("workflow", logging.INFO)
class ExpManager:
@@ -33,7 +33,10 @@ class ExpManager:
def start_exp(
self,
*,
experiment_id: Optional[Text] = None,
experiment_name: Optional[Text] = None,
recorder_id: Optional[Text] = None,
recorder_name: Optional[Text] = None,
uri: Optional[Text] = None,
resume: bool = False,
@@ -45,8 +48,12 @@ class ExpManager:
Parameters
----------
experiment_id : str
id of the active experiment.
experiment_name : str
name of the active experiment.
recorder_id : str
id of the recorder to be started.
recorder_name : str
name of the recorder to be started.
uri : str
@@ -102,10 +109,9 @@ class ExpManager:
"""
raise NotImplementedError(f"Please implement the `search_records` method.")
def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True):
def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True, start: bool = False):
"""
Retrieve an experiment. This method includes getting an active experiment, and get_or_create a specific experiment.
The returned experiment will be active.
When user specify experiment id and name, the method will try to return the specific experiment.
When user does not provide recorder id or name, the method will try to return the current active experiment.
@@ -117,12 +123,12 @@ class ExpManager:
* If `active experiment` exists:
* no id or name specified, return the active experiment.
* if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given id or name, and the experiment is set to be active.
* if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given id or name. If `start` is set to be True, the experiment is set to be active.
* If `active experiment` not exists:
* no id or name specified, create a default experiment.
* if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given id or name, and the experiment is set to be active.
* if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given id or name. If `start` is set to be True, the experiment is set to be active.
* Else If `create` is False:
@@ -144,6 +150,8 @@ class ExpManager:
name of the experiment to return.
create : boolean
create the experiment it if hasn't been created before.
start : boolean
start the new experiment if one is created.
Returns
-------
@@ -159,8 +167,11 @@ class ExpManager:
if create:
exp, is_new = self._get_or_create_exp(experiment_id=experiment_id, experiment_name=experiment_name)
else:
exp, is_new = self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name), False
if is_new:
exp, is_new = (
self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name),
False,
)
if is_new and start:
self.active_experiment = exp
# start the recorder
self.active_experiment.start()
@@ -172,7 +183,10 @@ class ExpManager:
automatically create a new experiment based on the given id and name.
"""
try:
return self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name), False
return (
self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name),
False,
)
except ValueError:
if experiment_name is None:
experiment_name = self._default_exp_name
@@ -291,7 +305,10 @@ class MLflowExpManager(ExpManager):
def start_exp(
self,
*,
experiment_id: Optional[Text] = None,
experiment_name: Optional[Text] = None,
recorder_id: Optional[Text] = None,
recorder_name: Optional[Text] = None,
uri: Optional[Text] = None,
resume: bool = False,
@@ -301,11 +318,11 @@ class MLflowExpManager(ExpManager):
# Create experiment
if experiment_name is None:
experiment_name = self._default_exp_name
experiment, _ = self._get_or_create_exp(experiment_name=experiment_name)
experiment, _ = self._get_or_create_exp(experiment_id=experiment_id, experiment_name=experiment_name)
# Set up active experiment
self.active_experiment = experiment
# Start the experiment
self.active_experiment.start(recorder_name, resume)
self.active_experiment.start(recorder_id=recorder_id, recorder_name=recorder_name, resume=resume)
return self.active_experiment

View File

View File

@@ -0,0 +1,304 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
OnlineManager can manage a set of `Online Strategy <#Online Strategy>`_ and run them dynamically.
With the change of time, the decisive models will be also changed. In this module, we call those contributing models `online` models.
In every routine(such as every day or every minute), the `online` models may be changed and the prediction of them needs to be updated.
So this module provides a series of methods to control this process.
This module also provides a method to simulate `Online Strategy <#Online Strategy>`_ in history.
Which means you can verify your strategy or find a better one.
There are 4 total situations for using different trainers in different situations:
========================= ===================================================================================
Situations Description
========================= ===================================================================================
Online + Trainer When you REAL want to do a routine, the Trainer will help you train the models.
Online + DelayTrainer In normal online routine, whether Trainer or DelayTrainer will REAL train models
in this routine. So it is not necessary to use DelayTrainer when do a REAL routine.
Simulation + Trainer When your models have some temporal dependence on the previous models, then you
need to consider using Trainer. This means it will REAL train your models in
every routine and prepare signals for every routine.
Simulation + DelayTrainer When your models don't have any temporal dependence, you can use DelayTrainer
for the ability to multitasking. It means all tasks in all routines
can be REAL trained at the end of simulating. The signals will be prepared well at
different time segments (based on whether or not any new model is online).
========================= ===================================================================================
"""
import logging
from typing import Callable, Dict, List, Union
import pandas as pd
from qlib import get_module_logger
from qlib.data.data import D
from qlib.log import set_global_logger_level
from qlib.model.ens.ensemble import AverageEnsemble
from qlib.model.trainer import DelayTrainerR, Trainer, TrainerR
from qlib.utils import flatten_dict
from qlib.utils.serial import Serializable
from qlib.workflow.online.strategy import OnlineStrategy
from qlib.workflow.task.collect import MergeCollector
class OnlineManager(Serializable):
"""
OnlineManager can manage online models with `Online Strategy <#Online Strategy>`_.
It also provides a history recording of which models are online at what time.
"""
STATUS_SIMULATING = "simulating" # when calling `simulate`
STATUS_NORMAL = "normal" # the normal status
def __init__(
self,
strategies: Union[OnlineStrategy, List[OnlineStrategy]],
trainer: Trainer = None,
begin_time: Union[str, pd.Timestamp] = None,
freq="day",
):
"""
Init OnlineManager.
One OnlineManager must have at least one OnlineStrategy.
Args:
strategies (Union[OnlineStrategy, List[OnlineStrategy]]): an instance of OnlineStrategy or a list of OnlineStrategy
begin_time (Union[str,pd.Timestamp], optional): the OnlineManager will begin at this time. Defaults to None for using the latest date.
trainer (Trainer): the trainer to train task. None for using TrainerR.
freq (str, optional): data frequency. Defaults to "day".
"""
self.logger = get_module_logger(self.__class__.__name__)
if not isinstance(strategies, list):
strategies = [strategies]
self.strategies = strategies
self.freq = freq
if begin_time is None:
begin_time = D.calendar(freq=self.freq).max()
self.begin_time = pd.Timestamp(begin_time)
self.cur_time = self.begin_time
# OnlineManager will recorder the history of online models, which is a dict like {pd.Timestamp, {strategy, [online_models]}}.
self.history = {}
if trainer is None:
trainer = TrainerR()
self.trainer = trainer
self.signals = None
self.status = self.STATUS_NORMAL
def first_train(self, strategies: List[OnlineStrategy] = None, model_kwargs: dict = {}):
"""
Get tasks from every strategy's first_tasks method and train them.
If using DelayTrainer, it can finish training all together after every strategy's first_tasks.
Args:
strategies (List[OnlineStrategy]): the strategies list (need this param when adding strategies). None for use default strategies.
model_kwargs (dict): the params for `prepare_online_models`
"""
if strategies is None:
strategies = self.strategies
for strategy in strategies:
self.logger.info(f"Strategy `{strategy.name_id}` begins first training...")
tasks = strategy.first_tasks()
models = self.trainer.train(tasks, experiment_name=strategy.name_id)
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
self.logger.info(f"Finished training {len(models)} models.")
online_models = strategy.prepare_online_models(models, **model_kwargs)
self.history.setdefault(self.cur_time, {})[strategy] = online_models
def routine(
self,
cur_time: Union[str, pd.Timestamp] = None,
task_kwargs: dict = {},
model_kwargs: dict = {},
signal_kwargs: dict = {},
):
"""
Typical update process for every strategy and record the online history.
The typical update process after a routine, such as day by day or month by month.
The process is: Update predictions -> Prepare tasks -> Prepare online models -> Prepare signals.
If using DelayTrainer, it can finish training all together after every strategy's prepare_tasks.
Args:
cur_time (Union[str,pd.Timestamp], optional): run routine method in this time. Defaults to None.
task_kwargs (dict): the params for `prepare_tasks`
model_kwargs (dict): the params for `prepare_online_models`
signal_kwargs (dict): the params for `prepare_signals`
"""
if cur_time is None:
cur_time = D.calendar(freq=self.freq).max()
self.cur_time = pd.Timestamp(cur_time) # None for latest date
for strategy in self.strategies:
self.logger.info(f"Strategy `{strategy.name_id}` begins routine...")
if self.status == self.STATUS_NORMAL:
strategy.tool.update_online_pred()
tasks = strategy.prepare_tasks(self.cur_time, **task_kwargs)
models = self.trainer.train(tasks)
if self.status == self.STATUS_NORMAL or not self.trainer.is_delay():
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
self.logger.info(f"Finished training {len(models)} models.")
online_models = strategy.prepare_online_models(models, **model_kwargs)
self.history.setdefault(self.cur_time, {})[strategy] = online_models
if not self.trainer.is_delay():
self.prepare_signals(**signal_kwargs)
def get_collector(self) -> MergeCollector:
"""
Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results from every strategy.
This collector can be a basis as the signals preparation.
Returns:
MergeCollector: the collector to merge other collectors.
"""
collector_dict = {}
for strategy in self.strategies:
collector_dict[strategy.name_id] = strategy.get_collector()
return MergeCollector(collector_dict, process_list=[])
def add_strategy(self, strategies: Union[OnlineStrategy, List[OnlineStrategy]]):
"""
Add some new strategies to OnlineManager.
Args:
strategy (Union[OnlineStrategy, List[OnlineStrategy]]): a list of OnlineStrategy
"""
if not isinstance(strategies, list):
strategies = [strategies]
self.first_train(strategies)
self.strategies.extend(strategies)
def prepare_signals(self, prepare_func: Callable = AverageEnsemble(), over_write=False):
"""
After preparing the data of the last routine (a box in box-plot) which means the end of the routine, we can prepare trading signals for the next routine.
NOTE: Given a set prediction, all signals before these prediction end times will be prepared well.
Even if the latest signal already exists, the latest calculation result will be overwritten.
.. note::
Given a prediction of a certain time, all signals before this time will be prepared well.
Args:
prepare_func (Callable, optional): Get signals from a dict after collecting. Defaults to AverageEnsemble(), the results collected by MergeCollector must be {xxx:pred}.
over_write (bool, optional): If True, the new signals will overwrite. If False, the new signals will append to the end of signals. Defaults to False.
Returns:
pd.DataFrame: the signals.
"""
signals = prepare_func(self.get_collector()())
old_signals = self.signals
if old_signals is not None and not over_write:
old_max = old_signals.index.get_level_values("datetime").max()
new_signals = signals.loc[old_max:]
signals = pd.concat([old_signals, new_signals], axis=0)
else:
new_signals = signals
self.logger.info(f"Finished preparing new {len(new_signals)} signals.")
self.signals = signals
return new_signals
def get_signals(self) -> Union[pd.Series, pd.DataFrame]:
"""
Get prepared online signals.
Returns:
Union[pd.Series, pd.DataFrame]: pd.Series for only one signals every datetime.
pd.DataFrame for multiple signals, for example, buy and sell operations use different trading signals.
"""
return self.signals
SIM_LOG_LEVEL = logging.INFO + 1 # when simulating, reduce information
SIM_LOG_NAME = "SIMULATE_INFO"
def simulate(
self, end_time, frequency="day", task_kwargs={}, model_kwargs={}, signal_kwargs={}
) -> Union[pd.Series, pd.DataFrame]:
"""
Starting from the current time, this method will simulate every routine in OnlineManager until the end time.
Considering the parallel training, the models and signals can be prepared after all routine simulating.
The delay training way can be ``DelayTrainer`` and the delay preparing signals way can be ``delay_prepare``.
Args:
end_time: the time the simulation will end
frequency: the calendar frequency
task_kwargs (dict): the params for `prepare_tasks`
model_kwargs (dict): the params for `prepare_online_models`
signal_kwargs (dict): the params for `prepare_signals`
Returns:
Union[pd.Series, pd.DataFrame]: pd.Series for only one signals every datetime.
pd.DataFrame for multiple signals, for example, buy and sell operations use different trading signals.
"""
self.status = self.STATUS_SIMULATING
cal = D.calendar(start_time=self.cur_time, end_time=end_time, freq=frequency)
self.first_train()
simulate_level = self.SIM_LOG_LEVEL
set_global_logger_level(simulate_level)
logging.addLevelName(simulate_level, self.SIM_LOG_NAME)
for cur_time in cal:
self.logger.log(level=simulate_level, msg=f"Simulating at {str(cur_time)}......")
self.routine(
cur_time,
task_kwargs=task_kwargs,
model_kwargs=model_kwargs,
signal_kwargs=signal_kwargs,
)
# delay prepare the models and signals
if self.trainer.is_delay():
self.delay_prepare(model_kwargs=model_kwargs, signal_kwargs=signal_kwargs)
# FIXME: get logging level firstly and restore it here
set_global_logger_level(logging.DEBUG)
self.logger.info(f"Finished preparing signals")
self.status = self.STATUS_NORMAL
return self.get_signals()
def delay_prepare(self, model_kwargs={}, signal_kwargs={}):
"""
Prepare all models and signals if something is waiting for preparation.
Args:
model_kwargs: the params for `end_train`
signal_kwargs: the params for `prepare_signals`
"""
last_models = {}
signals_time = D.calendar()[0]
need_prepare = False
for cur_time, strategy_models in self.history.items():
self.cur_time = cur_time
for strategy, models in strategy_models.items():
# only new online models need to prepare
if last_models.setdefault(strategy, set()) != set(models):
models = self.trainer.end_train(models, experiment_name=strategy.name_id, **model_kwargs)
strategy.tool.reset_online_tag(models)
need_prepare = True
last_models[strategy] = set(models)
if need_prepare:
# NOTE: Assumption: the predictions of online models need less than next cur_time, or this method will work in a wrong way.
self.prepare_signals(**signal_kwargs)
if signals_time > cur_time:
self.logger.warn(
f"The signals have already parpred to {signals_time} by last preparation, but current time is only {cur_time}. This may be because the online models predict more than they should, which can cause signals to be contaminated by the offline models."
)
need_prepare = False
signals_time = self.signals.index.get_level_values("datetime").max()

View File

@@ -0,0 +1,211 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
OnlineStrategy module is an element of online serving.
"""
from copy import deepcopy
from typing import List, Tuple, Union
from qlib.data.data import D
from qlib.log import get_module_logger
from qlib.model.ens.group import RollingGroup
from qlib.workflow.online.utils import OnlineTool, OnlineToolR
from qlib.workflow.recorder import Recorder
from qlib.workflow.task.collect import Collector, RecorderCollector
from qlib.workflow.task.gen import RollingGen, task_generator
from qlib.workflow.task.utils import TimeAdjuster
class OnlineStrategy:
"""
OnlineStrategy is working with `Online Manager <#Online Manager>`_, responding to how the tasks are generated, the models are updated and signals are prepared.
"""
def __init__(self, name_id: str):
"""
Init OnlineStrategy.
This module **MUST** use `Trainer <../reference/api.html#Trainer>`_ to finishing model training.
Args:
name_id (str): a unique name or id.
trainer (Trainer, optional): a instance of Trainer. Defaults to None.
"""
self.name_id = name_id
self.logger = get_module_logger(self.__class__.__name__)
self.tool = OnlineTool()
def prepare_tasks(self, cur_time, **kwargs) -> List[dict]:
"""
After the end of a routine, check whether we need to prepare and train some new tasks based on cur_time (None for latest)..
Return the new tasks waiting for training.
You can find the last online models by OnlineTool.online_models.
"""
raise NotImplementedError(f"Please implement the `prepare_tasks` method.")
def prepare_online_models(self, trained_models, cur_time=None) -> List[object]:
"""
Select some models from trained models and set them to online models.
This is a typical implementation to online all trained models, you can override it to implement the complex method.
You can find the last online models by OnlineTool.online_models if you still need them.
NOTE: Reset all online models to trained models. If there are no trained models, then do nothing.
Args:
models (list): a list of models.
cur_time (pd.Dataframe): current time from OnlineManger. None for the latest.
Returns:
List[object]: a list of online models.
"""
if not trained_models:
return self.tool.online_models()
self.tool.reset_online_tag(trained_models)
return trained_models
def first_tasks(self) -> List[dict]:
"""
Generate a series of tasks firstly and return them.
"""
raise NotImplementedError(f"Please implement the `first_tasks` method.")
def get_collector(self) -> Collector:
"""
Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect different results of this strategy.
For example:
1) collect predictions in Recorder
2) collect signals in a txt file
Returns:
Collector
"""
raise NotImplementedError(f"Please implement the `get_collector` method.")
class RollingStrategy(OnlineStrategy):
"""
This example strategy always uses the latest rolling model sas online models.
"""
def __init__(
self,
name_id: str,
task_template: Union[dict, List[dict]],
rolling_gen: RollingGen,
):
"""
Init RollingStrategy.
Assumption: the str of name_id, the experiment name, and the trainer's experiment name are the same.
Args:
name_id (str): a unique name or id. Will be also the name of the Experiment.
task_template (Union[dict, List[dict]]): a list of task_template or a single template, which will be used to generate many tasks using rolling_gen.
rolling_gen (RollingGen): an instance of RollingGen
"""
super().__init__(name_id=name_id)
self.exp_name = self.name_id
if not isinstance(task_template, list):
task_template = [task_template]
self.task_template = task_template
self.rg = rolling_gen
self.tool = OnlineToolR(self.exp_name)
self.ta = TimeAdjuster()
def get_collector(self, process_list=[RollingGroup()], rec_key_func=None, rec_filter_func=None, artifacts_key=None):
"""
Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results. The returned collector must distinguish results in different models.
Assumption: the models can be distinguished based on the model name and rolling test segments.
If you do not want this assumption, please implement your method or use another rec_key_func.
Args:
rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
artifacts_key (List[str], optional): the artifacts key you want to get. If None, get all artifacts.
"""
def rec_key(recorder):
task_config = recorder.load_object("task")
model_key = task_config["model"]["class"]
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
return model_key, rolling_key
if rec_key_func is None:
rec_key_func = rec_key
artifacts_collector = RecorderCollector(
experiment=self.exp_name,
process_list=process_list,
rec_key_func=rec_key_func,
rec_filter_func=rec_filter_func,
artifacts_key=artifacts_key,
)
return artifacts_collector
def first_tasks(self) -> List[dict]:
"""
Use rolling_gen to generate different tasks based on task_template.
Returns:
List[dict]: a list of tasks
"""
return task_generator(
tasks=self.task_template,
generators=self.rg, # generate different date segment
)
def prepare_tasks(self, cur_time) -> List[dict]:
"""
Prepare new tasks based on cur_time (None for the latest).
You can find the last online models by OnlineToolR.online_models.
Returns:
List[dict]: a list of new tasks.
"""
latest_records, max_test = self._list_latest(self.tool.online_models())
if max_test is None:
self.logger.warn(f"No latest online recorders, no new tasks.")
return []
calendar_latest = D.calendar(end_time=cur_time)[-1] if cur_time is None else cur_time
self.logger.info(
f"The interval between current time {calendar_latest} and last rolling test begin time {max_test[0]} is {self.ta.cal_interval(calendar_latest, max_test[0])}, the rolling step is {self.rg.step}"
)
if self.ta.cal_interval(calendar_latest, max_test[0]) >= self.rg.step:
old_tasks = []
tasks_tmp = []
for rec in latest_records:
task = rec.load_object("task")
old_tasks.append(deepcopy(task))
test_begin = task["dataset"]["kwargs"]["segments"]["test"][0]
# modify the test segment to generate new tasks
task["dataset"]["kwargs"]["segments"]["test"] = (test_begin, calendar_latest)
tasks_tmp.append(task)
new_tasks_tmp = task_generator(tasks_tmp, self.rg)
new_tasks = [task for task in new_tasks_tmp if task not in old_tasks]
return new_tasks
return []
def _list_latest(self, rec_list: List[Recorder]):
"""
List latest recorder form rec_list
Args:
rec_list (List[Recorder]): a list of Recorder
Returns:
List[Recorder], pd.Timestamp: the latest recorders and their test end time
"""
if len(rec_list) == 0:
return rec_list, None
max_test = max(rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] for rec in rec_list)
latest_rec = []
for rec in rec_list:
if rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] == max_test:
latest_rec.append(rec)
return latest_rec, max_test

View File

@@ -0,0 +1,160 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
Updater is a module to update artifacts such as predictions when the stock data is updating.
"""
from abc import ABCMeta, abstractmethod
import pandas as pd
from qlib import get_module_logger
from qlib.data import D
from qlib.data.dataset import DatasetH
from qlib.data.dataset.handler import DataHandlerLP
from qlib.model import Model
from qlib.utils import get_date_by_shift
from qlib.workflow.recorder import Recorder
class RMDLoader:
"""
Recorder Model Dataset Loader
"""
def __init__(self, rec: Recorder):
self.rec = rec
def get_dataset(self, start_time, end_time, segments=None) -> DatasetH:
"""
Load, config and setup dataset.
This dataset is for inference.
Args:
start_time :
the start_time of underlying data
end_time :
the end_time of underlying data
segments : dict
the segments config for dataset
Due to the time series dataset (TSDatasetH), the test segments maybe different from start_time and end_time
Returns:
DatasetH: the instance of DatasetH
"""
if segments is None:
segments = {"test": (start_time, end_time)}
dataset: DatasetH = self.rec.load_object("dataset")
dataset.config(handler_kwargs={"start_time": start_time, "end_time": end_time}, segments=segments)
dataset.setup_data(handler_kwargs={"init_type": DataHandlerLP.IT_LS})
return dataset
def get_model(self) -> Model:
return self.rec.load_object("params.pkl")
class RecordUpdater(metaclass=ABCMeta):
"""
Update a specific recorders
"""
def __init__(self, record: Recorder, *args, **kwargs):
self.record = record
self.logger = get_module_logger(self.__class__.__name__)
@abstractmethod
def update(self, *args, **kwargs):
"""
Update info for specific recorder
"""
...
class PredUpdater(RecordUpdater):
"""
Update the prediction in the Recorder
"""
def __init__(self, record: Recorder, to_date=None, hist_ref: int = 0, freq="day"):
"""
Init PredUpdater.
Args:
record : Recorder
to_date :
update to prediction to the `to_date`
hist_ref : int
Sometimes, the dataset will have historical depends.
Leave the problem to users to set the length of historical dependency
.. note::
the start_time is not included in the hist_ref
"""
# TODO: automate this hist_ref in the future.
super().__init__(record=record)
self.to_date = to_date
self.hist_ref = hist_ref
self.freq = freq
self.rmdl = RMDLoader(rec=record)
if to_date == None:
to_date = D.calendar(freq=freq)[-1]
self.to_date = pd.Timestamp(to_date)
self.old_pred = record.load_object("pred.pkl")
self.last_end = self.old_pred.index.get_level_values("datetime").max()
def prepare_data(self) -> DatasetH:
"""
Load dataset
Separating this function will make it easier to reuse the dataset
Returns:
DatasetH: the instance of DatasetH
"""
start_time_buffer = get_date_by_shift(self.last_end, -self.hist_ref + 1, clip_shift=False, freq=self.freq)
start_time = get_date_by_shift(self.last_end, 1, freq=self.freq)
seg = {"test": (start_time, self.to_date)}
dataset = self.rmdl.get_dataset(start_time=start_time_buffer, end_time=self.to_date, segments=seg)
return dataset
def update(self, dataset: DatasetH = None):
"""
Update the prediction in a recorder.
Args:
DatasetH: the instance of DatasetH. None for reprepare.
"""
# FIXME: the problem below is not solved
# The model dumped on GPU instances can not be loaded on CPU instance. Follow exception will raised
# RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
# https://github.com/pytorch/pytorch/issues/16797
start_time = get_date_by_shift(self.last_end, 1, freq=self.freq)
if start_time >= self.to_date:
self.logger.info(
f"The prediction in {self.record.info['id']} are latest ({start_time}). No need to update to {self.to_date}."
)
return
# load dataset
if dataset is None:
# For reusing the dataset
dataset = self.prepare_data()
# Load model
model = self.rmdl.get_model()
new_pred: pd.Series = model.predict(dataset)
cb_pred = pd.concat([self.old_pred, new_pred.to_frame("score")], axis=0)
cb_pred = cb_pred.sort_index()
self.record.save_objects(**{"pred.pkl": cb_pred})
self.logger.info(f"Finish updating new {new_pred.shape[0]} predictions in {self.record.info['id']}.")

View File

@@ -0,0 +1,168 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
OnlineTool is a module to set and unset a series of `online` models.
The `online` models are some decisive models in some time points, which can be changed with the change of time.
This allows us to use efficient submodels as the market-style changing.
"""
from typing import List, Union
from qlib.log import get_module_logger
from qlib.workflow.online.update import PredUpdater
from qlib.workflow.recorder import Recorder
from qlib.workflow.task.utils import list_recorders
class OnlineTool:
"""
OnlineTool will manage `online` models in an experiment that includes the model recorders.
"""
ONLINE_KEY = "online_status" # the online status key in recorder
ONLINE_TAG = "online" # the 'online' model
OFFLINE_TAG = "offline" # the 'offline' model, not for online serving
def __init__(self):
"""
Init OnlineTool.
"""
self.logger = get_module_logger(self.__class__.__name__)
def set_online_tag(self, tag, recorder: Union[list, object]):
"""
Set `tag` to the model to sign whether online.
Args:
tag (str): the tags in `ONLINE_TAG`, `OFFLINE_TAG`
recorder (Union[list,object]): the model's recorder
"""
raise NotImplementedError(f"Please implement the `set_online_tag` method.")
def get_online_tag(self, recorder: object) -> str:
"""
Given a model recorder and return its online tag.
Args:
recorder (Object): the model's recorder
Returns:
str: the online tag
"""
raise NotImplementedError(f"Please implement the `get_online_tag` method.")
def reset_online_tag(self, recorder: Union[list, object]):
"""
Offline all models and set the recorders to 'online'.
Args:
recorder (Union[list,object]):
the recorder you want to reset to 'online'.
"""
raise NotImplementedError(f"Please implement the `reset_online_tag` method.")
def online_models(self) -> list:
"""
Get current `online` models
Returns:
list: a list of `online` models.
"""
raise NotImplementedError(f"Please implement the `online_models` method.")
def update_online_pred(self, to_date=None):
"""
Update the predictions of `online` models to to_date.
Args:
to_date (pd.Timestamp): the pred before this date will be updated. None for updating to the latest.
"""
raise NotImplementedError(f"Please implement the `update_online_pred` method.")
class OnlineToolR(OnlineTool):
"""
The implementation of OnlineTool based on (R)ecorder.
"""
def __init__(self, experiment_name: str):
"""
Init OnlineToolR.
Args:
experiment_name (str): the experiment name.
"""
super().__init__()
self.exp_name = experiment_name
def set_online_tag(self, tag, recorder: Union[Recorder, List]):
"""
Set `tag` to the model's recorder to sign whether online.
Args:
tag (str): the tags in `ONLINE_TAG`, `NEXT_ONLINE_TAG`, `OFFLINE_TAG`
recorder (Union[Recorder, List]): a list of Recorder or an instance of Recorder
"""
if isinstance(recorder, Recorder):
recorder = [recorder]
for rec in recorder:
rec.set_tags(**{self.ONLINE_KEY: tag})
self.logger.info(f"Set {len(recorder)} models to '{tag}'.")
def get_online_tag(self, recorder: Recorder) -> str:
"""
Given a model recorder and return its online tag.
Args:
recorder (Recorder): an instance of recorder
Returns:
str: the online tag
"""
tags = recorder.list_tags()
return tags.get(self.ONLINE_KEY, self.OFFLINE_TAG)
def reset_online_tag(self, recorder: Union[Recorder, List]):
"""
Offline all models and set the recorders to 'online'.
Args:
recorder (Union[Recorder, List]):
the recorder you want to reset to 'online'.
"""
if isinstance(recorder, Recorder):
recorder = [recorder]
recs = list_recorders(self.exp_name)
self.set_online_tag(self.OFFLINE_TAG, list(recs.values()))
self.set_online_tag(self.ONLINE_TAG, recorder)
def online_models(self) -> list:
"""
Get current `online` models
Returns:
list: a list of `online` models.
"""
return list(list_recorders(self.exp_name, lambda rec: self.get_online_tag(rec) == self.ONLINE_TAG).values())
def update_online_pred(self, to_date=None):
"""
Update the predictions of online models to to_date.
Args:
to_date (pd.Timestamp): the pred before this date will be updated. None for updating to latest time in Calendar.
"""
online_models = self.online_models()
for rec in online_models:
hist_ref = 0
task = rec.load_object("task")
# Special treatment of historical dependencies
if task["dataset"]["class"] == "TSDatasetH":
hist_ref = task["dataset"]["kwargs"]["step_len"]
PredUpdater(rec, to_date=to_date, hist_ref=hist_ref).update()
self.logger.info(f"Finished updating {len(online_models)} online model predictions of {self.exp_name}.")

View File

@@ -1,7 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import re
import re, logging
import pandas as pd
from pathlib import Path
from pprint import pprint
@@ -13,10 +13,10 @@ from ..data.dataset.handler import DataHandlerLP
from ..utils import init_instance_by_config, get_module_by_module_path
from ..log import get_module_logger
from ..utils import flatten_dict
from ..contrib.eva.alpha import calc_ic, calc_long_short_return
from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec
from ..contrib.strategy.strategy import BaseStrategy
logger = get_module_logger("workflow", "INFO")
logger = get_module_logger("workflow", logging.INFO)
class RecordTemp:
@@ -39,7 +39,13 @@ class RecordTemp:
return "/".join(names)
def __init__(self, recorder):
self.recorder = recorder
self._recorder = recorder
@property
def recorder(self):
if self._recorder is None:
raise ValueError("This RecordTemp did not set recorder yet.")
return self._recorder
def generate(self, **kwargs):
"""
@@ -145,6 +151,10 @@ class SignalRecord(RecordTemp):
del params["data_key"]
# The backend handler should be DataHandler
raw_label = self.dataset.prepare(**params)
except AttributeError:
# The data handler is initialize with `drop_raw=True`...
# So raw_label is not available
raw_label = None
self.recorder.save_objects(**{"label.pkl": raw_label})
self.dataset.__class__ = orig_cls
@@ -156,6 +166,60 @@ class SignalRecord(RecordTemp):
return super().load(name)
class HFSignalRecord(SignalRecord):
"""
This is the Signal Analysis Record class that generates the analysis results such as IC and IR. This class inherits the ``RecordTemp`` class.
"""
artifact_path = "hg_sig_analysis"
def __init__(self, recorder, **kwargs):
super().__init__(recorder=recorder)
def generate(self):
pred = self.load("pred.pkl")
raw_label = self.load("label.pkl")
long_pre, short_pre = calc_long_short_prec(pred.iloc[:, 0], raw_label.iloc[:, 0], is_alpha=True)
ic, ric = calc_ic(pred.iloc[:, 0], raw_label.iloc[:, 0])
metrics = {
"IC": ic.mean(),
"ICIR": ic.mean() / ic.std(),
"Rank IC": ric.mean(),
"Rank ICIR": ric.mean() / ric.std(),
"Long precision": long_pre.mean(),
"Short precision": short_pre.mean(),
}
objects = {"ic.pkl": ic, "ric.pkl": ric}
objects.update({"long_pre.pkl": long_pre, "short_pre.pkl": short_pre})
long_short_r, long_avg_r = calc_long_short_return(pred.iloc[:, 0], raw_label.iloc[:, 0])
metrics.update(
{
"Long-Short Average Return": long_short_r.mean(),
"Long-Short Average Sharpe": long_short_r.mean() / long_short_r.std(),
}
)
objects.update(
{
"long_short_r.pkl": long_short_r,
"long_avg_r.pkl": long_avg_r,
}
)
self.recorder.log_metrics(**metrics)
self.recorder.save_objects(**objects, artifact_path=self.get_path())
pprint(metrics)
def list(self):
paths = [
self.get_path("ic.pkl"),
self.get_path("ric.pkl"),
self.get_path("long_pre.pkl"),
self.get_path("short_pre.pkl"),
self.get_path("long_short_r.pkl"),
self.get_path("long_avg_r.pkl"),
]
return paths
class SigAnaRecord(SignalRecord):
"""
This is the Signal Analysis Record class that generates the analysis results such as IC and IR. This class inherits the ``RecordTemp`` class.
@@ -176,6 +240,9 @@ class SigAnaRecord(SignalRecord):
pred = self.load("pred.pkl")
label = self.load("label.pkl")
if label is None or not isinstance(label, pd.DataFrame) or label.empty:
logger.warn(f"Empty label.")
return
ic, ric = calc_ic(pred.iloc[:, 0], label.iloc[:, 0])
metrics = {
"IC": ic.mean(),
@@ -248,11 +315,20 @@ class PortAnaRecord(SignalRecord):
report_dict = normal_backtest(pred_score, strategy=self.strategy, **self.backtest_config)
report_normal = report_dict.get("report_df")
positions_normal = report_dict.get("positions")
self.recorder.save_objects(**{"report_normal.pkl": report_normal}, artifact_path=PortAnaRecord.get_path())
self.recorder.save_objects(**{"positions_normal.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path())
self.recorder.save_objects(
**{"report_normal.pkl": report_normal},
artifact_path=PortAnaRecord.get_path(),
)
self.recorder.save_objects(
**{"positions_normal.pkl": positions_normal},
artifact_path=PortAnaRecord.get_path(),
)
order_normal = report_dict.get("order_list")
if order_normal:
self.recorder.save_objects(**{"order_normal.pkl": order_normal}, artifact_path=PortAnaRecord.get_path())
self.recorder.save_objects(
**{"order_normal.pkl": order_normal},
artifact_path=PortAnaRecord.get_path(),
)
# analysis
analysis = dict()

View File

@@ -1,14 +1,14 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import mlflow
import mlflow, logging
import shutil, os, pickle, tempfile, codecs, pickle
from pathlib import Path
from datetime import datetime
from ..utils.objm import FileManager
from ..log import get_module_logger
logger = get_module_logger("workflow", "INFO")
logger = get_module_logger("workflow", logging.INFO)
class Recorder:
@@ -39,6 +39,9 @@ class Recorder:
def __str__(self):
return str(self.info)
def __hash__(self) -> int:
return hash(self.info["id"])
@property
def info(self):
output = dict()
@@ -232,6 +235,14 @@ class MLflowRecorder(Recorder):
client=self.client,
)
def __hash__(self) -> int:
return hash(self.info["id"])
def __eq__(self, o: object) -> bool:
if isinstance(o, MLflowRecorder):
return self.info["id"] == o.info["id"]
return False
@property
def uri(self):
return self._uri

View File

@@ -0,0 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
Task related workflow is implemented in this folder
A typical task workflow
| Step | Description |
|-----------------------+------------------------------------------------|
| TaskGen | Generating tasks. |
| TaskManager(optional) | Manage generated tasks |
| run task | retrive tasks from TaskManager and run tasks. |
"""

View File

@@ -0,0 +1,219 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
Collector module can collect objects from everywhere and process them such as merging, grouping, averaging and so on.
"""
from typing import Callable, Dict, List
from qlib.utils.serial import Serializable
from qlib.workflow import R
class Collector(Serializable):
"""The collector to collect different results"""
pickle_backend = "dill" # use dill to dump user method
def __init__(self, process_list=[]):
"""
Init Collector.
Args:
process_list (list or Callable): the list of processors or the instance of a processor to process dict.
"""
if not isinstance(process_list, list):
process_list = [process_list]
self.process_list = process_list
def collect(self) -> dict:
"""
Collect the results and return a dict like {key: things}
Returns:
dict: the dict after collecting.
For example:
{"prediction": pd.Series}
{"IC": {"Xgboost": pd.Series, "LSTM": pd.Series}}
......
"""
raise NotImplementedError(f"Please implement the `collect` method.")
@staticmethod
def process_collect(collected_dict, process_list=[], *args, **kwargs) -> dict:
"""
Do a series of processing to the dict returned by collect and return a dict like {key: things}
For example, you can group and ensemble.
Args:
collected_dict (dict): the dict return by `collect`
process_list (list or Callable): the list of processors or the instance of a processor to process dict.
The processor order is the same as the list order.
For example: [Group1(..., Ensemble1()), Group2(..., Ensemble2())]
Returns:
dict: the dict after processing.
"""
if not isinstance(process_list, list):
process_list = [process_list]
result = {}
for artifact in collected_dict:
value = collected_dict[artifact]
for process in process_list:
if not callable(process):
raise NotImplementedError(f"{type(process)} is not supported in `process_collect`.")
value = process(value, *args, **kwargs)
result[artifact] = value
return result
def __call__(self, *args, **kwargs) -> dict:
"""
Do the workflow including ``collect`` and ``process_collect``
Returns:
dict: the dict after collecting and processing.
"""
collected = self.collect()
return self.process_collect(collected, self.process_list, *args, **kwargs)
class MergeCollector(Collector):
"""
A collector to collect the results of other Collectors
For example:
We have 2 collector, which named A and B.
A can collect {"prediction": pd.Series} and B can collect {"IC": {"Xgboost": pd.Series, "LSTM": pd.Series}}.
Then after this class's collect, we can collect {"A_prediction": pd.Series, "B_IC": {"Xgboost": pd.Series, "LSTM": pd.Series}}
......
"""
def __init__(self, collector_dict: Dict[str, Collector], process_list: List[Callable] = [], merge_func=None):
"""
Init MergeCollector.
Args:
collector_dict (Dict[str,Collector]): the dict like {collector_key, Collector}
process_list (List[Callable]): the list of processors or the instance of processor to process dict.
merge_func (Callable): a method to generate outermost key. The given params are ``collector_key`` from collector_dict and ``key`` from every collector after collecting.
None for using tuple to connect them, such as "ABC"+("a","b") -> ("ABC", ("a","b")).
"""
super().__init__(process_list=process_list)
self.collector_dict = collector_dict
self.merge_func = merge_func
def collect(self) -> dict:
"""
Collect all results of collector_dict and change the outermost key to a recombination key.
Returns:
dict: the dict after collecting.
"""
collect_dict = {}
for collector_key, collector in self.collector_dict.items():
tmp_dict = collector()
for key, value in tmp_dict.items():
if self.merge_func is not None:
collect_dict[self.merge_func(collector_key, key)] = value
else:
collect_dict[(collector_key, key)] = value
return collect_dict
class RecorderCollector(Collector):
ART_KEY_RAW = "__raw"
def __init__(
self,
experiment,
process_list=[],
rec_key_func=None,
rec_filter_func=None,
artifacts_path={"pred": "pred.pkl"},
artifacts_key=None,
):
"""
Init RecorderCollector.
Args:
experiment (Experiment or str): an instance of an Experiment or the name of an Experiment
process_list (list or Callable): the list of processors or the instance of a processor to process dict.
rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
artifacts_path (dict, optional): The artifacts name and its path in Recorder. Defaults to {"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}.
artifacts_key (str or List, optional): the artifacts key you want to get. If None, get all artifacts.
"""
super().__init__(process_list=process_list)
if isinstance(experiment, str):
experiment = R.get_exp(experiment_name=experiment)
self.experiment = experiment
self.artifacts_path = artifacts_path
if rec_key_func is None:
rec_key_func = lambda rec: rec.info["id"]
if artifacts_key is None:
artifacts_key = list(self.artifacts_path.keys())
self.rec_key_func = rec_key_func
self.artifacts_key = artifacts_key
self.rec_filter_func = rec_filter_func
def collect(self, artifacts_key=None, rec_filter_func=None, only_exist=True) -> dict:
"""
Collect different artifacts based on recorder after filtering.
Args:
artifacts_key (str or List, optional): the artifacts key you want to get. If None, use the default.
rec_filter_func (Callable, optional): filter the recorder by return True or False. If None, use the default.
only_exist (bool, optional): if only collect the artifacts when a recorder really has.
If True, the recorder with exception when loading will not be collected. But if False, it will raise the exception.
Returns:
dict: the dict after collected like {artifact: {rec_key: object}}
"""
if artifacts_key is None:
artifacts_key = self.artifacts_key
if rec_filter_func is None:
rec_filter_func = self.rec_filter_func
if isinstance(artifacts_key, str):
artifacts_key = [artifacts_key]
collect_dict = {}
# filter records
recs = self.experiment.list_recorders()
recs_flt = {}
for rid, rec in recs.items():
if rec_filter_func is None or rec_filter_func(rec):
recs_flt[rid] = rec
for _, rec in recs_flt.items():
rec_key = self.rec_key_func(rec)
for key in artifacts_key:
if self.ART_KEY_RAW == key:
artifact = rec
else:
try:
artifact = rec.load_object(self.artifacts_path[key])
except Exception as e:
if only_exist:
# only collect existing artifact
continue
raise e
collect_dict.setdefault(key, {})[rec_key] = artifact
return collect_dict
def get_exp_name(self) -> str:
"""
Get experiment name
Returns:
str: experiment name
"""
return self.experiment.name

231
qlib/workflow/task/gen.py Normal file
View File

@@ -0,0 +1,231 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
TaskGenerator module can generate many tasks based on TaskGen and some task templates.
"""
import abc
import copy
from typing import List, Union, Callable
from .utils import TimeAdjuster
def task_generator(tasks, generators) -> list:
"""
Use a list of TaskGen and a list of task templates to generate different tasks.
For examples:
There are 3 task templates a,b,c and 2 TaskGen A,B. A will generates 2 tasks from a template and B will generates 3 tasks from a template.
task_generator([a, b, c], [A, B]) will finally generate 3*2*3 = 18 tasks.
Parameters
----------
tasks : List[dict] or dict
a list of task templates or a single task
generators : List[TaskGen] or TaskGen
a list of TaskGen or a single TaskGen
Returns
-------
list
a list of tasks
"""
if isinstance(tasks, dict):
tasks = [tasks]
if isinstance(generators, TaskGen):
generators = [generators]
# generate gen_task_list
for gen in generators:
new_task_list = []
for task in tasks:
new_task_list.extend(gen.generate(task))
tasks = new_task_list
return tasks
class TaskGen(metaclass=abc.ABCMeta):
"""
The base class for generating different tasks
Example 1:
input: a specific task template and rolling steps
output: rolling version of the tasks
Example 2:
input: a specific task template and losses list
output: a set of tasks with different losses
"""
@abc.abstractmethod
def generate(self, task: dict) -> List[dict]:
"""
Generate different tasks based on a task template
Parameters
----------
task: dict
a task template
Returns
-------
typing.List[dict]:
A list of tasks
"""
pass
def __call__(self, *args, **kwargs):
"""
This is just a syntactic sugar for generate
"""
return self.generate(*args, **kwargs)
def handler_mod(task: dict, rolling_gen):
"""
Help to modify the handler end time when using RollingGen
Args:
task (dict): a task template
rg (RollingGen): an instance of RollingGen
"""
try:
interval = rolling_gen.ta.cal_interval(
task["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"],
task["dataset"]["kwargs"]["segments"][rolling_gen.test_key][1],
)
# if end_time < the end of test_segments, then change end_time to allow load more data
if interval < 0:
task["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"] = copy.deepcopy(
task["dataset"]["kwargs"]["segments"][rolling_gen.test_key][1]
)
except KeyError:
# Maybe dataset do not have handler, then do nothing.
pass
class RollingGen(TaskGen):
ROLL_EX = TimeAdjuster.SHIFT_EX # fixed start date, expanding end date
ROLL_SD = TimeAdjuster.SHIFT_SD # fixed segments size, slide it from start date
def __init__(self, step: int = 40, rtype: str = ROLL_EX, ds_extra_mod_func: Union[None, Callable] = handler_mod):
"""
Generate tasks for rolling
Parameters
----------
step : int
step to rolling
rtype : str
rolling type (expanding, sliding)
ds_extra_mod_func: Callable
A method like: handler_mod(task: dict, rg: RollingGen)
Do some extra action after generating a task. For example, use ``handler_mod`` to modify the end time of the handler of a dataset.
"""
self.step = step
self.rtype = rtype
self.ds_extra_mod_func = ds_extra_mod_func
self.ta = TimeAdjuster(future=True)
self.test_key = "test"
self.train_key = "train"
def generate(self, task: dict) -> List[dict]:
"""
Converting the task into a rolling task.
Parameters
----------
task: dict
A dict describing a task. For example.
.. code-block:: python
DEFAULT_TASK = {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
},
"dataset": {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": "csi100",
},
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-20"), # Please avoid leaking the future test data into validation
"test": ("2017-01-01", "2020-08-01"),
},
},
},
"record": [
{
"class": "SignalRecord",
"module_path": "qlib.workflow.record_temp",
},
]
}
Returns
----------
List[dict]: a list of tasks
"""
res = []
prev_seg = None
test_end = None
while True:
t = copy.deepcopy(task)
# calculate segments
if prev_seg is None:
# First rolling
# 1) prepare the end point
segments: dict = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"]))
test_end = self.ta.max() if segments[self.test_key][1] is None else segments[self.test_key][1]
# 2) and init test segments
test_start_idx = self.ta.align_idx(segments[self.test_key][0])
segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1))
else:
segments = {}
try:
for k, seg in prev_seg.items():
# decide how to shift
# expanding only for train data, the segments size of test data and valid data won't change
if k == self.train_key and self.rtype == self.ROLL_EX:
rtype = self.ta.SHIFT_EX
else:
rtype = self.ta.SHIFT_SD
# shift the segments data
segments[k] = self.ta.shift(seg, step=self.step, rtype=rtype)
if segments[self.test_key][0] > test_end:
break
except KeyError:
# We reach the end of tasks
# No more rolling
break
# update segments of this task
t["dataset"]["kwargs"]["segments"] = copy.deepcopy(segments)
prev_seg = segments
if self.ds_extra_mod_func is not None:
self.ds_extra_mod_func(t, self)
res.append(t)
return res

View File

@@ -0,0 +1,493 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
TaskManager can fetch unused tasks automatically and manage the lifecycle of a set of tasks with error handling.
These features can run tasks concurrently and ensure every task will be used only once.
Task Manager will store all tasks in `MongoDB <https://www.mongodb.com/>`_.
Users **MUST** finished the configuration of `MongoDB <https://www.mongodb.com/>`_ when using this module.
A task in TaskManager consists of 3 parts
- tasks description: the desc will define the task
- tasks status: the status of the task
- tasks result: A user can get the task with the task description and task result.
"""
import concurrent
import pickle
import time
from contextlib import contextmanager
from typing import Callable, List
import fire
import pymongo
from bson.binary import Binary
from bson.objectid import ObjectId
from pymongo.errors import InvalidDocument
from qlib import auto_init, get_module_logger
from tqdm.cli import tqdm
from .utils import get_mongodb
class TaskManager:
"""
TaskManager
Here is what will a task looks like when it created by TaskManager
.. code-block:: python
{
'def': pickle serialized task definition. using pickle will make it easier
'filter': json-like data. This is for filtering the tasks.
'status': 'waiting' | 'running' | 'done'
'res': pickle serialized task result,
}
The tasks manager assumes that you will only update the tasks you fetched.
The mongo fetch one and update will make it date updating secure.
.. note::
Assumption: the data in MongoDB was encoded and the data out of MongoDB was decoded
Here are four status which are:
STATUS_WAITING: waiting for training
STATUS_RUNNING: training
STATUS_PART_DONE: finished some step and waiting for next step
STATUS_DONE: all work done
"""
STATUS_WAITING = "waiting"
STATUS_RUNNING = "running"
STATUS_DONE = "done"
STATUS_PART_DONE = "part_done"
ENCODE_FIELDS_PREFIX = ["def", "res"]
def __init__(self, task_pool: str = None):
"""
Init Task Manager, remember to make the statement of MongoDB url and database name firstly.
Parameters
----------
task_pool: str
the name of Collection in MongoDB
"""
self.mdb = get_mongodb()
if task_pool is not None:
self.task_pool = getattr(self.mdb, task_pool)
self.logger = get_module_logger(self.__class__.__name__)
def list(self) -> list:
"""
List the all collection(task_pool) of the db
Returns:
list
"""
return self.mdb.list_collection_names()
def _encode_task(self, task):
for prefix in self.ENCODE_FIELDS_PREFIX:
for k in list(task.keys()):
if k.startswith(prefix):
task[k] = Binary(pickle.dumps(task[k]))
return task
def _decode_task(self, task):
for prefix in self.ENCODE_FIELDS_PREFIX:
for k in list(task.keys()):
if k.startswith(prefix):
task[k] = pickle.loads(task[k])
return task
def _dict_to_str(self, flt):
return {k: str(v) for k, v in flt.items()}
def replace_task(self, task, new_task):
"""
Use a new task to replace a old one
Args:
task: old task
new_task: new task
"""
new_task = self._encode_task(new_task)
query = {"_id": ObjectId(task["_id"])}
try:
self.task_pool.replace_one(query, new_task)
except InvalidDocument:
task["filter"] = self._dict_to_str(task["filter"])
self.task_pool.replace_one(query, new_task)
def insert_task(self, task):
"""
Insert a task.
Args:
task: the task waiting for insert
Returns:
pymongo.results.InsertOneResult
"""
try:
insert_result = self.task_pool.insert_one(task)
except InvalidDocument:
task["filter"] = self._dict_to_str(task["filter"])
insert_result = self.task_pool.insert_one(task)
return insert_result
def insert_task_def(self, task_def):
"""
Insert a task to task_pool
Parameters
----------
task_def: dict
the task definition
Returns
-------
pymongo.results.InsertOneResult
"""
task = self._encode_task(
{
"def": task_def,
"filter": task_def, # FIXME: catch the raised error
"status": self.STATUS_WAITING,
}
)
insert_result = self.insert_task(task)
return insert_result
def create_task(self, task_def_l, dry_run=False, print_nt=False) -> List[str]:
"""
If the tasks in task_def_l are new, then insert new tasks into the task_pool, and record inserted_id.
If a task is not new, then just query its _id.
Parameters
----------
task_def_l: list
a list of task
dry_run: bool
if insert those new tasks to task pool
print_nt: bool
if print new task
Returns
-------
List[str]
a list of the _id of task_def_l
"""
new_tasks = []
_id_list = []
for t in task_def_l:
try:
r = self.task_pool.find_one({"filter": t})
except InvalidDocument:
r = self.task_pool.find_one({"filter": self._dict_to_str(t)})
if r is None:
new_tasks.append(t)
if not dry_run:
insert_result = self.insert_task_def(t)
_id_list.append(insert_result.inserted_id)
else:
_id_list.append(None)
else:
_id_list.append(self._decode_task(r)["_id"])
self.logger.info(f"Total Tasks: {len(task_def_l)}, New Tasks: {len(new_tasks)}")
if print_nt: # print new task
for t in new_tasks:
print(t)
if dry_run:
return []
return _id_list
def fetch_task(self, query={}, status=STATUS_WAITING) -> dict:
"""
Use query to fetch tasks.
Args:
query (dict, optional): query dict. Defaults to {}.
status (str, optional): [description]. Defaults to STATUS_WAITING.
Returns:
dict: a task(document in collection) after decoding
"""
query = query.copy()
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
query.update({"status": status})
task = self.task_pool.find_one_and_update(
query, {"$set": {"status": self.STATUS_RUNNING}}, sort=[("priority", pymongo.DESCENDING)]
)
# null will be at the top after sorting when using ASCENDING, so the larger the number higher, the higher the priority
if task is None:
return None
task["status"] = self.STATUS_RUNNING
return self._decode_task(task)
@contextmanager
def safe_fetch_task(self, query={}, status=STATUS_WAITING):
"""
Fetch task from task_pool using query with contextmanager
Parameters
----------
query: dict
the dict of query
Returns
-------
dict: a task(document in collection) after decoding
"""
task = self.fetch_task(query=query, status=status)
try:
yield task
except Exception:
if task is not None:
self.logger.info("Returning task before raising error")
self.return_task(task)
self.logger.info("Task returned")
raise
def task_fetcher_iter(self, query={}):
while True:
with self.safe_fetch_task(query=query) as task:
if task is None:
break
yield task
def query(self, query={}, decode=True):
"""
Query task in collection.
This function may raise exception `pymongo.errors.CursorNotFound: cursor id not found` if it takes too long to iterate the generator
Parameters
----------
query: dict
the dict of query
decode: bool
Returns
-------
dict: a task(document in collection) after decoding
"""
query = query.copy()
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
for t in self.task_pool.find(query):
yield self._decode_task(t)
def re_query(self, _id):
"""
Use _id to query task.
Args:
_id (str): _id of a document
Returns:
dict: a task(document in collection) after decoding
"""
t = self.task_pool.find_one({"_id": ObjectId(_id)})
return self._decode_task(t)
def commit_task_res(self, task, res, status=STATUS_DONE):
"""
Commit the result to task['res'].
Args:
task ([type]): [description]
res (object): the result you want to save
status (str, optional): STATUS_WAITING, STATUS_RUNNING, STATUS_DONE, STATUS_PART_DONE. Defaults to STATUS_DONE.
"""
# A workaround to use the class attribute.
if status is None:
status = TaskManager.STATUS_DONE
self.task_pool.update_one({"_id": task["_id"]}, {"$set": {"status": status, "res": Binary(pickle.dumps(res))}})
def return_task(self, task, status=STATUS_WAITING):
"""
Return a task to status. Alway using in error handling.
Args:
task ([type]): [description]
status (str, optional): STATUS_WAITING, STATUS_RUNNING, STATUS_DONE, STATUS_PART_DONE. Defaults to STATUS_WAITING.
"""
if status is None:
status = TaskManager.STATUS_WAITING
update_dict = {"$set": {"status": status}}
self.task_pool.update_one({"_id": task["_id"]}, update_dict)
def remove(self, query={}):
"""
Remove the task using query
Parameters
----------
query: dict
the dict of query
"""
query = query.copy()
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
self.task_pool.delete_many(query)
def task_stat(self, query={}) -> dict:
"""
Count the tasks in every status.
Args:
query (dict, optional): the query dict. Defaults to {}.
Returns:
dict
"""
query = query.copy()
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
tasks = self.query(query=query, decode=False)
status_stat = {}
for t in tasks:
status_stat[t["status"]] = status_stat.get(t["status"], 0) + 1
return status_stat
def reset_waiting(self, query={}):
"""
Reset all running task into waiting status. Can be used when some running task exit unexpected.
Args:
query (dict, optional): the query dict. Defaults to {}.
"""
query = query.copy()
# default query
if "status" not in query:
query["status"] = self.STATUS_RUNNING
return self.reset_status(query=query, status=self.STATUS_WAITING)
def reset_status(self, query, status):
query = query.copy()
if "_id" in query:
query["_id"] = ObjectId(query["_id"])
print(self.task_pool.update_many(query, {"$set": {"status": status}}))
def prioritize(self, task, priority: int):
"""
Set priority for task
Parameters
----------
task : dict
The task query from the database
priority : int
the target priority
"""
update_dict = {"$set": {"priority": priority}}
self.task_pool.update_one({"_id": task["_id"]}, update_dict)
def _get_undone_n(self, task_stat):
return task_stat.get(self.STATUS_WAITING, 0) + task_stat.get(self.STATUS_RUNNING, 0)
def _get_total(self, task_stat):
return sum(task_stat.values())
def wait(self, query={}):
task_stat = self.task_stat(query)
total = self._get_total(task_stat)
last_undone_n = self._get_undone_n(task_stat)
with tqdm(total=total, initial=total - last_undone_n) as pbar:
while True:
time.sleep(10)
undone_n = self._get_undone_n(self.task_stat(query))
pbar.update(last_undone_n - undone_n)
last_undone_n = undone_n
if undone_n == 0:
break
def __str__(self):
return f"TaskManager({self.task_pool})"
def run_task(
task_func: Callable,
task_pool: str,
query: dict = {},
force_release: bool = False,
before_status: str = TaskManager.STATUS_WAITING,
after_status: str = TaskManager.STATUS_DONE,
**kwargs,
):
"""
While the task pool is not empty (has WAITING tasks), use task_func to fetch and run tasks in task_pool
After running this method, here are 4 situations (before_status -> after_status):
STATUS_WAITING -> STATUS_DONE: use task["def"] as `task_func` param
STATUS_WAITING -> STATUS_PART_DONE: use task["def"] as `task_func` param
STATUS_PART_DONE -> STATUS_PART_DONE: use task["res"] as `task_func` param
STATUS_PART_DONE -> STATUS_DONE: use task["res"] as `task_func` param
Parameters
----------
task_func : Callable
def (task_def, **kwargs) -> <res which will be committed>
the function to run the task
task_pool : str
the name of the task pool (Collection in MongoDB)
query: dict
will use this dict to query task_pool when fetching task
force_release : bool
will the program force to release the resource
before_status : str:
the tasks in before_status will be fetched and trained. Can be STATUS_WAITING, STATUS_PART_DONE.
after_status : str:
the tasks after trained will become after_status. Can be STATUS_WAITING, STATUS_PART_DONE.
kwargs
the params for `task_func`
"""
tm = TaskManager(task_pool)
ever_run = False
while True:
with tm.safe_fetch_task(status=before_status, query=query) as task:
if task is None:
break
get_module_logger("run_task").info(task["def"])
# when fetching `WAITING` task, use task["def"] to train
if before_status == TaskManager.STATUS_WAITING:
param = task["def"]
# when fetching `PART_DONE` task, use task["res"] to train because the middle result has been saved to task["res"]
elif before_status == TaskManager.STATUS_PART_DONE:
param = task["res"]
else:
raise ValueError("The fetched task must be `STATUS_WAITING` or `STATUS_PART_DONE`!")
if force_release:
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
res = executor.submit(task_func, param, **kwargs).result()
else:
res = task_func(param, **kwargs)
tm.commit_task_res(task, res, status=after_status)
ever_run = True
return ever_run
if __name__ == "__main__":
# This is for using it in cmd
# E.g. : `python -m qlib.workflow.task.manage list`
auto_init()
fire.Fire(TaskManager)

258
qlib/workflow/task/utils.py Normal file
View File

@@ -0,0 +1,258 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""
Some tools for task management.
"""
import bisect
import pandas as pd
from qlib.data import D
from qlib.workflow import R
from qlib.config import C
from qlib.log import get_module_logger
from pymongo import MongoClient
from pymongo.database import Database
from typing import Union
def get_mongodb() -> Database:
"""
Get database in MongoDB, which means you need to declare the address and the name of a database at first.
For example:
Using qlib.init():
mongo_conf = {
"task_url": task_url, # your MongoDB url
"task_db_name": task_db_name, # database name
}
qlib.init(..., mongo=mongo_conf)
After qlib.init():
C["mongo"] = {
"task_url" : "mongodb://localhost:27017/",
"task_db_name" : "rolling_db"
}
Returns:
Database: the Database instance
"""
try:
cfg = C["mongo"]
except KeyError:
get_module_logger("task").error("Please configure `C['mongo']` before using TaskManager")
raise
client = MongoClient(cfg["task_url"])
return client.get_database(name=cfg["task_db_name"])
def list_recorders(experiment, rec_filter_func=None):
"""
List all recorders which can pass the filter in an experiment.
Args:
experiment (str or Experiment): the name of an Experiment or an instance
rec_filter_func (Callable, optional): return True to retain the given recorder. Defaults to None.
Returns:
dict: a dict {rid: recorder} after filtering.
"""
if isinstance(experiment, str):
experiment = R.get_exp(experiment_name=experiment)
recs = experiment.list_recorders()
recs_flt = {}
for rid, rec in recs.items():
if rec_filter_func is None or rec_filter_func(rec):
recs_flt[rid] = rec
return recs_flt
class TimeAdjuster:
"""
Find appropriate date and adjust date.
"""
def __init__(self, future=True, end_time=None):
self._future = future
self.cals = D.calendar(future=future, end_time=end_time)
def set_end_time(self, end_time=None):
"""
Set end time. None for use calendar's end time.
Args:
end_time
"""
self.cals = D.calendar(future=self._future, end_time=end_time)
def get(self, idx: int):
"""
Get datetime by index.
Parameters
----------
idx : int
index of the calendar
"""
if idx >= len(self.cals):
return None
return self.cals[idx]
def max(self) -> pd.Timestamp:
"""
Return the max calendar datetime
"""
return max(self.cals)
def align_idx(self, time_point, tp_type="start") -> int:
"""
Align the index of time_point in the calendar.
Parameters
----------
time_point
tp_type : str
Returns
-------
index : int
"""
time_point = pd.Timestamp(time_point)
if tp_type == "start":
idx = bisect.bisect_left(self.cals, time_point)
elif tp_type == "end":
idx = bisect.bisect_right(self.cals, time_point) - 1
else:
raise NotImplementedError(f"This type of input is not supported")
return idx
def cal_interval(self, time_point_A, time_point_B) -> int:
"""
Calculate the trading day interval (time_point_A - time_point_B)
Args:
time_point_A : time_point_A
time_point_B : time_point_B (is the past of time_point_A)
Returns:
int: the interval between A and B
"""
return self.align_idx(time_point_A) - self.align_idx(time_point_B)
def align_time(self, time_point, tp_type="start") -> pd.Timestamp:
"""
Align time_point to trade date of calendar
Args:
time_point
Time point
tp_type : str
time point type (`"start"`, `"end"`)
Returns:
pd.Timestamp
"""
return self.cals[self.align_idx(time_point, tp_type=tp_type)]
def align_seg(self, segment: Union[dict, tuple]) -> Union[dict, tuple]:
"""
Align the given date to the trade date
for example:
.. code-block:: python
input: {'train': ('2008-01-01', '2014-12-31'), 'valid': ('2015-01-01', '2016-12-31'), 'test': ('2017-01-01', '2020-08-01')}
output: {'train': (Timestamp('2008-01-02 00:00:00'), Timestamp('2014-12-31 00:00:00')),
'valid': (Timestamp('2015-01-05 00:00:00'), Timestamp('2016-12-30 00:00:00')),
'test': (Timestamp('2017-01-03 00:00:00'), Timestamp('2020-07-31 00:00:00'))}
Parameters
----------
segment
Returns
-------
Union[dict, tuple]: the start and end trade date (pd.Timestamp) between the given start and end date.
"""
if isinstance(segment, dict):
return {k: self.align_seg(seg) for k, seg in segment.items()}
elif isinstance(segment, tuple) or isinstance(segment, list):
return self.align_time(segment[0], tp_type="start"), self.align_time(segment[1], tp_type="end")
else:
raise NotImplementedError(f"This type of input is not supported")
def truncate(self, segment: tuple, test_start, days: int) -> tuple:
"""
Truncate the segment based on the test_start date
Parameters
----------
segment : tuple
time segment
test_start
days : int
The trading days to be truncated
the data in this segment may need 'days' data
Returns
---------
tuple: new segment
"""
test_idx = self.align_idx(test_start)
if isinstance(segment, tuple):
new_seg = []
for time_point in segment:
tp_idx = min(self.align_idx(time_point), test_idx - days)
assert tp_idx > 0
new_seg.append(self.get(tp_idx))
return tuple(new_seg)
else:
raise NotImplementedError(f"This type of input is not supported")
SHIFT_SD = "sliding"
SHIFT_EX = "expanding"
def shift(self, seg: tuple, step: int, rtype=SHIFT_SD) -> tuple:
"""
Shift the datatime of segment
Parameters
----------
seg :
datetime segment
step : int
rolling step
rtype : str
rolling type ("sliding" or "expanding")
Returns
--------
tuple: new segment
Raises
------
KeyError:
shift will raise error if the index(both start and end) is out of self.cal
"""
if isinstance(seg, tuple):
start_idx, end_idx = self.align_idx(seg[0], tp_type="start"), self.align_idx(seg[1], tp_type="end")
if rtype == self.SHIFT_SD:
start_idx += step
end_idx += step
elif rtype == self.SHIFT_EX:
end_idx += step
else:
raise NotImplementedError(f"This type of input is not supported")
if start_idx > len(self.cals):
raise KeyError("The segment is out of valid calendar")
return self.get(start_idx), self.get(end_idx)
else:
raise NotImplementedError(f"This type of input is not supported")

View File

@@ -1,12 +1,12 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys, traceback, signal, atexit
import sys, traceback, signal, atexit, logging
from . import R
from .recorder import Recorder
from ..log import get_module_logger
logger = get_module_logger("workflow", "INFO")
logger = get_module_logger("workflow", logging.INFO)
# function to handle the experiment when unusual program ending occurs

View File

@@ -15,7 +15,11 @@
### Download CN Data
```bash
# daily data
python get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
# 1min data (Optional for running non-high-frequency strategies)
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1min --region cn --interval 1min
```
### Downlaod US Data

View File

@@ -0,0 +1,24 @@
# Get future trading days
> `D.calendar(future=True)` will be used
## Requirements
```bash
pip install -r requirements.txt
```
## Collector Data
```bash
# parse instruments, using in qlib/instruments.
python future_trading_date_collector.py --qlib_dir ~/.qlib/qlib_data/cn_data --freq day
```
## Parameters
- qlib_dir: qlib data directory
- freq: value from [`day`, `1min`], default `day`

View File

@@ -0,0 +1,87 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys
from typing import List
from pathlib import Path
import fire
import numpy as np
import pandas as pd
from loguru import logger
# get data from baostock
import baostock as bs
CUR_DIR = Path(__file__).resolve().parent
sys.path.append(str(CUR_DIR.parent.parent))
from data_collector.utils import generate_minutes_calendar_from_daily
def read_calendar_from_qlib(qlib_dir: Path) -> pd.DataFrame:
calendar_path = qlib_dir.joinpath("calendars").joinpath("day.txt")
if not calendar_path.exists():
return pd.DataFrame()
return pd.read_csv(calendar_path, header=None)
def write_calendar_to_qlib(qlib_dir: Path, date_list: List[str], freq: str = "day"):
calendar_path = str(qlib_dir.joinpath("calendars").joinpath(f"{freq}_future.txt"))
np.savetxt(calendar_path, date_list, fmt="%s", encoding="utf-8")
logger.info(f"write future calendars success: {calendar_path}")
def generate_qlib_calendar(date_list: List[str], freq: str) -> List[str]:
print(freq)
if freq == "day":
return date_list
elif freq == "1min":
date_list = generate_minutes_calendar_from_daily(date_list, freq=freq).tolist()
return list(map(lambda x: pd.Timestamp(x).strftime("%Y-%m-%d %H:%M:%S"), date_list))
else:
raise ValueError(f"Unsupported freq: {freq}")
def future_calendar_collector(qlib_dir: [str, Path], freq: str = "day"):
"""get future calendar
Parameters
----------
qlib_dir: str or Path
qlib data directory
freq: str
value from ["day", "1min"], by default day
"""
qlib_dir = Path(qlib_dir).expanduser().resolve()
if not qlib_dir.exists():
raise FileNotFoundError(str(qlib_dir))
lg = bs.login()
if lg.error_code != "0":
logger.error(f"login error: {lg.error_msg}")
return
# read daily calendar
daily_calendar = read_calendar_from_qlib(qlib_dir)
end_year = pd.Timestamp.now().year
if daily_calendar.empty:
start_year = pd.Timestamp.now().year
else:
start_year = pd.Timestamp(daily_calendar.iloc[-1, 0]).year
rs = bs.query_trade_dates(start_date=pd.Timestamp(f"{start_year}-01-01"), end_date=f"{end_year}-12-31")
data_list = []
while (rs.error_code == "0") & rs.next():
_row_data = rs.get_row_data()
if int(_row_data[1]) == 1:
data_list.append(_row_data[0])
data_list = sorted(data_list)
date_list = generate_qlib_calendar(data_list, freq=freq)
write_calendar_to_qlib(qlib_dir, date_list, freq=freq)
bs.logout()
logger.info(f"get trading dates success: {start_year}-01-01 to {end_year}-12-31")
if __name__ == "__main__":
fire.Fire(future_calendar_collector)

View File

@@ -0,0 +1,5 @@
baostock
fire
numpy
pandas
loguru

View File

@@ -114,6 +114,8 @@ class IndexBase:
$ python collector.py save_new_companies --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data
"""
df = self.get_new_companies()
if df is None or df.empty:
raise ValueError(f"get new companies error: {self.index_name}")
df = df.drop_duplicates([self.SYMBOL_FIELD_NAME])
df.loc[:, self.INSTRUMENTS_COLUMNS].to_csv(
self.instruments_dir.joinpath(f"{self.index_name.lower()}_only_new.txt"), sep="\t", index=False, header=None
@@ -184,7 +186,10 @@ class IndexBase:
logger.info(f"start parse {self.index_name.lower()} companies.....")
instruments_columns = [self.SYMBOL_FIELD_NAME, self.START_DATE_FIELD, self.END_DATE_FIELD]
changers_df = self.get_changes()
new_df = self.get_new_companies().copy()
new_df = self.get_new_companies()
if new_df is None or new_df.empty:
raise ValueError(f"get new companies error: {self.index_name}")
new_df = new_df.copy()
logger.info("parse history companies by changes......")
for _row in tqdm(changers_df.sort_values(self.DATE_FIELD_NAME, ascending=False).itertuples(index=False)):
if _row.type == self.ADD:

View File

@@ -35,7 +35,7 @@ WIKI_INDEX_NAME_MAP = {
class WIKIIndex(IndexBase):
# NOTE: The US stock code contains "PRN", and the directory cannot be created on Windows system, use the "_" prefix
# https://superuser.com/questions/613313/why-cant-we-make-con-prn-null-folder-in-windows
INST_PREFIX = "_"
INST_PREFIX = ""
def __init__(self, index_name: str, qlib_dir: [str, Path] = None, request_retry: int = 5, retry_sleep: int = 3):
super(WIKIIndex, self).__init__(
@@ -123,7 +123,7 @@ class NASDAQ100Index(WIKIIndex):
MAX_WORKERS = 16
def filter_df(self, df: pd.DataFrame) -> pd.DataFrame:
if not (set(df.columns) - {"Company", "Ticker"}):
if len(df) >= 100 and "Ticker" in df.columns:
return df.loc[:, ["Ticker"]].copy()
@property

View File

@@ -10,7 +10,9 @@ import random
import requests
import functools
from pathlib import Path
from typing import Iterable, Tuple
import numpy as np
import pandas as pd
from lxml import etree
from loguru import logger
@@ -418,5 +420,40 @@ def get_trading_date_by_shift(trading_list: list, trading_date: pd.Timestamp, sh
return res
def generate_minutes_calendar_from_daily(
calendars: Iterable,
freq: str = "1min",
am_range: Tuple[str, str] = ("09:30:00", "11:29:00"),
pm_range: Tuple[str, str] = ("13:00:00", "14:59:00"),
) -> pd.Index:
"""generate minutes calendar
Parameters
----------
calendars: Iterable
daily calendar
freq: str
by default 1min
am_range: Tuple[str, str]
AM Time Range, by default China-Stock: ("09:30:00", "11:29:00")
pm_range: Tuple[str, str]
PM Time Range, by default China-Stock: ("13:00:00", "14:59:00")
"""
daily_format: str = "%Y-%m-%d"
res = []
for _day in calendars:
for _range in [am_range, pm_range]:
res.append(
pd.date_range(
f"{pd.Timestamp(_day).strftime(daily_format)} {_range[0]}",
f"{pd.Timestamp(_day).strftime(daily_format)} {_range[1]}",
freq=freq,
)
)
return pd.Index(sorted(set(np.hstack(res))))
if __name__ == "__main__":
assert len(get_hs_stock_symbols()) >= MINIMUM_SYMBOLS_NUM

View File

@@ -24,7 +24,12 @@ from qlib.config import REG_CN as REGION_CN
CUR_DIR = Path(__file__).resolve().parent
sys.path.append(str(CUR_DIR.parent.parent))
from data_collector.base import BaseCollector, BaseNormalize, BaseRun
from data_collector.utils import get_calendar_list, get_hs_stock_symbols, get_us_stock_symbols
from data_collector.utils import (
get_calendar_list,
get_hs_stock_symbols,
get_us_stock_symbols,
generate_minutes_calendar_from_daily,
)
INDEX_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{index_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg={begin}&end={end}"
@@ -418,21 +423,9 @@ class YahooNormalize1min(YahooNormalize, ABC):
return calendar_list_1d
def generate_1min_from_daily(self, calendars: Iterable) -> pd.Index:
res = []
daily_format = self.DAILY_FORMAT
am_range = self.AM_RANGE
pm_range = self.PM_RANGE
for _day in calendars:
for _range in [am_range, pm_range]:
res.append(
pd.date_range(
f"{_day.strftime(daily_format)} {_range[0]}",
f"{_day.strftime(daily_format)} {_range[1]}",
freq="1min",
)
)
return pd.Index(sorted(set(np.hstack(res))))
return generate_minutes_calendar_from_daily(
calendars, freq="1min", am_range=self.AM_RANGE, pm_range=self.PM_RANGE
)
def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame:
# TODO: using daily data factor

Some files were not shown because too many files have changed in this diff Show More