mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-02 18:40:58 +08:00
Merge remote-tracking branch 'microsoft/main' into data_storage
This commit is contained in:
62
.github/stale.yml
vendored
62
.github/stale.yml
vendored
@@ -1,62 +0,0 @@
|
||||
# Configuration for probot-stale - https://github.com/probot/stale
|
||||
|
||||
# Number of days of inactivity before an Issue or Pull Request becomes stale
|
||||
daysUntilStale: 60
|
||||
|
||||
# Number of days of inactivity before an Issue or Pull Request with the stale label is closed.
|
||||
# Set to false to disable. If disabled, issues still need to be closed manually, but will remain marked as stale.
|
||||
daysUntilClose: 7
|
||||
|
||||
# Only issues or pull requests with all of these labels are check if stale. Defaults to `[]` (disabled)
|
||||
onlyLabels: []
|
||||
|
||||
# Issues or Pull Requests with these labels will never be considered stale. Set to `[]` to disable
|
||||
exemptLabels:
|
||||
- bug
|
||||
- pinned
|
||||
- security
|
||||
- "[Status] Maybe Later"
|
||||
|
||||
# Set to true to ignore issues in a project (defaults to false)
|
||||
exemptProjects: false
|
||||
|
||||
# Set to true to ignore issues in a milestone (defaults to false)
|
||||
exemptMilestones: false
|
||||
|
||||
# Set to true to ignore issues with an assignee (defaults to false)
|
||||
exemptAssignees: false
|
||||
|
||||
# Label to use when marking as stale
|
||||
staleLabel: wontfix
|
||||
|
||||
# Comment to post when marking as stale. Set to `false` to disable
|
||||
markComment: >
|
||||
This issue has been automatically marked as stale because it has not had
|
||||
recent activity. It will be closed if no further activity occurs. Thank you
|
||||
for your contributions.
|
||||
|
||||
# Comment to post when removing the stale label.
|
||||
# unmarkComment: >
|
||||
# Your comment here.
|
||||
|
||||
# Comment to post when closing a stale Issue or Pull Request.
|
||||
# closeComment: >
|
||||
# Your comment here.
|
||||
|
||||
# Limit the number of actions per hour, from 1-30. Default is 30
|
||||
limitPerRun: 30
|
||||
|
||||
# Limit to only `issues` or `pulls`
|
||||
# only: issues
|
||||
|
||||
# Optionally, specify configuration settings that are specific to just 'issues' or 'pulls':
|
||||
# pulls:
|
||||
# daysUntilStale: 30
|
||||
# markComment: >
|
||||
# This pull request has been automatically marked as stale because it has not had
|
||||
# recent activity. It will be closed if no further activity occurs. Thank you
|
||||
# for your contributions.
|
||||
|
||||
# issues:
|
||||
# exemptLabels:
|
||||
# - confirmed
|
||||
24
.github/workflows/stale.yml
vendored
Normal file
24
.github/workflows/stale.yml
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
name: Mark stale issues and pull requests
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: "0 0/3 * * *"
|
||||
|
||||
jobs:
|
||||
stale:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- uses: actions/stale@v3
|
||||
with:
|
||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
stale-issue-message: 'This issue is stale because it has been open for three months with no activity. Remove the stale label or comment on the issue otherwise this will be closed in 5 days'
|
||||
stale-pr-message: 'This PR is stale because it has been open for a year with no activity. Remove the stale label or comment on the PR otherwise this will be closed in 5 days'
|
||||
stale-issue-label: 'stale'
|
||||
stale-pr-label: 'stale'
|
||||
days-before-stale: 90
|
||||
days-before-close: 5
|
||||
operations-per-run: 100
|
||||
exempt-issue-labels: 'bug,enhancement'
|
||||
remove-stale-when-updated: true
|
||||
@@ -45,16 +45,16 @@ New features under development(order by estimated release time).
|
||||
Your feedbacks about the features are very important.
|
||||
| Feature | Status |
|
||||
| -- | ------ |
|
||||
| Online serving and automatic model rolling | Under review: https://github.com/microsoft/qlib/pull/290 |
|
||||
| Planning-based portfolio optimization | Under review: https://github.com/microsoft/qlib/pull/280 |
|
||||
| Fund data supporting and analysis | Under review: https://github.com/microsoft/qlib/pull/292 |
|
||||
| Point-in-Time database | Under review: https://github.com/microsoft/qlib/pull/343 |
|
||||
| High-frequency trading | Initial opensource version under development |
|
||||
| High-frequency trading | Under review: https://github.com/microsoft/qlib/pull/408 |
|
||||
| Meta-Learning-based data selection | Initial opensource version under development |
|
||||
|
||||
Recent released features
|
||||
| Feature | Status |
|
||||
| -- | ------ |
|
||||
| Online serving and automatic model rolling | Released: https://github.com/microsoft/qlib/pull/290 |
|
||||
| DoubleEnsemble Model | Released https://github.com/microsoft/qlib/pull/286 |
|
||||
| High-frequency data processing example | Released https://github.com/microsoft/qlib/pull/257 |
|
||||
| High-frequency trading example | Part of code released https://github.com/microsoft/qlib/pull/227 |
|
||||
@@ -243,6 +243,7 @@ Qlib provides a tool named `qrun` to run the whole workflow automatically (inclu
|
||||
- Rank Label
|
||||

|
||||
-->
|
||||
- [Explanation](https://qlib.readthedocs.io/en/latest/component/report.html) of above results
|
||||
|
||||
## Building Customized Quant Research Workflow by Code
|
||||
The automatic workflow may not suit the research workflow of all Quant researchers. To support a flexible Quant research workflow, Qlib also provides a modularized interface to allow researchers to build their own workflow by code. [Here](examples/workflow_by_code.ipynb) is a demo for customized Quant research workflow by code.
|
||||
|
||||
BIN
docs/_static/img/online_serving.png
vendored
Normal file
BIN
docs/_static/img/online_serving.png
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 440 KiB |
@@ -14,6 +14,9 @@ Serializable Class
|
||||
|
||||
``Qlib`` provides a base class ``qlib.utils.serial.Serializable``, whose state can be dumped into or loaded from disk in `pickle` format.
|
||||
When users dump the state of a ``Serializable`` instance, the attributes of the instance whose name **does not** start with `_` will be saved on the disk.
|
||||
However, users can use ``config`` method or override ``default_dump_all`` attribute to prevent this feature.
|
||||
|
||||
Users can also override ``pickle_backend`` attribute to choose a pickle backend. The supported value is "pickle" (default and common) and "dill" (dump more things such as function, more information in `here <https://pypi.org/project/dill/>`_).
|
||||
|
||||
Example
|
||||
==========================
|
||||
|
||||
89
docs/advanced/task_management.rst
Normal file
89
docs/advanced/task_management.rst
Normal file
@@ -0,0 +1,89 @@
|
||||
.. _task_management:
|
||||
|
||||
=================================
|
||||
Task Management
|
||||
=================================
|
||||
.. currentmodule:: qlib
|
||||
|
||||
|
||||
Introduction
|
||||
=============
|
||||
|
||||
The `Workflow <../component/introduction.html>`_ part introduces how to run research workflow in a loosely-coupled way. But it can only execute one ``task`` when you use ``qrun``.
|
||||
To automatically generate and execute different tasks, ``Task Management`` provides a whole process including `Task Generating`_, `Task Storing`_, `Task Training`_ and `Task Collecting`_.
|
||||
With this module, users can run their ``task`` automatically at different periods, in different losses, or even by different models.
|
||||
|
||||
This whole process can be used in `Online Serving <../component/online.html>`_.
|
||||
|
||||
An example of the entire process is shown `here <https://github.com/microsoft/qlib/tree/main/examples/model_rolling/task_manager_rolling.py>`_.
|
||||
|
||||
Task Generating
|
||||
===============
|
||||
A ``task`` consists of `Model`, `Dataset`, `Record`, or anything added by users.
|
||||
The specific task template can be viewed in
|
||||
`Task Section <../component/workflow.html#task-section>`_.
|
||||
Even though the task template is fixed, users can customize their ``TaskGen`` to generate different ``task`` by task template.
|
||||
|
||||
Here is the base class of ``TaskGen``:
|
||||
|
||||
.. autoclass:: qlib.workflow.task.gen.TaskGen
|
||||
:members:
|
||||
|
||||
``Qlib`` provides a class `RollingGen <https://github.com/microsoft/qlib/tree/main/qlib/workflow/task/gen.py>`_ to generate a list of ``task`` of the dataset in different date segments.
|
||||
This class allows users to verify the effect of data from different periods on the model in one experiment. More information is `here <../reference/api.html#TaskGen>`_.
|
||||
|
||||
Task Storing
|
||||
===============
|
||||
To achieve higher efficiency and the possibility of cluster operation, ``Task Manager`` will store all tasks in `MongoDB <https://www.mongodb.com/>`_.
|
||||
``TaskManager`` can fetch undone tasks automatically and manage the lifecycle of a set of tasks with error handling.
|
||||
Users **MUST** finish the configuration of `MongoDB <https://www.mongodb.com/>`_ when using this module.
|
||||
|
||||
Users need to provide the MongoDB URL and database name for using ``TaskManager`` in `initialization <../start/initialization.html#Parameters>`_ or make a statement like this.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from qlib.config import C
|
||||
C["mongo"] = {
|
||||
"task_url" : "mongodb://localhost:27017/", # your MongoDB url
|
||||
"task_db_name" : "rolling_db" # database name
|
||||
}
|
||||
|
||||
.. autoclass:: qlib.workflow.task.manage.TaskManager
|
||||
:members:
|
||||
|
||||
More information of ``Task Manager`` can be found in `here <../reference/api.html#TaskManager>`_.
|
||||
|
||||
Task Training
|
||||
===============
|
||||
After generating and storing those ``task``, it's time to run the ``task`` which is in the *WAITING* status.
|
||||
``Qlib`` provides a method called ``run_task`` to run those ``task`` in task pool, however, users can also customize how tasks are executed.
|
||||
An easy way to get the ``task_func`` is using ``qlib.model.trainer.task_train`` directly.
|
||||
It will run the whole workflow defined by ``task``, which includes *Model*, *Dataset*, *Record*.
|
||||
|
||||
.. autofunction:: qlib.workflow.task.manage.run_task
|
||||
|
||||
Meanwhile, ``Qlib`` provides a module called ``Trainer``.
|
||||
|
||||
.. autoclass:: qlib.model.trainer.Trainer
|
||||
:members:
|
||||
|
||||
``Trainer`` will train a list of tasks and return a list of model recorders.
|
||||
``Qlib`` offer two kinds of Trainer, TrainerR is the simplest way and TrainerRM is based on TaskManager to help manager tasks lifecycle automatically.
|
||||
If you do not want to use ``Task Manager`` to manage tasks, then use TrainerR to train a list of tasks generated by ``TaskGen`` is enough.
|
||||
`Here <../reference/api.html#Trainer>`_ are the details about different ``Trainer``.
|
||||
|
||||
Task Collecting
|
||||
===============
|
||||
To collect the results of ``task`` after training, ``Qlib`` provides `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_ to collect the results in a readable, expandable and loosely-coupled way.
|
||||
|
||||
`Collector <../reference/api.html#Collector>`_ can collect objects from everywhere and process them such as merging, grouping, averaging and so on. It has 2 step action including ``collect`` (collect anything in a dict) and ``process_collect`` (process collected dict).
|
||||
|
||||
`Group <../reference/api.html#Group>`_ also has 2 steps including ``group`` (can group a set of object based on `group_func` and change them to a dict) and ``reduce`` (can make a dict become an ensemble based on some rule).
|
||||
For example: {(A,B,C1): object, (A,B,C2): object} ---``group``---> {(A,B): {C1: object, C2: object}} ---``reduce``---> {(A,B): object}
|
||||
|
||||
`Ensemble <../reference/api.html#Ensemble>`_ can merge the objects in an ensemble.
|
||||
For example: {C1: object, C2: object} ---``Ensemble``---> object
|
||||
|
||||
So the hierarchy is ``Collector``'s second step corresponds to ``Group``. And ``Group``'s second step correspond to ``Ensemble``.
|
||||
|
||||
For more information, please see `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_, or the `example <https://github.com/microsoft/qlib/tree/main/examples/model_rolling/task_manager_rolling.py>`_.
|
||||
@@ -182,6 +182,11 @@ The `trade unit` defines the unit number of stocks can be used in a trade, and t
|
||||
qlib.init(provider_uri='~/.qlib/qlib_data/us_data', region=REG_US)
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
PRs for new data source are highly welcome! Users could commit the code to crawl data as a PR like `the examples here <https://github.com/microsoft/qlib/tree/main/scripts>`_. And then we will use the code to create data cache on our server which other users could use directly.
|
||||
|
||||
|
||||
Data API
|
||||
========================
|
||||
|
||||
@@ -298,9 +303,10 @@ Here are some important interfaces that ``DataHandlerLP`` provides:
|
||||
.. autoclass:: qlib.data.dataset.handler.DataHandlerLP
|
||||
:members: __init__, fetch, get_cols
|
||||
|
||||
If users want to load features and labels by config, users can inherit ``qlib.data.dataset.handler.ConfigDataHandler``, ``Qlib`` also provides some preprocess method in this subclass.
|
||||
|
||||
If users want to use qlib data, `QLibDataHandler` is recommended. Users can inherit their custom class from `QLibDataHandler`, which is also a subclass of `ConfigDataHandler`.
|
||||
If users want to load features and labels by config, users can define a new handler and call the static method `parse_config_to_fields` of ``qlib.contrib.data.handler.Alpha158``.
|
||||
|
||||
Also, users can pass ``qlib.contrib.data.processor.ConfigSectionProcessor`` that provides some preprocess methods for features defined by config into the new handler.
|
||||
|
||||
|
||||
Processor
|
||||
@@ -337,7 +343,6 @@ Qlib provides implemented data handler `Alpha158`. The following example shows h
|
||||
|
||||
.. note:: Users need to initialize ``Qlib`` with `qlib.init` first, please refer to `initialization <../start/initialization.html>`_.
|
||||
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
import qlib
|
||||
@@ -364,6 +369,9 @@ Qlib provides implemented data handler `Alpha158`. The following example shows h
|
||||
# fetch all the features
|
||||
print(h.fetch(col_set="feature"))
|
||||
|
||||
|
||||
.. note:: In the ``Alpha158``, ``Qlib`` uses the label `Ref($close, -2)/Ref($close, -1) - 1` that means the change from T+1 to T+2, rather than `Ref($close, -1)/$close - 1`, of which the reason is that when getting the T day close price of a china stock, the stock can be bought on T+1 day and sold on T+2 day.
|
||||
|
||||
API
|
||||
---------
|
||||
|
||||
@@ -388,8 +396,7 @@ The ``DatasetH`` class is the `dataset` with `Data Handler`. Here is the most im
|
||||
API
|
||||
---------
|
||||
|
||||
To know more about ``Dataset``, please refer to `Dataset API <../reference/api.html#module-qlib.data.dataset.__init__>`_.
|
||||
|
||||
To know more about ``Dataset``, please refer to `Dataset API <../reference/api.html#dataset>`_.
|
||||
|
||||
|
||||
Cache
|
||||
|
||||
46
docs/component/online.rst
Normal file
46
docs/component/online.rst
Normal file
@@ -0,0 +1,46 @@
|
||||
.. _online:
|
||||
|
||||
=================================
|
||||
Online Serving
|
||||
=================================
|
||||
.. currentmodule:: qlib
|
||||
|
||||
|
||||
Introduction
|
||||
=============
|
||||
|
||||
.. image:: ../_static/img/online_serving.png
|
||||
:align: center
|
||||
|
||||
|
||||
In addition to backtesting, one way to test a model is effective is to make predictions in real market conditions or even do real trading based on those predictions.
|
||||
``Online Serving`` is a set of modules for online models using the latest data,
|
||||
which including `Online Manager <#Online Manager>`_, `Online Strategy <#Online Strategy>`_, `Online Tool <#Online Tool>`_, `Updater <#Updater>`_.
|
||||
|
||||
`Here <https://github.com/microsoft/qlib/tree/main/examples/online_srv>`_ are several examples for reference, which demonstrate different features of ``Online Serving``.
|
||||
If you have many models or `task` needs to be managed, please consider `Task Management <../advanced/task_management.html>`_.
|
||||
The `examples <https://github.com/microsoft/qlib/tree/main/examples/online_srv>`_ are based on some components in `Task Management <../advanced/task_management.html>`_ such as ``TrainerRM`` or ``Collector``.
|
||||
|
||||
Online Manager
|
||||
=============
|
||||
|
||||
.. automodule:: qlib.workflow.online.manager
|
||||
:members:
|
||||
|
||||
Online Strategy
|
||||
=============
|
||||
|
||||
.. automodule:: qlib.workflow.online.strategy
|
||||
:members:
|
||||
|
||||
Online Tool
|
||||
=============
|
||||
|
||||
.. automodule:: qlib.workflow.online.utils
|
||||
:members:
|
||||
|
||||
Updater
|
||||
=============
|
||||
|
||||
.. automodule:: qlib.workflow.online.update
|
||||
:members:
|
||||
@@ -34,6 +34,7 @@ Here is a general view of the structure of the system:
|
||||
- Recorder 2
|
||||
- ...
|
||||
- ...
|
||||
|
||||
This experiment management system defines a set of interface and provided a concrete implementation ``MLflowExpManager``, which is based on the machine learning platform: ``MLFlow`` (`link <https://mlflow.org/>`_).
|
||||
|
||||
If users set the implementation of ``ExpManager`` to be ``MLflowExpManager``, they can use the command `mlflow ui` to visualize and check the experiment results. For more information, pleaes refer to the related documents `here <https://www.mlflow.org/docs/latest/cli.html#mlflow-ui>`_.
|
||||
|
||||
@@ -42,6 +42,7 @@ Document Structure
|
||||
Intraday Trading: Model&Strategy Testing <component/backtest.rst>
|
||||
Qlib Recorder: Experiment Management <component/recorder.rst>
|
||||
Analysis: Evaluation & Results Analysis <component/report.rst>
|
||||
Online Serving: Online Management & Strategy & Tool <component/online.rst>
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 3
|
||||
@@ -50,6 +51,7 @@ Document Structure
|
||||
Building Formulaic Alphas <advanced/alpha.rst>
|
||||
Online & Offline mode <advanced/server.rst>
|
||||
Serialization <advanced/serial.rst>
|
||||
Task Management <advanced/task_management.rst>
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 3
|
||||
|
||||
@@ -154,6 +154,70 @@ Record Template
|
||||
.. automodule:: qlib.workflow.record_temp
|
||||
:members:
|
||||
|
||||
Task Management
|
||||
====================
|
||||
|
||||
|
||||
TaskGen
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.task.gen
|
||||
:members:
|
||||
|
||||
TaskManager
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.task.manage
|
||||
:members:
|
||||
|
||||
Trainer
|
||||
--------------------
|
||||
.. automodule:: qlib.model.trainer
|
||||
:members:
|
||||
|
||||
Collector
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.task.collect
|
||||
:members:
|
||||
|
||||
Group
|
||||
--------------------
|
||||
.. automodule:: qlib.model.ens.group
|
||||
:members:
|
||||
|
||||
Ensemble
|
||||
--------------------
|
||||
.. automodule:: qlib.model.ens.ensemble
|
||||
:members:
|
||||
|
||||
Utils
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.task.utils
|
||||
:members:
|
||||
|
||||
|
||||
Online Serving
|
||||
====================
|
||||
|
||||
|
||||
Online Manager
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.online.manager
|
||||
:members:
|
||||
|
||||
Online Strategy
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.online.strategy
|
||||
:members:
|
||||
|
||||
Online Tool
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.online.utils
|
||||
:members:
|
||||
|
||||
RecordUpdater
|
||||
--------------------
|
||||
.. automodule:: qlib.workflow.online.update
|
||||
:members:
|
||||
|
||||
|
||||
Utils
|
||||
====================
|
||||
@@ -162,4 +226,7 @@ Serializable
|
||||
--------------------
|
||||
|
||||
.. automodule:: qlib.utils.serial.Serializable
|
||||
:members:
|
||||
:members:
|
||||
|
||||
|
||||
|
||||
@@ -75,3 +75,14 @@ Besides `provider_uri` and `region`, `qlib.init` has other parameters. The follo
|
||||
"default_exp_name": "Experiment",
|
||||
}
|
||||
})
|
||||
- `mongo`
|
||||
Type: dict, optional parameter, the setting of `MongoDB <https://www.mongodb.com/>`_ which will be used in some features such as `Task Management <../advanced/task_management.html>`_, with high performance and clustered processing.
|
||||
Users need finished `installation <https://www.mongodb.com/try/download/community>`_ firstly, and run it in a fixed URL.
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
# For example, you can initialize qlib below
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN, mongo={
|
||||
"task_url": "mongodb://localhost:27017/", # your mongo url
|
||||
"task_db_name": "rolling_db", # the database name of Task Management
|
||||
})
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
instruments: *market
|
||||
data_loader:
|
||||
class: QlibDataLoader
|
||||
kwargs:
|
||||
config:
|
||||
feature:
|
||||
- ["Resi($close, 15)/$close", "Std(Abs($close/Ref($close, 1)-1)*$volume, 5)/(Mean(Abs($close/Ref($close, 1)-1)*$volume, 5)+1e-12)", "Rsquare($close, 5)", "($high-$low)/$open", "Rsquare($close, 10)", "Corr($close, Log($volume+1), 5)", "Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), 5)", "Corr($close, Log($volume+1), 10)", "Rsquare($close, 20)", "Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), 60)", "Corr($close/Ref($close,1), Log($volume/Ref($volume, 1)+1), 10)", "Corr($close, Log($volume+1), 20)", "(Less($open, $close)-$low)/$open"]
|
||||
- ["RESI5", "WVMA5", "RSQR5", "KLEN", "RSQR10", "CORR5", "CORD5", "CORR10", "RSQR20", "CORD60", "CORD10", "CORR20", "KLOW"]
|
||||
label:
|
||||
- ["Ref($close, -2)/Ref($close, -1) - 1"]
|
||||
- ["LABEL0"]
|
||||
freq: day
|
||||
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSZScoreNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: LGBModel
|
||||
module_path: qlib.contrib.model.gbdt
|
||||
kwargs:
|
||||
loss: mse
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.2
|
||||
subsample: 0.8789
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: DataHandlerLP
|
||||
module_path: qlib.data.dataset.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -17,6 +17,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
| ALSTM (Yao Qin, et al.) | Alpha360 | 0.0493±0.01 | 0.3778±0.06| 0.0585±0.00 | 0.4606±0.04 | 0.0513±0.03 | 0.6727±0.38| -0.1085±0.02 |
|
||||
| GATs (Petar Velickovic, et al.) | Alpha360 | 0.0475±0.00 | 0.3515±0.02| 0.0592±0.00 | 0.4585±0.01 | 0.0876±0.02 | 1.1513±0.27| -0.0795±0.02 |
|
||||
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha360 | 0.0407±0.00| 0.3053±0.00 | 0.0490±0.00 | 0.3840±0.00 | 0.0380±0.02 | 0.5000±0.21 | -0.0984±0.02 |
|
||||
| TabNet (Sercan O. Arik, et al.)| Alpha360 | 0.0192±0.00 | 0.1401±0.00| 0.0291±0.00 | 0.2163±0.00 | -0.0258±0.00 | -0.2961±0.00| -0.1429±0.00 |
|
||||
|
||||
## Alpha158 dataset
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
@@ -32,6 +33,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
| ALSTM (Yao Qin, et al.) | Alpha158 (with selected 20 features) | 0.0385±0.01 | 0.3022±0.06| 0.0478±0.00 | 0.3874±0.04 | 0.0486±0.03 | 0.7141±0.45| -0.1088±0.03 |
|
||||
| GATs (Petar Velickovic, et al.) | Alpha158 (with selected 20 features) | 0.0349±0.00 | 0.2511±0.01| 0.0457±0.00 | 0.3537±0.01 | 0.0578±0.02 | 0.8221±0.25| -0.0824±0.02 |
|
||||
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha158 | 0.0544±0.00 | 0.4338±0.01 | 0.0523±0.00 | 0.4257±0.01 | 0.1253±0.01 | 1.4105±0.14 | -0.0902±0.01 |
|
||||
| TabNet (Sercan O. Arik, et al.)| Alpha158 | 0.0383±0.00 | 0.3414±0.00| 0.0388±0.00 | 0.3460±0.00 | 0.0226±0.00 | 0.2652±0.00| -0.1072±0.00 |
|
||||
|
||||
- The selected 20 features are based on the feature importance of a lightgbm-based model.
|
||||
- The base model of DoubleEnsemble is LGBM.
|
||||
|
||||
@@ -132,7 +132,7 @@ class GenericDataFormatter(abc.ABC):
|
||||
return -1, -1
|
||||
|
||||
def get_column_definition(self):
|
||||
""""Returns formatted column definition in order expected by the TFT."""
|
||||
"""Returns formatted column definition in order expected by the TFT."""
|
||||
|
||||
column_definition = self._column_definition
|
||||
|
||||
|
||||
@@ -25,4 +25,11 @@ The example is given in `workflow.py`, users can run the code as follows.
|
||||
Run the example by running the following command:
|
||||
```bash
|
||||
python workflow.py dump_and_load_dataset
|
||||
```
|
||||
```
|
||||
|
||||
## Benchmarks Performance
|
||||
### Signal Test
|
||||
Here are the results of signal test for benchmark models. We will keep updating benchmark models in future.
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Long precision| Short Precision | Long-Short Average Return | Long-Short Average Sharpe |
|
||||
|---|---|---|---|---|---|---|---|---|---|
|
||||
| LightGBM | Alpha158 | 0.3042±0.00 | 1.5372±0.00| 0.3117±0.00 | 1.6258±0.00 | 0.6720±0.00 | 0.6870±0.00 | 0.000769±0.00 | 1.0190±0.00 |
|
||||
|
||||
@@ -27,12 +27,11 @@ from qlib.tests.data import GetData
|
||||
from highfreq_ops import get_calendar_day, DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut
|
||||
|
||||
|
||||
class HighfreqWorkflow(object):
|
||||
class HighfreqWorkflow:
|
||||
|
||||
SPEC_CONF = {"custom_ops": [DayLast, FFillNan, BFillNan, Date, Select, IsNull, Cut], "expression_cache": None}
|
||||
|
||||
MARKET = "all"
|
||||
BENCHMARK = "SH000300"
|
||||
|
||||
start_time = "2020-09-15 00:00:00"
|
||||
end_time = "2021-01-18 16:00:00"
|
||||
@@ -146,35 +145,40 @@ class HighfreqWorkflow(object):
|
||||
|
||||
self._prepare_calender_cache()
|
||||
##=============reinit dataset=============
|
||||
dataset.init(
|
||||
dataset.config(
|
||||
handler_kwargs={
|
||||
"start_time": "2021-01-19 00:00:00",
|
||||
"end_time": "2021-01-25 16:00:00",
|
||||
},
|
||||
segments={
|
||||
"test": (
|
||||
"2021-01-19 00:00:00",
|
||||
"2021-01-25 16:00:00",
|
||||
),
|
||||
},
|
||||
)
|
||||
dataset.setup_data(
|
||||
handler_kwargs={
|
||||
"init_type": DataHandlerLP.IT_LS,
|
||||
"start_time": "2021-01-19 00:00:00",
|
||||
"end_time": "2021-01-25 16:00:00",
|
||||
},
|
||||
segment_kwargs={
|
||||
"test": (
|
||||
"2021-01-19 00:00:00",
|
||||
"2021-01-25 16:00:00",
|
||||
),
|
||||
},
|
||||
)
|
||||
dataset_backtest.init(
|
||||
dataset_backtest.config(
|
||||
handler_kwargs={
|
||||
"start_time": "2021-01-19 00:00:00",
|
||||
"end_time": "2021-01-25 16:00:00",
|
||||
},
|
||||
segment_kwargs={
|
||||
segments={
|
||||
"test": (
|
||||
"2021-01-19 00:00:00",
|
||||
"2021-01-25 16:00:00",
|
||||
),
|
||||
},
|
||||
)
|
||||
dataset_backtest.setup_data(handler_kwargs={})
|
||||
|
||||
##=============get data=============
|
||||
xtest = dataset.prepare(["test"])
|
||||
backtest_test = dataset_backtest.prepare(["test"])
|
||||
xtest = dataset.prepare("test")
|
||||
backtest_test = dataset_backtest.prepare("test")
|
||||
|
||||
print(xtest, backtest_test)
|
||||
return
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data_1min"
|
||||
region: cn
|
||||
market: &market 'csi300'
|
||||
start_time: &start_time "2020-09-15 00:00:00"
|
||||
end_time: &end_time "2021-01-18 16:00:00"
|
||||
train_end_time: &train_end_time "2020-11-15 16:00:00"
|
||||
valid_start_time: &valid_start_time "2020-11-16 00:00:00"
|
||||
valid_end_time: &valid_end_time "2020-11-30 16:00:00"
|
||||
test_start_time: &test_start_time "2020-12-01 00:00:00"
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: *start_time
|
||||
end_time: *end_time
|
||||
fit_start_time: *start_time
|
||||
fit_end_time: *train_end_time
|
||||
instruments: *market
|
||||
freq: '1min'
|
||||
infer_processors:
|
||||
- class: 'RobustZScoreNorm'
|
||||
kwargs:
|
||||
fields_group: 'feature'
|
||||
clip_outlier: false
|
||||
- class: "Fillna"
|
||||
kwargs:
|
||||
fields_group: 'feature'
|
||||
learn_processors:
|
||||
- class: 'DropnaLabel'
|
||||
- class: 'CSRankNorm'
|
||||
kwargs:
|
||||
fields_group: 'label'
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
task:
|
||||
model:
|
||||
class: "HFLGBModel"
|
||||
module_path: "qlib.contrib.model.highfreq_gdbt_model"
|
||||
kwargs:
|
||||
objective: 'binary'
|
||||
metric: ['binary_logloss','auc']
|
||||
verbosity: -1
|
||||
learning_rate: 0.01
|
||||
max_depth: 8
|
||||
num_leaves: 150
|
||||
lambda_l1: 1.5
|
||||
lambda_l2: 1
|
||||
num_threads: 20
|
||||
dataset:
|
||||
class: "DatasetH"
|
||||
module_path: "qlib.data.dataset"
|
||||
kwargs:
|
||||
handler:
|
||||
class: "Alpha158"
|
||||
module_path: "qlib.contrib.data.handler"
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [*start_time, *train_end_time]
|
||||
valid: [*train_end_time, *valid_end_time]
|
||||
test: [*test_start_time, *end_time]
|
||||
record:
|
||||
- class: "SignalRecord"
|
||||
module_path: "qlib.workflow.record_temp"
|
||||
kwargs: {}
|
||||
- class: "HFSignalRecord"
|
||||
module_path: "qlib.workflow.record_temp"
|
||||
kwargs: {}
|
||||
23
examples/hyperparameter/LightGBM/Readme.md
Normal file
23
examples/hyperparameter/LightGBM/Readme.md
Normal file
@@ -0,0 +1,23 @@
|
||||
# LightGBM hyperparameter
|
||||
|
||||
## Alpha158
|
||||
First terminal
|
||||
```
|
||||
optuna create-study --study LGBM_158 --storage sqlite:///db.sqlite3
|
||||
optuna-dashboard --port 5000 --host 0.0.0.0 sqlite:///db.sqlite3
|
||||
```
|
||||
Second terminal
|
||||
```
|
||||
python hyperparameter_158.py
|
||||
```
|
||||
|
||||
## Alpha360
|
||||
First terminal
|
||||
```
|
||||
optuna create-study --study LGBM_360 --storage sqlite:///db.sqlite3
|
||||
optuna-dashboard --port 5000 --host 0.0.0.0 sqlite:///db.sqlite3
|
||||
```
|
||||
Second terminal
|
||||
```
|
||||
python hyperparameter_360.py
|
||||
```
|
||||
76
examples/hyperparameter/LightGBM/hyperparameter_158.py
Normal file
76
examples/hyperparameter/LightGBM/hyperparameter_158.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.utils import exists_qlib_data, init_instance_by_config
|
||||
import optuna
|
||||
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data"
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
sys.path.append(str(scripts_dir))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data(target_dir=provider_uri, region="cn")
|
||||
qlib.init(provider_uri=provider_uri, region="cn")
|
||||
|
||||
market = "csi300"
|
||||
benchmark = "SH000300"
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": market,
|
||||
}
|
||||
dataset_task = {
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
dataset = init_instance_by_config(dataset_task["dataset"])
|
||||
|
||||
|
||||
def objective(trial):
|
||||
task = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
"kwargs": {
|
||||
"loss": "mse",
|
||||
"colsample_bytree": trial.suggest_uniform("colsample_bytree", 0.5, 1),
|
||||
"learning_rate": trial.suggest_uniform("learning_rate", 0, 1),
|
||||
"subsample": trial.suggest_uniform("subsample", 0, 1),
|
||||
"lambda_l1": trial.suggest_loguniform("lambda_l1", 1e-8, 1e4),
|
||||
"lambda_l2": trial.suggest_loguniform("lambda_l2", 1e-8, 1e4),
|
||||
"max_depth": 10,
|
||||
"num_leaves": trial.suggest_int("num_leaves", 1, 1024),
|
||||
"feature_fraction": trial.suggest_uniform("feature_fraction", 0.4, 1.0),
|
||||
"bagging_fraction": trial.suggest_uniform("bagging_fraction", 0.4, 1.0),
|
||||
"bagging_freq": trial.suggest_int("bagging_freq", 1, 7),
|
||||
"min_data_in_leaf": trial.suggest_int("min_data_in_leaf", 1, 50),
|
||||
"min_child_samples": trial.suggest_int("min_child_samples", 5, 100),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
evals_result = dict()
|
||||
model = init_instance_by_config(task["model"])
|
||||
model.fit(dataset, evals_result=evals_result)
|
||||
return min(evals_result["valid"])
|
||||
|
||||
|
||||
study = optuna.Study(study_name="LGBM_158", storage="sqlite:///db.sqlite3")
|
||||
study.optimize(objective, n_jobs=6)
|
||||
76
examples/hyperparameter/LightGBM/hyperparameter_360.py
Normal file
76
examples/hyperparameter/LightGBM/hyperparameter_360.py
Normal file
@@ -0,0 +1,76 @@
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.utils import exists_qlib_data, init_instance_by_config
|
||||
import optuna
|
||||
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data"
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
sys.path.append(str(scripts_dir))
|
||||
from get_data import GetData
|
||||
|
||||
GetData().qlib_data(target_dir=provider_uri, region="cn")
|
||||
qlib.init(provider_uri=provider_uri, region="cn")
|
||||
|
||||
market = "csi300"
|
||||
benchmark = "SH000300"
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": market,
|
||||
}
|
||||
dataset_task = {
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha360",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
dataset = init_instance_by_config(dataset_task["dataset"])
|
||||
|
||||
|
||||
def objective(trial):
|
||||
task = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
"kwargs": {
|
||||
"loss": "mse",
|
||||
"colsample_bytree": trial.suggest_uniform("colsample_bytree", 0.5, 1),
|
||||
"learning_rate": trial.suggest_uniform("learning_rate", 0, 1),
|
||||
"subsample": trial.suggest_uniform("subsample", 0, 1),
|
||||
"lambda_l1": trial.suggest_loguniform("lambda_l1", 1e-8, 1e4),
|
||||
"lambda_l2": trial.suggest_loguniform("lambda_l2", 1e-8, 1e4),
|
||||
"max_depth": 10,
|
||||
"num_leaves": trial.suggest_int("num_leaves", 1, 1024),
|
||||
"feature_fraction": trial.suggest_uniform("feature_fraction", 0.4, 1.0),
|
||||
"bagging_fraction": trial.suggest_uniform("bagging_fraction", 0.4, 1.0),
|
||||
"bagging_freq": trial.suggest_int("bagging_freq", 1, 7),
|
||||
"min_data_in_leaf": trial.suggest_int("min_data_in_leaf", 1, 50),
|
||||
"min_child_samples": trial.suggest_int("min_child_samples", 5, 100),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
evals_result = dict()
|
||||
model = init_instance_by_config(task["model"])
|
||||
model.fit(dataset, evals_result=evals_result)
|
||||
return min(evals_result["valid"])
|
||||
|
||||
|
||||
study = optuna.Study(study_name="LGBM_360", storage="sqlite:///db.sqlite3")
|
||||
study.optimize(objective, n_jobs=6)
|
||||
5
examples/hyperparameter/LightGBM/requirements.txt
Normal file
5
examples/hyperparameter/LightGBM/requirements.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.17.4
|
||||
lightgbm==3.1.0
|
||||
optuna==2.7.0
|
||||
optuna-dashboard==0.4.1
|
||||
159
examples/model_rolling/task_manager_rolling.py
Normal file
159
examples/model_rolling/task_manager_rolling.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
This example shows how a TrainerRM works based on TaskManager with rolling tasks.
|
||||
After training, how to collect the rolling results will be shown in task_collecting.
|
||||
"""
|
||||
|
||||
from pprint import pprint
|
||||
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.workflow.task.collect import RecorderCollector
|
||||
from qlib.model.ens.group import RollingGroup
|
||||
from qlib.model.trainer import TrainerRM
|
||||
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": "csi100",
|
||||
}
|
||||
|
||||
dataset_config = {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
record_config = [
|
||||
{
|
||||
"class": "SignalRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
{
|
||||
"class": "SigAnaRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
]
|
||||
|
||||
# use lgb
|
||||
task_lgb_config = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
},
|
||||
"dataset": dataset_config,
|
||||
"record": record_config,
|
||||
}
|
||||
|
||||
# use xgboost
|
||||
task_xgboost_config = {
|
||||
"model": {
|
||||
"class": "XGBModel",
|
||||
"module_path": "qlib.contrib.model.xgboost",
|
||||
},
|
||||
"dataset": dataset_config,
|
||||
"record": record_config,
|
||||
}
|
||||
|
||||
|
||||
class RollingTaskExample:
|
||||
def __init__(
|
||||
self,
|
||||
provider_uri="~/.qlib/qlib_data/cn_data",
|
||||
region=REG_CN,
|
||||
task_url="mongodb://10.0.0.4:27017/",
|
||||
task_db_name="rolling_db",
|
||||
experiment_name="rolling_exp",
|
||||
task_pool="rolling_task",
|
||||
task_config=[task_xgboost_config, task_lgb_config],
|
||||
rolling_step=550,
|
||||
rolling_type=RollingGen.ROLL_SD,
|
||||
):
|
||||
# TaskManager config
|
||||
mongo_conf = {
|
||||
"task_url": task_url,
|
||||
"task_db_name": task_db_name,
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
|
||||
self.experiment_name = experiment_name
|
||||
self.task_pool = task_pool
|
||||
self.task_config = task_config
|
||||
self.rolling_gen = RollingGen(step=rolling_step, rtype=rolling_type)
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(self):
|
||||
print("========== reset ==========")
|
||||
TaskManager(task_pool=self.task_pool).remove()
|
||||
exp = R.get_exp(experiment_name=self.experiment_name)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
def task_generating(self):
|
||||
print("========== task_generating ==========")
|
||||
tasks = task_generator(
|
||||
tasks=self.task_config,
|
||||
generators=self.rolling_gen, # generate different date segments
|
||||
)
|
||||
pprint(tasks)
|
||||
return tasks
|
||||
|
||||
def task_training(self, tasks):
|
||||
print("========== task_training ==========")
|
||||
trainer = TrainerRM(self.experiment_name, self.task_pool)
|
||||
trainer.train(tasks)
|
||||
|
||||
def task_collecting(self):
|
||||
print("========== task_collecting ==========")
|
||||
|
||||
def rec_key(recorder):
|
||||
task_config = recorder.load_object("task")
|
||||
model_key = task_config["model"]["class"]
|
||||
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
|
||||
return model_key, rolling_key
|
||||
|
||||
def my_filter(recorder):
|
||||
# only choose the results of "LGBModel"
|
||||
model_key, rolling_key = rec_key(recorder)
|
||||
if model_key == "LGBModel":
|
||||
return True
|
||||
return False
|
||||
|
||||
collector = RecorderCollector(
|
||||
experiment=self.experiment_name,
|
||||
process_list=RollingGroup(),
|
||||
rec_key_func=rec_key,
|
||||
rec_filter_func=my_filter,
|
||||
)
|
||||
print(collector())
|
||||
|
||||
def main(self):
|
||||
self.reset()
|
||||
tasks = self.task_generating()
|
||||
self.task_training(tasks)
|
||||
self.task_collecting()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
## to see the whole process with your own parameters, use the command below
|
||||
# python task_manager_rolling.py main --experiment_name="your_exp_name"
|
||||
fire.Fire(RollingTaskExample)
|
||||
146
examples/online_srv/online_management_simulate.py
Normal file
146
examples/online_srv/online_management_simulate.py
Normal file
@@ -0,0 +1,146 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
This example is about how can simulate the OnlineManager based on rolling tasks.
|
||||
"""
|
||||
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.online.manager import OnlineManager
|
||||
from qlib.workflow.online.strategy import RollingStrategy
|
||||
from qlib.workflow.task.gen import RollingGen
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2018-01-01",
|
||||
"end_time": "2018-10-31",
|
||||
"fit_start_time": "2018-01-01",
|
||||
"fit_end_time": "2018-03-31",
|
||||
"instruments": "csi100",
|
||||
}
|
||||
|
||||
dataset_config = {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2018-01-01", "2018-03-31"),
|
||||
"valid": ("2018-04-01", "2018-05-31"),
|
||||
"test": ("2018-06-01", "2018-09-10"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
record_config = [
|
||||
{
|
||||
"class": "SignalRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
{
|
||||
"class": "SigAnaRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
]
|
||||
|
||||
# use lgb model
|
||||
task_lgb_config = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
},
|
||||
"dataset": dataset_config,
|
||||
"record": record_config,
|
||||
}
|
||||
|
||||
# use xgboost model
|
||||
task_xgboost_config = {
|
||||
"model": {
|
||||
"class": "XGBModel",
|
||||
"module_path": "qlib.contrib.model.xgboost",
|
||||
},
|
||||
"dataset": dataset_config,
|
||||
"record": record_config,
|
||||
}
|
||||
|
||||
|
||||
class OnlineSimulationExample:
|
||||
def __init__(
|
||||
self,
|
||||
provider_uri="~/.qlib/qlib_data/cn_data",
|
||||
region="cn",
|
||||
exp_name="rolling_exp",
|
||||
task_url="mongodb://10.0.0.4:27017/",
|
||||
task_db_name="rolling_db",
|
||||
task_pool="rolling_task",
|
||||
rolling_step=80,
|
||||
start_time="2018-09-10",
|
||||
end_time="2018-10-31",
|
||||
tasks=[task_xgboost_config, task_lgb_config],
|
||||
):
|
||||
"""
|
||||
Init OnlineManagerExample.
|
||||
|
||||
Args:
|
||||
provider_uri (str, optional): the provider uri. Defaults to "~/.qlib/qlib_data/cn_data".
|
||||
region (str, optional): the stock region. Defaults to "cn".
|
||||
exp_name (str, optional): the experiment name. Defaults to "rolling_exp".
|
||||
task_url (str, optional): your MongoDB url. Defaults to "mongodb://10.0.0.4:27017/".
|
||||
task_db_name (str, optional): database name. Defaults to "rolling_db".
|
||||
task_pool (str, optional): the task pool name (a task pool is a collection in MongoDB). Defaults to "rolling_task".
|
||||
rolling_step (int, optional): the step for rolling. Defaults to 80.
|
||||
start_time (str, optional): the start time of simulating. Defaults to "2018-09-10".
|
||||
end_time (str, optional): the end time of simulating. Defaults to "2018-10-31".
|
||||
tasks (dict or list[dict]): a set of the task config waiting for rolling and training
|
||||
"""
|
||||
self.exp_name = exp_name
|
||||
self.task_pool = task_pool
|
||||
self.start_time = start_time
|
||||
self.end_time = end_time
|
||||
mongo_conf = {
|
||||
"task_url": task_url,
|
||||
"task_db_name": task_db_name,
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
|
||||
self.rolling_gen = RollingGen(
|
||||
step=rolling_step, rtype=RollingGen.ROLL_SD, ds_extra_mod_func=None
|
||||
) # The rolling tasks generator, ds_extra_mod_func is None because we just need to simulate to 2018-10-31 and needn't change the handler end time.
|
||||
self.trainer = DelayTrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR
|
||||
self.rolling_online_manager = OnlineManager(
|
||||
RollingStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen),
|
||||
trainer=self.trainer,
|
||||
begin_time=self.start_time,
|
||||
)
|
||||
self.tasks = tasks
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(self):
|
||||
TaskManager(self.task_pool).remove()
|
||||
exp = R.get_exp(experiment_name=self.exp_name)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
# Run this to run all workflow automatically
|
||||
def main(self):
|
||||
print("========== reset ==========")
|
||||
self.reset()
|
||||
print("========== simulate ==========")
|
||||
self.rolling_online_manager.simulate(end_time=self.end_time)
|
||||
print("========== collect results ==========")
|
||||
print(self.rolling_online_manager.get_collector()())
|
||||
print("========== signals ==========")
|
||||
print(self.rolling_online_manager.get_signals())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
## to run all workflow automatically with your own parameters, use the command below
|
||||
# python online_management_simulate.py main --experiment_name="your_exp_name" --rolling_step=60
|
||||
fire.Fire(OnlineSimulationExample)
|
||||
181
examples/online_srv/rolling_online_management.py
Normal file
181
examples/online_srv/rolling_online_management.py
Normal file
@@ -0,0 +1,181 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
This example shows how OnlineManager works with rolling tasks.
|
||||
There are four parts including first train, routine 1, add strategy and routine 2.
|
||||
Firstly, the OnlineManager will finish the first training and set trained models to `online` models.
|
||||
Next, the OnlineManager will finish a routine process, including update online prediction -> prepare tasks -> prepare new models -> prepare signals
|
||||
Then, we will add some new strategies to the OnlineManager. This will finish first training of new strategies.
|
||||
Finally, the OnlineManager will finish second routine and update all strategies.
|
||||
"""
|
||||
|
||||
import os
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.online.strategy import RollingStrategy
|
||||
from qlib.workflow.task.gen import RollingGen
|
||||
from qlib.workflow.online.manager import OnlineManager
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2013-01-01",
|
||||
"end_time": "2020-09-25",
|
||||
"fit_start_time": "2013-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": "csi100",
|
||||
}
|
||||
|
||||
dataset_config = {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2013-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2015-12-31"),
|
||||
"test": ("2016-01-01", "2020-07-10"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
record_config = [
|
||||
{
|
||||
"class": "SignalRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
{
|
||||
"class": "SigAnaRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
]
|
||||
|
||||
# use lgb model
|
||||
task_lgb_config = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
},
|
||||
"dataset": dataset_config,
|
||||
"record": record_config,
|
||||
}
|
||||
|
||||
# use xgboost model
|
||||
task_xgboost_config = {
|
||||
"model": {
|
||||
"class": "XGBModel",
|
||||
"module_path": "qlib.contrib.model.xgboost",
|
||||
},
|
||||
"dataset": dataset_config,
|
||||
"record": record_config,
|
||||
}
|
||||
|
||||
|
||||
class RollingOnlineExample:
|
||||
def __init__(
|
||||
self,
|
||||
provider_uri="~/.qlib/qlib_data/cn_data",
|
||||
region="cn",
|
||||
task_url="mongodb://10.0.0.4:27017/",
|
||||
task_db_name="rolling_db",
|
||||
rolling_step=550,
|
||||
tasks=[task_xgboost_config],
|
||||
add_tasks=[task_lgb_config],
|
||||
):
|
||||
mongo_conf = {
|
||||
"task_url": task_url, # your MongoDB url
|
||||
"task_db_name": task_db_name, # database name
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf)
|
||||
self.tasks = tasks
|
||||
self.add_tasks = add_tasks
|
||||
self.rolling_step = rolling_step
|
||||
strategies = []
|
||||
for task in tasks:
|
||||
name_id = task["model"]["class"] # NOTE: Assumption: The model class can specify only one strategy
|
||||
strategies.append(
|
||||
RollingStrategy(
|
||||
name_id,
|
||||
task,
|
||||
RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD),
|
||||
)
|
||||
)
|
||||
|
||||
self.rolling_online_manager = OnlineManager(strategies)
|
||||
|
||||
_ROLLING_MANAGER_PATH = (
|
||||
".RollingOnlineExample" # the OnlineManager will dump to this file, for it can be loaded when calling routine.
|
||||
)
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(self):
|
||||
for task in self.tasks + self.add_tasks:
|
||||
name_id = task["model"]["class"]
|
||||
exp = R.get_exp(experiment_name=name_id)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
if os.path.exists(self._ROLLING_MANAGER_PATH):
|
||||
os.remove(self._ROLLING_MANAGER_PATH)
|
||||
|
||||
def first_run(self):
|
||||
print("========== reset ==========")
|
||||
self.reset()
|
||||
print("========== first_run ==========")
|
||||
self.rolling_online_manager.first_train()
|
||||
print("========== collect results ==========")
|
||||
print(self.rolling_online_manager.get_collector()())
|
||||
print("========== dump ==========")
|
||||
self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)
|
||||
|
||||
def routine(self):
|
||||
print("========== load ==========")
|
||||
self.rolling_online_manager = OnlineManager.load(self._ROLLING_MANAGER_PATH)
|
||||
print("========== routine ==========")
|
||||
self.rolling_online_manager.routine()
|
||||
print("========== collect results ==========")
|
||||
print(self.rolling_online_manager.get_collector()())
|
||||
print("========== signals ==========")
|
||||
print(self.rolling_online_manager.get_signals())
|
||||
print("========== dump ==========")
|
||||
self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)
|
||||
|
||||
def add_strategy(self):
|
||||
print("========== load ==========")
|
||||
self.rolling_online_manager = OnlineManager.load(self._ROLLING_MANAGER_PATH)
|
||||
print("========== add strategy ==========")
|
||||
strategies = []
|
||||
for task in self.add_tasks:
|
||||
name_id = task["model"]["class"] # NOTE: Assumption: The model class can specify only one strategy
|
||||
strategies.append(
|
||||
RollingStrategy(
|
||||
name_id,
|
||||
task,
|
||||
RollingGen(step=self.rolling_step, rtype=RollingGen.ROLL_SD),
|
||||
)
|
||||
)
|
||||
self.rolling_online_manager.add_strategy(strategies=strategies)
|
||||
print("========== dump ==========")
|
||||
self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH)
|
||||
|
||||
def main(self):
|
||||
self.first_run()
|
||||
self.routine()
|
||||
self.add_strategy()
|
||||
self.routine()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
####### to train the first version's models, use the command below
|
||||
# python rolling_online_management.py first_run
|
||||
|
||||
####### to update the models and predictions after the trading time, use the command below
|
||||
# python rolling_online_management.py routine
|
||||
|
||||
####### to define your own parameters, use `--`
|
||||
# python rolling_online_management.py first_run --exp_name='your_exp_name' --rolling_step=40
|
||||
fire.Fire(RollingOnlineExample)
|
||||
91
examples/online_srv/update_online_pred.py
Normal file
91
examples/online_srv/update_online_pred.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
This example shows how OnlineTool works when we need update prediction.
|
||||
There are two parts including first_train and update_online_pred.
|
||||
Firstly, we will finish the training and set the trained models to the `online` models.
|
||||
Next, we will finish updating online predictions.
|
||||
"""
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.model.trainer import task_train
|
||||
from qlib.workflow.online.utils import OnlineToolR
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": "csi100",
|
||||
}
|
||||
|
||||
task = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
"kwargs": {
|
||||
"loss": "mse",
|
||||
"colsample_bytree": 0.8879,
|
||||
"learning_rate": 0.0421,
|
||||
"subsample": 0.8789,
|
||||
"lambda_l1": 205.6999,
|
||||
"lambda_l2": 580.9768,
|
||||
"max_depth": 8,
|
||||
"num_leaves": 210,
|
||||
"num_threads": 20,
|
||||
},
|
||||
},
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
},
|
||||
"record": {
|
||||
"class": "SignalRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class UpdatePredExample:
|
||||
def __init__(
|
||||
self, provider_uri="~/.qlib/qlib_data/cn_data", region=REG_CN, experiment_name="online_srv", task_config=task
|
||||
):
|
||||
qlib.init(provider_uri=provider_uri, region=region)
|
||||
self.experiment_name = experiment_name
|
||||
self.online_tool = OnlineToolR(self.experiment_name)
|
||||
self.task_config = task_config
|
||||
|
||||
def first_train(self):
|
||||
rec = task_train(self.task_config, experiment_name=self.experiment_name)
|
||||
self.online_tool.reset_online_tag(rec) # set to online model
|
||||
|
||||
def update_online_pred(self):
|
||||
self.online_tool.update_online_pred()
|
||||
|
||||
def main(self):
|
||||
self.first_train()
|
||||
self.update_online_pred()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
## to train a model and set it to online model, use the command below
|
||||
# python update_online_pred.py first_train
|
||||
## to update online predictions once a day, use the command below
|
||||
# python update_online_pred.py update_online_pred
|
||||
## to see the whole process with your own parameters, use the command below
|
||||
# python update_online_pred.py main --experiment_name="your_exp_name"
|
||||
fire.Fire(UpdatePredExample)
|
||||
17
examples/rolling_process_data/README.md
Normal file
17
examples/rolling_process_data/README.md
Normal file
@@ -0,0 +1,17 @@
|
||||
# Rolling Process Data
|
||||
|
||||
This workflow is an example for `Rolling Process Data`.
|
||||
|
||||
## Background
|
||||
|
||||
When rolling train the models, data also needs to be generated in the different rolling windows. When the rolling window moves, the training data will change, and the processor's learnable state (such as standard deviation, mean, etc.) will also change.
|
||||
|
||||
In order to avoid regenerating data, this example uses the `DataHandler-based DataLoader` to load the raw features that are not related to the rolling window, and then used Processors to generate processed-features related to the rolling window.
|
||||
|
||||
|
||||
## Run the Code
|
||||
|
||||
Run the example by running the following command:
|
||||
```bash
|
||||
python workflow.py rolling_process
|
||||
```
|
||||
32
examples/rolling_process_data/rolling_handler.py
Normal file
32
examples/rolling_process_data/rolling_handler.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.data.dataset.loader import DataLoaderDH
|
||||
from qlib.contrib.data.handler import check_transform_proc
|
||||
|
||||
|
||||
class RollingDataHandler(DataHandlerLP):
|
||||
def __init__(
|
||||
self,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
infer_processors=[],
|
||||
learn_processors=[],
|
||||
fit_start_time=None,
|
||||
fit_end_time=None,
|
||||
data_loader_kwargs={},
|
||||
):
|
||||
infer_processors = check_transform_proc(infer_processors, fit_start_time, fit_end_time)
|
||||
learn_processors = check_transform_proc(learn_processors, fit_start_time, fit_end_time)
|
||||
|
||||
data_loader = {
|
||||
"class": "DataLoaderDH",
|
||||
"kwargs": {**data_loader_kwargs},
|
||||
}
|
||||
|
||||
super().__init__(
|
||||
instruments=None,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
data_loader=data_loader,
|
||||
infer_processors=infer_processors,
|
||||
learn_processors=learn_processors,
|
||||
)
|
||||
141
examples/rolling_process_data/workflow.py
Normal file
141
examples/rolling_process_data/workflow.py
Normal file
@@ -0,0 +1,141 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import qlib
|
||||
import fire
|
||||
import pickle
|
||||
import pandas as pd
|
||||
|
||||
from datetime import datetime
|
||||
from qlib.config import REG_CN
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.contrib.data.handler import Alpha158
|
||||
from qlib.utils import exists_qlib_data, init_instance_by_config
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
|
||||
class RollingDataWorkflow:
|
||||
|
||||
MARKET = "csi300"
|
||||
start_time = "2010-01-01"
|
||||
end_time = "2019-12-31"
|
||||
rolling_cnt = 5
|
||||
|
||||
def _init_qlib(self):
|
||||
"""initialize qlib"""
|
||||
# use yahoo_cn_1min data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
def _dump_pre_handler(self, path):
|
||||
handler_config = {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": {
|
||||
"start_time": self.start_time,
|
||||
"end_time": self.end_time,
|
||||
"instruments": self.MARKET,
|
||||
"infer_processors": [],
|
||||
"learn_processors": [],
|
||||
},
|
||||
}
|
||||
pre_handler = init_instance_by_config(handler_config)
|
||||
pre_handler.config(dump_all=True)
|
||||
pre_handler.to_pickle(path)
|
||||
|
||||
def _load_pre_handler(self, path):
|
||||
with open(path, "rb") as file_dataset:
|
||||
pre_handler = pickle.load(file_dataset)
|
||||
return pre_handler
|
||||
|
||||
def rolling_process(self):
|
||||
self._init_qlib()
|
||||
self._dump_pre_handler("pre_handler.pkl")
|
||||
pre_handler = self._load_pre_handler("pre_handler.pkl")
|
||||
|
||||
train_start_time = (2010, 1, 1)
|
||||
train_end_time = (2012, 12, 31)
|
||||
valid_start_time = (2013, 1, 1)
|
||||
valid_end_time = (2013, 12, 31)
|
||||
test_start_time = (2014, 1, 1)
|
||||
test_end_time = (2014, 12, 31)
|
||||
|
||||
dataset_config = {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "RollingDataHandler",
|
||||
"module_path": "rolling_handler",
|
||||
"kwargs": {
|
||||
"start_time": datetime(*train_start_time),
|
||||
"end_time": datetime(*test_end_time),
|
||||
"fit_start_time": datetime(*train_start_time),
|
||||
"fit_end_time": datetime(*train_end_time),
|
||||
"infer_processors": [
|
||||
{"class": "RobustZScoreNorm", "kwargs": {"fields_group": "feature"}},
|
||||
],
|
||||
"learn_processors": [
|
||||
{"class": "DropnaLabel"},
|
||||
{"class": "CSZScoreNorm", "kwargs": {"fields_group": "label"}},
|
||||
],
|
||||
"data_loader_kwargs": {
|
||||
"handler_config": pre_handler,
|
||||
},
|
||||
},
|
||||
},
|
||||
"segments": {
|
||||
"train": (datetime(*train_start_time), datetime(*train_end_time)),
|
||||
"valid": (datetime(*valid_start_time), datetime(*valid_end_time)),
|
||||
"test": (datetime(*test_start_time), datetime(*test_end_time)),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
dataset = init_instance_by_config(dataset_config)
|
||||
|
||||
for rolling_offset in range(self.rolling_cnt):
|
||||
|
||||
print(f"===========rolling{rolling_offset} start===========")
|
||||
if rolling_offset:
|
||||
dataset.config(
|
||||
handler_kwargs={
|
||||
"start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]),
|
||||
"end_time": datetime(test_end_time[0] + rolling_offset, *test_end_time[1:]),
|
||||
"processor_kwargs": {
|
||||
"fit_start_time": datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]),
|
||||
"fit_end_time": datetime(train_end_time[0] + rolling_offset, *train_end_time[1:]),
|
||||
},
|
||||
},
|
||||
segments={
|
||||
"train": (
|
||||
datetime(train_start_time[0] + rolling_offset, *train_start_time[1:]),
|
||||
datetime(train_end_time[0] + rolling_offset, *train_end_time[1:]),
|
||||
),
|
||||
"valid": (
|
||||
datetime(valid_start_time[0] + rolling_offset, *valid_start_time[1:]),
|
||||
datetime(valid_end_time[0] + rolling_offset, *valid_end_time[1:]),
|
||||
),
|
||||
"test": (
|
||||
datetime(test_start_time[0] + rolling_offset, *test_start_time[1:]),
|
||||
datetime(test_end_time[0] + rolling_offset, *test_end_time[1:]),
|
||||
),
|
||||
},
|
||||
)
|
||||
dataset.setup_data(
|
||||
handler_kwargs={
|
||||
"init_type": DataHandlerLP.IT_FIT_SEQ,
|
||||
}
|
||||
)
|
||||
|
||||
dtrain, dvalid, dtest = dataset.prepare(["train", "valid", "test"])
|
||||
print(dtrain, dvalid, dtest)
|
||||
## print or dump data
|
||||
print(f"===========rolling{rolling_offset} end===========")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(RollingDataWorkflow)
|
||||
@@ -28,11 +28,17 @@
|
||||
"import sys, site\n",
|
||||
"from pathlib import Path\n",
|
||||
"\n",
|
||||
"################################# NOTE #################################\n",
|
||||
"# Please be aware that if colab installs the latest numpy and pyqlib #\n",
|
||||
"# in this cell, users should RESTART the runtime in order to run the #\n",
|
||||
"# following cells successfully. #\n",
|
||||
"########################################################################\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" import qlib\n",
|
||||
"except ImportError:\n",
|
||||
" # install qlib\n",
|
||||
" ! pip install --upgrade numpy\n",
|
||||
" ! pip install pyqlib\n",
|
||||
" # reload\n",
|
||||
" site.main()\n",
|
||||
@@ -238,9 +244,7 @@
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"scrolled": false
|
||||
},
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from qlib.contrib.report import analysis_model, analysis_position\n",
|
||||
@@ -359,7 +363,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.9"
|
||||
"version": "3.8.3"
|
||||
},
|
||||
"toc": {
|
||||
"base_numbering": 1,
|
||||
@@ -377,4 +381,4 @@
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
|
||||
__version__ = "0.6.3.99"
|
||||
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
|
||||
|
||||
|
||||
import os
|
||||
@@ -10,12 +11,13 @@ import yaml
|
||||
import logging
|
||||
import platform
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from .log import get_module_logger
|
||||
|
||||
|
||||
# init qlib
|
||||
def init(default_conf="client", **kwargs):
|
||||
from .config import C
|
||||
from .log import get_module_logger
|
||||
from .data.cache import H
|
||||
|
||||
H.clear()
|
||||
@@ -48,7 +50,6 @@ def init(default_conf="client", **kwargs):
|
||||
|
||||
|
||||
def _mount_nfs_uri(C):
|
||||
from .log import get_module_logger
|
||||
|
||||
LOG = get_module_logger("mount nfs", level=logging.INFO)
|
||||
|
||||
@@ -151,3 +152,74 @@ def init_from_yaml_conf(conf_path, **kwargs):
|
||||
config.update(kwargs)
|
||||
default_conf = config.pop("default_conf", "client")
|
||||
init(default_conf, **config)
|
||||
|
||||
|
||||
def get_project_path(config_name="config.yaml", cur_path=None) -> Path:
|
||||
"""
|
||||
If users are building a project follow the following pattern.
|
||||
- Qlib is a sub folder in project path
|
||||
- There is a file named `config.yaml` in qlib.
|
||||
|
||||
For example:
|
||||
If your project file system stucuture follows such a pattern
|
||||
|
||||
<project_path>/
|
||||
- config.yaml
|
||||
- ...some folders...
|
||||
- qlib/
|
||||
|
||||
This folder will return <project_path>
|
||||
|
||||
NOTE: link is not supported here.
|
||||
|
||||
|
||||
This method is often used when
|
||||
- user want to use a relative config path instead of hard-coding qlib config path in code
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError:
|
||||
If project path is not found
|
||||
"""
|
||||
if cur_path is None:
|
||||
cur_path = Path(__file__).absolute().resolve()
|
||||
while True:
|
||||
if (cur_path / config_name).exists():
|
||||
return cur_path
|
||||
if cur_path == cur_path.parent:
|
||||
raise FileNotFoundError("We can't find the project path")
|
||||
cur_path = cur_path.parent
|
||||
|
||||
|
||||
def auto_init(**kwargs):
|
||||
"""
|
||||
This function will init qlib automatically with following priority
|
||||
- Find the project configuration and init qlib
|
||||
- The parsing process will be affected by the `conf_type` of the configuration file
|
||||
- Init qlib with default config
|
||||
"""
|
||||
|
||||
try:
|
||||
pp = get_project_path(cur_path=kwargs.pop("cur_path", None))
|
||||
except FileNotFoundError:
|
||||
init(**kwargs)
|
||||
else:
|
||||
|
||||
conf_pp = pp / "config.yaml"
|
||||
with conf_pp.open() as f:
|
||||
conf = yaml.safe_load(f)
|
||||
|
||||
conf_type = conf.get("conf_type", "origin")
|
||||
if conf_type == "origin":
|
||||
# The type of config is just like original qlib config
|
||||
init_from_yaml_conf(conf_pp, **kwargs)
|
||||
elif conf_type == "ref":
|
||||
# This config type will be more convenient in following scenario
|
||||
# - There is a shared configure file and you don't want to edit it inplace.
|
||||
# - The shared configure may be updated later and you don't want to copy it.
|
||||
# - You have some customized config.
|
||||
qlib_conf_path = conf["qlib_cfg"]
|
||||
qlib_conf_update = conf.get("qlib_cfg_update")
|
||||
init_from_yaml_conf(qlib_conf_path, **qlib_conf_update, **kwargs)
|
||||
logger = get_module_logger("Initialization")
|
||||
logger.info(f"Auto load project config: {conf_pp}")
|
||||
|
||||
@@ -33,6 +33,9 @@ class Config:
|
||||
|
||||
raise AttributeError(f"No such {attr} in self._config")
|
||||
|
||||
def get(self, key, default=None):
|
||||
return self.__dict__["_config"].get(key, default)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.__dict__["_config"][key] = value
|
||||
|
||||
@@ -131,7 +134,7 @@ _default_config = {
|
||||
},
|
||||
"loggers": {"qlib": {"level": logging.DEBUG, "handlers": ["console"]}},
|
||||
},
|
||||
# Defatult config for experiment manager
|
||||
# Default config for experiment manager
|
||||
"exp_manager": {
|
||||
"class": "MLflowExpManager",
|
||||
"module_path": "qlib.workflow.expm",
|
||||
@@ -140,6 +143,11 @@ _default_config = {
|
||||
"default_exp_name": "Experiment",
|
||||
},
|
||||
},
|
||||
# Default config for MongoDB
|
||||
"mongo": {
|
||||
"task_url": "mongodb://localhost:27017/",
|
||||
"task_db_name": "default_task_db",
|
||||
},
|
||||
}
|
||||
|
||||
MODE_CONF = {
|
||||
@@ -310,8 +318,22 @@ class QlibConfig(Config):
|
||||
# clean up experiment when python program ends
|
||||
experiment_exit_handler()
|
||||
|
||||
# Supporting user reset qlib version (useful when user want to connect to qlib server with old version)
|
||||
self.reset_qlib_version()
|
||||
|
||||
self._registered = True
|
||||
|
||||
def reset_qlib_version(self):
|
||||
import qlib
|
||||
|
||||
reset_version = self.get("qlib_reset_version", None)
|
||||
if reset_version is not None:
|
||||
qlib.__version__ = reset_version
|
||||
else:
|
||||
qlib.__version__ = getattr(qlib, "__version__bak")
|
||||
# Due to a bug? that converting __version__ to _QlibConfig__version__bak
|
||||
# Using __version__bak instead of __version__
|
||||
|
||||
@property
|
||||
def registered(self):
|
||||
return self._registered
|
||||
|
||||
@@ -15,7 +15,8 @@ LOG = get_module_logger("backtest")
|
||||
|
||||
|
||||
def backtest(pred, strategy, executor, trade_exchange, shift, verbose, account, benchmark, return_order):
|
||||
"""Parameters
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
pred : pandas.DataFrame
|
||||
predict should has <datetime, instrument> index and one `score` column
|
||||
@@ -124,7 +125,9 @@ def backtest(pred, strategy, executor, trade_exchange, shift, verbose, account,
|
||||
|
||||
|
||||
def update_account(trade_account, trade_info, trade_exchange, trade_date):
|
||||
"""Update the account and strategy
|
||||
"""
|
||||
Update the account and strategy
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trade_account : Account()
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
import pandas as pd
|
||||
import copy
|
||||
import pathlib
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from .order import Order
|
||||
|
||||
"""
|
||||
@@ -128,7 +128,7 @@ class Position:
|
||||
return self.position["cash"]
|
||||
|
||||
def get_stock_amount_dict(self):
|
||||
"""generate stock amount dict {stock_id : amount of stock} """
|
||||
"""generate stock amount dict {stock_id : amount of stock}"""
|
||||
d = {}
|
||||
stock_list = self.get_stock_list()
|
||||
for stock_code in stock_list:
|
||||
|
||||
@@ -26,6 +26,7 @@ def check_transform_proc(proc_l, fit_start_time, fit_end_time):
|
||||
"fit_end_time": fit_end_time,
|
||||
}
|
||||
)
|
||||
# FIXME: the `module_path` parameter is missed.
|
||||
new_l.append({"class": klass.__name__, "kwargs": pkwargs})
|
||||
else:
|
||||
new_l.append(p)
|
||||
|
||||
@@ -8,6 +8,59 @@ import pandas as pd
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
def calc_long_short_prec(
|
||||
pred: pd.Series, label: pd.Series, date_col="datetime", quantile: float = 0.2, dropna=False, is_alpha=False
|
||||
) -> Tuple[pd.Series, pd.Series]:
|
||||
"""
|
||||
calculate the precision for long and short operation
|
||||
|
||||
|
||||
:param pred/label: index is **pd.MultiIndex**, index name is **[datetime, instruments]**; columns names is **[score]**.
|
||||
|
||||
.. code-block:: python
|
||||
score
|
||||
datetime instrument
|
||||
2020-12-01 09:30:00 SH600068 0.553634
|
||||
SH600195 0.550017
|
||||
SH600276 0.540321
|
||||
SH600584 0.517297
|
||||
SH600715 0.544674
|
||||
label :
|
||||
label
|
||||
date_col :
|
||||
date_col
|
||||
|
||||
Returns
|
||||
-------
|
||||
(pd.Series, pd.Series)
|
||||
long precision and short precision in time level
|
||||
"""
|
||||
if is_alpha:
|
||||
label = label - label.mean(level=date_col)
|
||||
if int(1 / quantile) >= len(label.index.get_level_values(1).unique()):
|
||||
raise ValueError("Need more instruments to calculate precision")
|
||||
|
||||
df = pd.DataFrame({"pred": pred, "label": label})
|
||||
if dropna:
|
||||
df.dropna(inplace=True)
|
||||
|
||||
group = df.groupby(level=date_col)
|
||||
|
||||
N = lambda x: int(len(x) * quantile)
|
||||
# find the top/low quantile of prediction and treat them as long and short target
|
||||
long = group.apply(lambda x: x.nlargest(N(x), columns="pred").label).reset_index(level=0, drop=True)
|
||||
short = group.apply(lambda x: x.nsmallest(N(x), columns="pred").label).reset_index(level=0, drop=True)
|
||||
|
||||
groupll = long.groupby(date_col)
|
||||
l_dom = groupll.apply(lambda x: x > 0)
|
||||
l_c = groupll.count()
|
||||
|
||||
groups = short.groupby(date_col)
|
||||
s_dom = groups.apply(lambda x: x < 0)
|
||||
s_c = groups.count()
|
||||
return (l_dom.groupby(date_col).sum() / l_c), (s_dom.groupby(date_col).sum() / s_c)
|
||||
|
||||
|
||||
def calc_ic(pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False) -> Tuple[pd.Series, pd.Series]:
|
||||
"""calc_ic.
|
||||
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
try:
|
||||
from .catboost_model import CatBoostModel
|
||||
except ModuleNotFoundError:
|
||||
CatBoostModel = None
|
||||
print("Please install necessary libs for CatBoostModel.")
|
||||
try:
|
||||
from .double_ensemble import DEnsembleModel
|
||||
from .gbdt import LGBModel
|
||||
except ModuleNotFoundError:
|
||||
DEnsembleModel, LGBModel = None, None
|
||||
print("Please install necessary libs for DEnsembleModel and LGBModel, such as lightgbm.")
|
||||
try:
|
||||
from .xgboost import XGBModel
|
||||
except ModuleNotFoundError:
|
||||
XGBModel = None
|
||||
print("Please install necessary libs for XGBModel, such as xgboost.")
|
||||
try:
|
||||
from .linear import LinearModel
|
||||
except ModuleNotFoundError:
|
||||
LinearModel = None
|
||||
print("Please install necessary libs for LinearModel, such as scipy and sklearn.")
|
||||
# import pytorch models
|
||||
try:
|
||||
from .pytorch_alstm import ALSTM
|
||||
from .pytorch_gats import GATs
|
||||
from .pytorch_gru import GRU
|
||||
from .pytorch_lstm import LSTM
|
||||
from .pytorch_nn import DNNModelPytorch
|
||||
from .pytorch_tabnet import TabnetModel
|
||||
from .pytorch_sfm import SFM_Model
|
||||
|
||||
pytorch_classes = (ALSTM, GATs, GRU, LSTM, DNNModelPytorch, TabnetModel, SFM_Model)
|
||||
except ModuleNotFoundError:
|
||||
pytorch_classes = ()
|
||||
print("Please install necessary libs for PyTorch models.")
|
||||
|
||||
all_model_classes = (CatBoostModel, DEnsembleModel, LGBModel, XGBModel, LinearModel) + pytorch_classes
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
from catboost import Pool, CatBoost
|
||||
from catboost.utils import get_gpu_device_count
|
||||
|
||||
@@ -62,10 +63,10 @@ class CatBoostModel(Model):
|
||||
evals_result["train"] = list(evals_result["learn"].values())[0]
|
||||
evals_result["valid"] = list(evals_result["validation"].values())[0]
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if self.model is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
return pd.Series(self.model.predict(x_test.values), index=x_test.index)
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
import lightgbm as lgb
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from typing import Text, Union
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
@@ -40,6 +40,10 @@ class DEnsembleModel(Model):
|
||||
self.bins_sr = bins_sr
|
||||
self.bins_fs = bins_fs
|
||||
self.decay = decay
|
||||
if sample_ratios is None: # the default values for sample_ratios
|
||||
sample_ratios = [0.8, 0.7, 0.6, 0.5, 0.4]
|
||||
if sub_weights is None: # the default values for sub_weights
|
||||
sub_weights = [1.0, 0.2, 0.2, 0.2, 0.2, 0.2]
|
||||
if not len(sample_ratios) == bins_fs:
|
||||
raise ValueError("The length of sample_ratios should be equal to bins_fs.")
|
||||
self.sample_ratios = sample_ratios
|
||||
@@ -228,10 +232,10 @@ class DEnsembleModel(Model):
|
||||
raise ValueError("not implemented yet")
|
||||
return loss_curve
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if self.ensemble is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
pred = pd.Series(np.zeros(x_test.shape[0]), index=x_test.index)
|
||||
for i_sub, submodel in enumerate(self.ensemble):
|
||||
feat_sub = self.sub_features[i_sub]
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import lightgbm as lgb
|
||||
|
||||
from typing import Text, Union
|
||||
from ...model.base import ModelFT
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
@@ -61,10 +61,10 @@ class LGBModel(ModelFT):
|
||||
evals_result["train"] = list(evals_result["train"].values())[0]
|
||||
evals_result["valid"] = list(evals_result["valid"].values())[0]
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if self.model is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
return pd.Series(self.model.predict(x_test.values), index=x_test.index)
|
||||
|
||||
def finetune(self, dataset: DatasetH, num_boost_round=10, verbose_eval=20):
|
||||
|
||||
157
qlib/contrib/model/highfreq_gdbt_model.py
Normal file
157
qlib/contrib/model/highfreq_gdbt_model.py
Normal file
@@ -0,0 +1,157 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import lightgbm as lgb
|
||||
|
||||
from qlib.model.base import ModelFT
|
||||
from qlib.data.dataset import DatasetH
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
import warnings
|
||||
|
||||
|
||||
class HFLGBModel(ModelFT):
|
||||
"""LightGBM Model for high frequency prediction"""
|
||||
|
||||
def __init__(self, loss="mse", **kwargs):
|
||||
if loss not in {"mse", "binary"}:
|
||||
raise NotImplementedError
|
||||
self.params = {"objective": loss, "verbosity": -1}
|
||||
self.params.update(kwargs)
|
||||
self.model = None
|
||||
|
||||
def _cal_signal_metrics(self, y_test, l_cut, r_cut):
|
||||
"""
|
||||
Calcaute the signal metrics by daily level
|
||||
"""
|
||||
up_pre, down_pre = [], []
|
||||
up_alpha_ll, down_alpha_ll = [], []
|
||||
for date in y_test.index.get_level_values(0).unique():
|
||||
df_res = y_test.loc[date].sort_values("pred")
|
||||
if int(l_cut * len(df_res)) < 10:
|
||||
warnings.warn("Warning: threhold is too low or instruments number is not enough")
|
||||
continue
|
||||
top = df_res.iloc[: int(l_cut * len(df_res))]
|
||||
bottom = df_res.iloc[int(r_cut * len(df_res)) :]
|
||||
|
||||
down_precision = len(top[top[top.columns[0]] < 0]) / (len(top))
|
||||
up_precision = len(bottom[bottom[top.columns[0]] > 0]) / (len(bottom))
|
||||
|
||||
down_alpha = top[top.columns[0]].mean()
|
||||
up_alpha = bottom[bottom.columns[0]].mean()
|
||||
|
||||
up_pre.append(up_precision)
|
||||
down_pre.append(down_precision)
|
||||
up_alpha_ll.append(up_alpha)
|
||||
down_alpha_ll.append(down_alpha)
|
||||
|
||||
return (
|
||||
np.array(up_pre).mean(),
|
||||
np.array(down_pre).mean(),
|
||||
np.array(up_alpha_ll).mean(),
|
||||
np.array(down_alpha_ll).mean(),
|
||||
)
|
||||
|
||||
def hf_signal_test(self, dataset: DatasetH, threhold=0.2):
|
||||
"""
|
||||
Test the sigal in high frequency test set
|
||||
"""
|
||||
if self.model == None:
|
||||
raise ValueError("Model hasn't been trained yet")
|
||||
df_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
||||
df_test.dropna(inplace=True)
|
||||
x_test, y_test = df_test["feature"], df_test["label"]
|
||||
# Convert label into alpha
|
||||
y_test[y_test.columns[0]] = y_test[y_test.columns[0]] - y_test[y_test.columns[0]].mean(level=0)
|
||||
|
||||
res = pd.Series(self.model.predict(x_test.values), index=x_test.index)
|
||||
y_test["pred"] = res
|
||||
|
||||
up_p, down_p, up_a, down_a = self._cal_signal_metrics(y_test, threhold, 1 - threhold)
|
||||
print("===============================")
|
||||
print("High frequency signal test")
|
||||
print("===============================")
|
||||
print("Test set precision: ")
|
||||
print("Positive precision: {}, Negative precision: {}".format(up_p, down_p))
|
||||
print("Test Alpha Average in test set: ")
|
||||
print("Positive average alpha: {}, Negative average alpha: {}".format(up_a, down_a))
|
||||
|
||||
def _prepare_data(self, dataset: DatasetH):
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
)
|
||||
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_train["feature"], df_valid["label"]
|
||||
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
|
||||
l_name = df_train["label"].columns[0]
|
||||
# Convert label into alpha
|
||||
df_train["label"][l_name] = df_train["label"][l_name] - df_train["label"][l_name].mean(level=0)
|
||||
df_valid["label"][l_name] = df_valid["label"][l_name] - df_valid["label"][l_name].mean(level=0)
|
||||
mapping_fn = lambda x: 0 if x < 0 else 1
|
||||
df_train["label_c"] = df_train["label"][l_name].apply(mapping_fn)
|
||||
df_valid["label_c"] = df_valid["label"][l_name].apply(mapping_fn)
|
||||
x_train, y_train = df_train["feature"], df_train["label_c"].values
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label_c"].values
|
||||
else:
|
||||
raise ValueError("LightGBM doesn't support multi-label training")
|
||||
|
||||
dtrain = lgb.Dataset(x_train.values, label=y_train)
|
||||
dvalid = lgb.Dataset(x_valid.values, label=y_valid)
|
||||
return dtrain, dvalid
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
num_boost_round=1000,
|
||||
early_stopping_rounds=50,
|
||||
verbose_eval=20,
|
||||
evals_result=dict(),
|
||||
**kwargs
|
||||
):
|
||||
dtrain, dvalid = self._prepare_data(dataset)
|
||||
self.model = lgb.train(
|
||||
self.params,
|
||||
dtrain,
|
||||
num_boost_round=num_boost_round,
|
||||
valid_sets=[dtrain, dvalid],
|
||||
valid_names=["train", "valid"],
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
verbose_eval=verbose_eval,
|
||||
evals_result=evals_result,
|
||||
**kwargs
|
||||
)
|
||||
evals_result["train"] = list(evals_result["train"].values())[0]
|
||||
evals_result["valid"] = list(evals_result["valid"].values())[0]
|
||||
|
||||
def predict(self, dataset):
|
||||
if self.model is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
return pd.Series(self.model.predict(x_test.values), index=x_test.index)
|
||||
|
||||
def finetune(self, dataset: DatasetH, num_boost_round=10, verbose_eval=20):
|
||||
"""
|
||||
finetune model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dataset : DatasetH
|
||||
dataset for finetuning
|
||||
num_boost_round : int
|
||||
number of round to finetune model
|
||||
verbose_eval : int
|
||||
verbose level
|
||||
"""
|
||||
# Based on existing model and finetune by train more rounds
|
||||
dtrain, _ = self._prepare_data(dataset)
|
||||
self.model = lgb.train(
|
||||
self.params,
|
||||
dtrain,
|
||||
num_boost_round=num_boost_round,
|
||||
init_model=self.model,
|
||||
valid_sets=[dtrain],
|
||||
valid_names=["train"],
|
||||
verbose_eval=verbose_eval,
|
||||
)
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from typing import Text, Union
|
||||
from scipy.optimize import nnls
|
||||
from sklearn.linear_model import LinearRegression, Ridge, Lasso
|
||||
|
||||
@@ -84,8 +84,8 @@ class LinearModel(Model):
|
||||
self.coef_ = coef
|
||||
self.intercept_ = 0.0
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if self.coef_ is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
return pd.Series(x_test.values @ self.coef_ + self.intercept_, index=x_test.index)
|
||||
|
||||
@@ -8,13 +8,9 @@ from __future__ import print_function
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
@@ -273,11 +269,11 @@ class ALSTM(Model):
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
index = x_test.index
|
||||
self.ALSTM_model.eval()
|
||||
x_values = x_test.values
|
||||
|
||||
@@ -8,13 +8,9 @@ from __future__ import print_function
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
@@ -264,11 +260,11 @@ class ALSTM(Model):
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
dl_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
||||
dl_test = dataset.prepare(segment, col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
||||
dl_test.config(fillna_type="ffill+bfill")
|
||||
test_loader = DataLoader(dl_test, batch_size=self.batch_size, num_workers=self.n_jobs)
|
||||
self.ALSTM_model.eval()
|
||||
|
||||
@@ -8,13 +8,9 @@ from __future__ import print_function
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -83,7 +79,6 @@ class GATs(Model):
|
||||
self.with_pretrain = with_pretrain
|
||||
self.model_path = model_path
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
@@ -310,11 +305,11 @@ class GATs(Model):
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
x_test = dataset.prepare(segment, col_set="feature")
|
||||
index = x_test.index
|
||||
self.GAT_model.eval()
|
||||
x_values = x_test.values
|
||||
|
||||
@@ -9,12 +9,7 @@ import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
@@ -8,13 +8,9 @@ from __future__ import print_function
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
@@ -273,11 +269,11 @@ class GRU(Model):
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
index = x_test.index
|
||||
self.gru_model.eval()
|
||||
x_values = x_test.values
|
||||
|
||||
@@ -9,12 +9,7 @@ import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
@@ -126,8 +121,8 @@ class GRU(Model):
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.gru_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.gru_model)))
|
||||
self.logger.info("model:\n{:}".format(self.GRU_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.GRU_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.GRU_model.parameters(), lr=self.lr)
|
||||
|
||||
@@ -8,13 +8,9 @@ from __future__ import print_function
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
@@ -268,11 +264,11 @@ class LSTM(Model):
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
index = x_test.index
|
||||
self.lstm_model.eval()
|
||||
x_values = x_test.values
|
||||
@@ -280,17 +276,13 @@ class LSTM(Model):
|
||||
preds = []
|
||||
|
||||
for begin in range(sample_num)[:: self.batch_size]:
|
||||
|
||||
if sample_num - begin < self.batch_size:
|
||||
end = sample_num
|
||||
else:
|
||||
end = begin + self.batch_size
|
||||
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.lstm_model(x_batch).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
return pd.Series(np.concatenate(preds), index=index)
|
||||
|
||||
@@ -9,12 +9,7 @@ import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
|
||||
@@ -8,6 +8,7 @@ from __future__ import print_function
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
|
||||
import torch
|
||||
@@ -18,7 +19,7 @@ from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, get_or_create_path, drop_nan_by_y_index
|
||||
from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
from ...workflow import R
|
||||
|
||||
@@ -48,8 +49,8 @@ class DNNModelPytorch(Model):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim,
|
||||
output_dim,
|
||||
input_dim=360,
|
||||
output_dim=1,
|
||||
layers=(256,),
|
||||
lr=0.001,
|
||||
max_steps=300,
|
||||
@@ -271,13 +272,12 @@ class DNNModelPytorch(Model):
|
||||
else:
|
||||
raise NotImplementedError("loss {} is not supported!".format(loss_type))
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test_pd = dataset.prepare("test", col_set="feature")
|
||||
x_test_pd = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
x_test = torch.from_numpy(x_test_pd.values).float().to(self.device)
|
||||
self.dnn_model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
preds = self.dnn_model(x_test).detach().cpu().numpy()
|
||||
return pd.Series(np.squeeze(preds), index=x_test_pd.index)
|
||||
|
||||
@@ -7,13 +7,9 @@ from __future__ import print_function
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
@@ -442,11 +438,11 @@ class SFM(Model):
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
index = x_test.index
|
||||
self.sfm_model.eval()
|
||||
x_values = x_test.values
|
||||
@@ -459,10 +455,7 @@ class SFM(Model):
|
||||
else:
|
||||
end = begin + self.batch_size
|
||||
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float()
|
||||
|
||||
if self.device != "cpu":
|
||||
x_batch = x_batch.to(self.device)
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.sfm_model(x_batch).detach().cpu().numpy()
|
||||
|
||||
@@ -6,13 +6,9 @@ from __future__ import print_function
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
import copy
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
@@ -217,11 +213,11 @@ class TabnetModel(Model):
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
index = x_test.index
|
||||
self.tabnet_model.eval()
|
||||
x_values = torch.from_numpy(x_test.values)
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import xgboost as xgb
|
||||
|
||||
from typing import Text, Union
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
@@ -57,8 +57,8 @@ class XGBModel(Model):
|
||||
evals_result["train"] = list(evals_result["train"].values())[0]
|
||||
evals_result["valid"] = list(evals_result["valid"].values())[0]
|
||||
|
||||
def predict(self, dataset):
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if self.model is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
return pd.Series(self.model.predict(xgb.DMatrix(x_test.values)), index=x_test.index)
|
||||
|
||||
@@ -214,7 +214,7 @@ def cumulative_return_graph(
|
||||
features_df = D.features(D.instruments('csi500'), ['Ref($close, -1)/$close - 1'], pred_df_dates.min(), pred_df_dates.max())
|
||||
features_df.columns = ['label']
|
||||
|
||||
qcr.cumulative_return_graph(positions, report_normal_df, features_df)
|
||||
qcr.analysis_position.cumulative_return_graph(positions, report_normal_df, features_df)
|
||||
|
||||
|
||||
Graph desc:
|
||||
|
||||
@@ -94,7 +94,7 @@ def rank_label_graph(
|
||||
features_df = D.features(D.instruments('csi500'), ['Ref($close, -1)/$close-1'], pred_df_dates.min(), pred_df_dates.max())
|
||||
features_df.columns = ['label']
|
||||
|
||||
qcr.rank_label_graph(positions, features_df, pred_df_dates.min(), pred_df_dates.max())
|
||||
qcr.analysis_position.rank_label_graph(positions, features_df, pred_df_dates.min(), pred_df_dates.max())
|
||||
|
||||
|
||||
:param position: position data; **qlib.contrib.backtest.backtest.backtest** result.
|
||||
|
||||
@@ -186,7 +186,7 @@ def report_graph(report_df: pd.DataFrame, show_notebook: bool = True) -> [list,
|
||||
|
||||
report_normal_df, _ = backtest(pred_df, strategy, **bparas)
|
||||
|
||||
qcr.report_graph(report_normal_df)
|
||||
qcr.analysis_position.report_graph(report_normal_df)
|
||||
|
||||
:param report_df: **df.index.name** must be **date**, **df.columns** must contain **return**, **turnover**, **cost**, **bench**.
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from ...utils import get_module_by_module_path
|
||||
|
||||
|
||||
class BaseGraph:
|
||||
""""""
|
||||
""" """
|
||||
|
||||
_name = None
|
||||
|
||||
|
||||
@@ -251,7 +251,7 @@ class TopkDropoutStrategy(BaseStrategy, ListAdjustTimer):
|
||||
|
||||
def generate_order_list(self, score_series, current, trade_exchange, pred_date, trade_date):
|
||||
"""
|
||||
Gnererate order list according to score_series at trade_date, will not change current.
|
||||
Generate order list according to score_series at trade_date, will not change current.
|
||||
|
||||
Parameters
|
||||
-----------
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from .record_temp import MultiSegRecord
|
||||
from .record_temp import SignalMseRecord
|
||||
|
||||
@@ -1,16 +1,60 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import re
|
||||
import logging
|
||||
import pandas as pd
|
||||
from sklearn.metrics import mean_squared_error
|
||||
from pprint import pprint
|
||||
import numpy as np
|
||||
from sklearn.metrics import mean_squared_error
|
||||
from typing import Dict, Text, Any
|
||||
|
||||
from ...contrib.eva.alpha import calc_ic
|
||||
from ...workflow.record_temp import RecordTemp
|
||||
from ...workflow.record_temp import SignalRecord
|
||||
from ...data import dataset as qlib_dataset
|
||||
from ...log import get_module_logger
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
logger = get_module_logger("workflow", logging.INFO)
|
||||
|
||||
|
||||
class MultiSegRecord(RecordTemp):
|
||||
"""
|
||||
This is the multiple segments signal record class that generates the signal prediction.
|
||||
This class inherits the ``RecordTemp`` class.
|
||||
"""
|
||||
|
||||
def __init__(self, model, dataset, recorder=None):
|
||||
super().__init__(recorder=recorder)
|
||||
if not isinstance(dataset, qlib_dataset.DatasetH):
|
||||
raise ValueError("The type of dataset is not DatasetH instead of {:}".format(type(dataset)))
|
||||
self.model = model
|
||||
self.dataset = dataset
|
||||
|
||||
def generate(self, segments: Dict[Text, Any], save: bool = False):
|
||||
for key, segment in segments.items():
|
||||
predics = self.model.predict(self.dataset, segment)
|
||||
if isinstance(predics, pd.Series):
|
||||
predics = predics.to_frame("score")
|
||||
labels = self.dataset.prepare(
|
||||
segments=segment, col_set="label", data_key=qlib_dataset.handler.DataHandlerLP.DK_R
|
||||
)
|
||||
# Compute the IC and Rank IC
|
||||
ic, ric = calc_ic(predics.iloc[:, 0], labels.iloc[:, 0])
|
||||
results = {"all-IC": ic, "mean-IC": ic.mean(), "all-Rank-IC": ric, "mean-Rank-IC": ric.mean()}
|
||||
logger.info("--- Results for {:} ({:}) ---".format(key, segment))
|
||||
ic_x100, ric_x100 = ic * 100, ric * 100
|
||||
logger.info("IC: {:.4f}%".format(ic_x100.mean()))
|
||||
logger.info("ICIR: {:.4f}%".format(ic_x100.mean() / ic_x100.std()))
|
||||
logger.info("Rank IC: {:.4f}%".format(ric_x100.mean()))
|
||||
logger.info("Rank ICIR: {:.4f}%".format(ric_x100.mean() / ric_x100.std()))
|
||||
|
||||
if save:
|
||||
save_name = "results-{:}.pkl".format(key)
|
||||
self.recorder.save_objects(**{save_name: results})
|
||||
logger.info(
|
||||
"The record '{:}' has been saved as the artifact of the Experiment {:}".format(
|
||||
save_name, self.recorder.experiment_id
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class SignalMseRecord(SignalRecord):
|
||||
@@ -38,7 +82,7 @@ class SignalMseRecord(SignalRecord):
|
||||
objects = {"mse.pkl": mse, "rmse.pkl": np.sqrt(mse)}
|
||||
self.recorder.log_metrics(**metrics)
|
||||
self.recorder.save_objects(**objects, artifact_path=self.get_path())
|
||||
pprint(metrics)
|
||||
logger.info("The evaluation results in SignalMseRecord is {:}".format(metrics))
|
||||
|
||||
def list(self):
|
||||
paths = [self.get_path("mse.pkl"), self.get_path("rmse.pkl")]
|
||||
|
||||
@@ -1037,7 +1037,8 @@ class ClientProvider(BaseProvider):
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
if isinstance(Cal, ClientCalendarProvider):
|
||||
Cal.set_conn(self.client)
|
||||
Inst.set_conn(self.client)
|
||||
if isinstance(Inst, ClientInstrumentProvider):
|
||||
Inst.set_conn(self.client)
|
||||
if hasattr(DatasetD, "provider"):
|
||||
DatasetD.provider.set_conn(self.client)
|
||||
else:
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Union, List, Tuple, Dict, Text, Optional
|
||||
from ...utils import init_instance_by_config, np_ffill
|
||||
from ...log import get_module_logger
|
||||
from .handler import DataHandler, DataHandlerLP
|
||||
from copy import deepcopy
|
||||
from inspect import getfullargspec
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
@@ -16,22 +17,28 @@ class Dataset(Serializable):
|
||||
Preparing data for model training and inferencing.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, **kwargs):
|
||||
"""
|
||||
init is designed to finish following steps:
|
||||
|
||||
- init the sub instance and the state of the dataset(info to prepare the data)
|
||||
- The name of essential state for preparing data should not start with '_' so that it could be serialized on disk when serializing.
|
||||
|
||||
- setup data
|
||||
- The data related attributes' names should start with '_' so that it will not be saved on disk when serializing.
|
||||
|
||||
- initialize the state of the dataset(info to prepare the data)
|
||||
- The name of essential state for preparing data should not start with '_' so that it could be serialized on disk when serializing.
|
||||
|
||||
The data could specify the info to caculate the essential data for preparation
|
||||
The data could specify the info to calculate the essential data for preparation
|
||||
"""
|
||||
self.setup_data(*args, **kwargs)
|
||||
self.setup_data(**kwargs)
|
||||
super().__init__()
|
||||
|
||||
def setup_data(self, *args, **kwargs):
|
||||
def config(self, **kwargs):
|
||||
"""
|
||||
config is designed to configure and parameters that cannot be learned from the data
|
||||
"""
|
||||
super().config(**kwargs)
|
||||
|
||||
def setup_data(self, **kwargs):
|
||||
"""
|
||||
Setup the data.
|
||||
|
||||
@@ -39,7 +46,7 @@ class Dataset(Serializable):
|
||||
|
||||
- User have a Dataset object with learned status on disk.
|
||||
|
||||
- User load the Dataset object from the disk(Note the init function is skiped).
|
||||
- User load the Dataset object from the disk.
|
||||
|
||||
- User call `setup_data` to load new data.
|
||||
|
||||
@@ -47,7 +54,7 @@ class Dataset(Serializable):
|
||||
"""
|
||||
pass
|
||||
|
||||
def prepare(self, *args, **kwargs) -> object:
|
||||
def prepare(self, **kwargs) -> object:
|
||||
"""
|
||||
The type of dataset depends on the model. (It could be pd.DataFrame, pytorch.DataLoader, etc.)
|
||||
The parameters should specify the scope for the prepared data
|
||||
@@ -76,44 +83,7 @@ class DatasetH(Dataset):
|
||||
- The processing is related to data split.
|
||||
"""
|
||||
|
||||
def init(self, handler_kwargs: dict = None, segment_kwargs: dict = None):
|
||||
"""
|
||||
Initialize the DatasetH
|
||||
|
||||
Parameters
|
||||
----------
|
||||
handler_kwargs : dict
|
||||
Config of DataHanlder, which could include the following arguments:
|
||||
|
||||
- arguments of DataHandler.conf_data, such as 'instruments', 'start_time' and 'end_time'.
|
||||
|
||||
- arguments of DataHandler.init, such as 'enable_cache', etc.
|
||||
|
||||
segment_kwargs : dict
|
||||
Config of segments which is same as 'segments' in DatasetH.setup_data
|
||||
|
||||
"""
|
||||
if handler_kwargs:
|
||||
if not isinstance(handler_kwargs, dict):
|
||||
raise TypeError(f"param handler_kwargs must be type dict, not {type(handler_kwargs)}")
|
||||
kwargs_init = {}
|
||||
kwargs_conf_data = {}
|
||||
conf_data_arg = {"instruments", "start_time", "end_time"}
|
||||
for k, v in handler_kwargs.items():
|
||||
if k in conf_data_arg:
|
||||
kwargs_conf_data.update({k: v})
|
||||
else:
|
||||
kwargs_init.update({k: v})
|
||||
|
||||
self.handler.conf_data(**kwargs_conf_data)
|
||||
self.handler.init(**kwargs_init)
|
||||
|
||||
if segment_kwargs:
|
||||
if not isinstance(segment_kwargs, dict):
|
||||
raise TypeError(f"param handler_kwargs must be type dict, not {type(segment_kwargs)}")
|
||||
self.segments = segment_kwargs.copy()
|
||||
|
||||
def setup_data(self, handler: Union[Dict, DataHandler], segments: Dict[Text, Tuple]):
|
||||
def __init__(self, handler: Union[Dict, DataHandler], segments: Dict[Text, Tuple], **kwargs):
|
||||
"""
|
||||
Setup the underlying data.
|
||||
|
||||
@@ -122,7 +92,7 @@ class DatasetH(Dataset):
|
||||
handler : Union[dict, DataHandler]
|
||||
handler could be:
|
||||
|
||||
- insntance of `DataHandler`
|
||||
- instance of `DataHandler`
|
||||
|
||||
- config of `DataHandler`. Please refer to `DataHandler`
|
||||
|
||||
@@ -142,8 +112,52 @@ class DatasetH(Dataset):
|
||||
'outsample': ("2017-01-01", "2020-08-01",),
|
||||
}
|
||||
"""
|
||||
self.handler = init_instance_by_config(handler, accept_types=DataHandler)
|
||||
self.handler: DataHandler = init_instance_by_config(handler, accept_types=DataHandler)
|
||||
self.segments = segments.copy()
|
||||
self.fetch_kwargs = {}
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def config(self, handler_kwargs: dict = None, **kwargs):
|
||||
"""
|
||||
Initialize the DatasetH
|
||||
|
||||
Parameters
|
||||
----------
|
||||
handler_kwargs : dict
|
||||
Config of DataHandler, which could include the following arguments:
|
||||
|
||||
- arguments of DataHandler.conf_data, such as 'instruments', 'start_time' and 'end_time'.
|
||||
|
||||
kwargs : dict
|
||||
Config of DatasetH, such as
|
||||
|
||||
- segments : dict
|
||||
Config of segments which is same as 'segments' in self.__init__
|
||||
|
||||
"""
|
||||
if handler_kwargs is not None:
|
||||
self.handler.config(**handler_kwargs)
|
||||
if "segments" in kwargs:
|
||||
self.segments = deepcopy(kwargs.pop("segments"))
|
||||
super().config(**kwargs)
|
||||
|
||||
def setup_data(self, handler_kwargs: dict = None, **kwargs):
|
||||
"""
|
||||
Setup the Data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
handler_kwargs : dict
|
||||
init arguments of DataHandler, which could include the following arguments:
|
||||
|
||||
- init_type : Init Type of Handler
|
||||
|
||||
- enable_cache : whether to enable cache
|
||||
|
||||
"""
|
||||
super().setup_data(**kwargs)
|
||||
if handler_kwargs is not None:
|
||||
self.handler.setup_data(**handler_kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(handler={handler}, segments={segments})".format(
|
||||
@@ -158,7 +172,10 @@ class DatasetH(Dataset):
|
||||
----------
|
||||
slc : slice
|
||||
"""
|
||||
return self.handler.fetch(slc, **kwargs)
|
||||
if hasattr(self, "fetch_kwargs"):
|
||||
return self.handler.fetch(slc, **kwargs, **self.fetch_kwargs)
|
||||
else:
|
||||
return self.handler.fetch(slc, **kwargs)
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
@@ -186,6 +203,12 @@ class DatasetH(Dataset):
|
||||
The data to fetch: DK_*
|
||||
Default is DK_I, which indicate fetching data for **inference**.
|
||||
|
||||
kwargs :
|
||||
The parameters that kwargs may contain:
|
||||
flt_col : str
|
||||
It only exists in TSDatasetH, can be used to add a column of data(True or False) to filter data.
|
||||
This parameter is only supported when it is an instance of TSDatasetH.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Union[List[pd.DataFrame], pd.DataFrame]:
|
||||
@@ -218,7 +241,7 @@ class TSDataSampler:
|
||||
(T)ime-(S)eries DataSampler
|
||||
This is the result of TSDatasetH
|
||||
|
||||
It works like `torch.data.utils.Dataset`, it provides a very convient interface for constructing time-series
|
||||
It works like `torch.data.utils.Dataset`, it provides a very convenient interface for constructing time-series
|
||||
dataset based on tabular data.
|
||||
|
||||
If user have further requirements for processing data, user could process them based on `TSDataSampler` or create
|
||||
@@ -230,7 +253,9 @@ class TSDataSampler:
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, data: pd.DataFrame, start, end, step_len: int, fillna_type: str = "none"):
|
||||
def __init__(
|
||||
self, data: pd.DataFrame, start, end, step_len: int, fillna_type: str = "none", dtype=None, flt_data=None
|
||||
):
|
||||
"""
|
||||
Build a dataset which looks like torch.data.utils.Dataset.
|
||||
|
||||
@@ -252,6 +277,11 @@ class TSDataSampler:
|
||||
ffill with previous sample
|
||||
ffill+bfill:
|
||||
ffill with previous samples first and fill with later samples second
|
||||
flt_data : pd.Series
|
||||
a column of data(True or False) to filter data.
|
||||
None:
|
||||
kepp all data
|
||||
|
||||
"""
|
||||
self.start = start
|
||||
self.end = end
|
||||
@@ -259,23 +289,51 @@ class TSDataSampler:
|
||||
self.fillna_type = fillna_type
|
||||
assert get_level_index(data, "datetime") == 0
|
||||
self.data = lazy_sort_index(data)
|
||||
self.data_arr = np.array(self.data) # Get index from numpy.array will much faster than DataFrame.values!
|
||||
# NOTE: append last line with full NaN for better performance in `__getitem__`
|
||||
self.data_arr = np.append(self.data_arr, np.full((1, self.data_arr.shape[1]), np.nan), axis=0)
|
||||
|
||||
kwargs = {"object": self.data}
|
||||
if dtype is not None:
|
||||
kwargs["dtype"] = dtype
|
||||
|
||||
self.data_arr = np.array(**kwargs) # Get index from numpy.array will much faster than DataFrame.values!
|
||||
# NOTE:
|
||||
# - append last line with full NaN for better performance in `__getitem__`
|
||||
# - Keep the same dtype will result in a better performance
|
||||
self.data_arr = np.append(
|
||||
self.data_arr, np.full((1, self.data_arr.shape[1]), np.nan, dtype=self.data_arr.dtype), axis=0
|
||||
)
|
||||
self.nan_idx = -1 # The last line is all NaN
|
||||
|
||||
# the data type will be changed
|
||||
# The index of usable data is between start_idx and end_idx
|
||||
self.start_idx, self.end_idx = self.data.index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end))
|
||||
self.idx_df, self.idx_map = self.build_index(self.data)
|
||||
self.data_index = deepcopy(self.data.index)
|
||||
|
||||
if flt_data is not None:
|
||||
self.flt_data = np.array(flt_data.reindex(self.data_index)).reshape(-1)
|
||||
self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map)
|
||||
self.data_index = self.data_index[np.where(self.flt_data == True)[0]]
|
||||
|
||||
self.start_idx, self.end_idx = self.data_index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end))
|
||||
self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance
|
||||
|
||||
del self.data # save memory
|
||||
|
||||
@staticmethod
|
||||
def flt_idx_map(flt_data, idx_map):
|
||||
idx = 0
|
||||
new_idx_map = {}
|
||||
for i, exist in enumerate(flt_data):
|
||||
if exist:
|
||||
new_idx_map[idx] = idx_map[i]
|
||||
idx += 1
|
||||
return new_idx_map
|
||||
|
||||
def get_index(self):
|
||||
"""
|
||||
Get the pandas index of the data, it will be useful in following scenarios
|
||||
- Special sampler will be used (e.g. user want to sample day by day)
|
||||
"""
|
||||
return self.data.index[self.start_idx : self.end_idx]
|
||||
return self.data_index[self.start_idx : self.end_idx]
|
||||
|
||||
def config(self, **kwargs):
|
||||
# Config the attributes
|
||||
@@ -419,7 +477,7 @@ class TSDatasetH(DatasetH):
|
||||
(T)ime-(S)eries Dataset (H)andler
|
||||
|
||||
|
||||
Covnert the tabular data to Time-Series data
|
||||
Convert the tabular data to Time-Series data
|
||||
|
||||
Requirements analysis
|
||||
|
||||
@@ -433,18 +491,22 @@ class TSDatasetH(DatasetH):
|
||||
- The dimension of a batch of data <batch_idx, feature, timestep>
|
||||
"""
|
||||
|
||||
def __init__(self, step_len=30, *args, **kwargs):
|
||||
def __init__(self, step_len=30, **kwargs):
|
||||
self.step_len = step_len
|
||||
super().__init__(*args, **kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def setup_data(self, *args, **kwargs):
|
||||
super().setup_data(*args, **kwargs)
|
||||
def config(self, **kwargs):
|
||||
if "step_len" in kwargs:
|
||||
self.step_len = kwargs.pop("step_len")
|
||||
super().config(**kwargs)
|
||||
|
||||
def setup_data(self, **kwargs):
|
||||
super().setup_data(**kwargs)
|
||||
cal = self.handler.fetch(col_set=self.handler.CS_RAW).index.get_level_values("datetime").unique()
|
||||
cal = sorted(cal)
|
||||
# Get the datatime index for building timestamp
|
||||
self.cal = cal
|
||||
|
||||
def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
|
||||
def _prepare_raw_seg(self, slc: slice, **kwargs) -> pd.DataFrame:
|
||||
# Dataset decide how to slice data(Get more data for timeseries).
|
||||
start, end = slc.start, slc.stop
|
||||
start_idx = bisect.bisect_left(self.cal, pd.Timestamp(start))
|
||||
@@ -453,6 +515,25 @@ class TSDatasetH(DatasetH):
|
||||
|
||||
# TSDatasetH will retrieve more data for complete
|
||||
data = super()._prepare_seg(slice(pad_start, end), **kwargs)
|
||||
return data
|
||||
|
||||
tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len)
|
||||
def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
|
||||
"""
|
||||
split the _prepare_raw_seg is to leave a hook for data preprocessing before creating processing data
|
||||
"""
|
||||
dtype = kwargs.pop("dtype", None)
|
||||
start, end = slc.start, slc.stop
|
||||
flt_col = kwargs.pop("flt_col", None)
|
||||
# TSDatasetH will retrieve more data for complete
|
||||
data = self._prepare_raw_seg(slc, **kwargs)
|
||||
|
||||
flt_kwargs = deepcopy(kwargs)
|
||||
if flt_col is not None:
|
||||
flt_kwargs["col_set"] = flt_col
|
||||
flt_data = self._prepare_raw_seg(slc, **flt_kwargs)
|
||||
assert len(flt_data.columns) == 1
|
||||
else:
|
||||
flt_data = None
|
||||
|
||||
tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len, dtype=dtype, flt_data=flt_data)
|
||||
return tsds
|
||||
|
||||
@@ -6,7 +6,8 @@ import abc
|
||||
import bisect
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Union, Tuple, List, Iterator, Optional
|
||||
from inspect import getfullargspec
|
||||
from typing import Callable, Union, Tuple, List, Iterator, Optional
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
@@ -16,7 +17,7 @@ from ...data import D
|
||||
from ...config import C
|
||||
from ...utils import parse_config, transform_end_date, init_instance_by_config
|
||||
from ...utils.serial import Serializable
|
||||
from .utils import get_level_index, fetch_df_by_index
|
||||
from .utils import fetch_df_by_index
|
||||
from pathlib import Path
|
||||
from .loader import DataLoader
|
||||
|
||||
@@ -35,7 +36,7 @@ class DataHandler(Serializable):
|
||||
The data handler try to maintain a handler with 2 level.
|
||||
`datetime` & `instruments`.
|
||||
|
||||
Any order of the index level can be suported (The order will be implied in the data).
|
||||
Any order of the index level can be supported (The order will be implied in the data).
|
||||
The order <`datetime`, `instruments`> will be used when the dataframe index name is missed.
|
||||
|
||||
Example of the data:
|
||||
@@ -50,6 +51,9 @@ class DataHandler(Serializable):
|
||||
SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042
|
||||
SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
|
||||
|
||||
|
||||
Tips for improving the performance of datahandler
|
||||
- Fetching data with `col_set=CS_RAW` will return the raw data and may avoid pandas from copying the data when calling `loc`
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -57,7 +61,7 @@ class DataHandler(Serializable):
|
||||
instruments=None,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
data_loader: Tuple[dict, str, DataLoader] = None,
|
||||
data_loader: Union[dict, str, DataLoader] = None,
|
||||
init_data=True,
|
||||
fetch_orig=True,
|
||||
):
|
||||
@@ -70,10 +74,10 @@ class DataHandler(Serializable):
|
||||
start_time of the original data.
|
||||
end_time :
|
||||
end_time of the original data.
|
||||
data_loader : Tuple[dict, str, DataLoader]
|
||||
data_loader : Union[dict, str, DataLoader]
|
||||
data loader to load the data.
|
||||
init_data :
|
||||
intialize the original data in the constructor.
|
||||
initialize the original data in the constructor.
|
||||
fetch_orig : bool
|
||||
Return the original data instead of copy if possible.
|
||||
"""
|
||||
@@ -99,10 +103,10 @@ class DataHandler(Serializable):
|
||||
self.fetch_orig = fetch_orig
|
||||
if init_data:
|
||||
with TimeInspector.logt("Init data"):
|
||||
self.init()
|
||||
self.setup_data()
|
||||
super().__init__()
|
||||
|
||||
def conf_data(self, **kwargs):
|
||||
def config(self, **kwargs):
|
||||
"""
|
||||
configuration of data.
|
||||
# what data to be loaded from data source
|
||||
@@ -115,13 +119,16 @@ class DataHandler(Serializable):
|
||||
for k, v in kwargs.items():
|
||||
if k in attr_list:
|
||||
setattr(self, k, v)
|
||||
else:
|
||||
raise KeyError("Such config is not supported.")
|
||||
|
||||
def init(self, enable_cache: bool = False):
|
||||
for attr in attr_list:
|
||||
if attr in kwargs:
|
||||
kwargs.pop(attr)
|
||||
|
||||
super().config(**kwargs)
|
||||
|
||||
def setup_data(self, enable_cache: bool = False):
|
||||
"""
|
||||
initialize the data.
|
||||
In case of running intialization for multiple time, it will do nothing for the second time.
|
||||
Set Up the data in case of running initialization for multiple time
|
||||
|
||||
It is responsible for maintaining following variable
|
||||
1) self._data
|
||||
@@ -159,6 +166,7 @@ class DataHandler(Serializable):
|
||||
level: Union[str, int] = "datetime",
|
||||
col_set: Union[str, List[str]] = CS_ALL,
|
||||
squeeze: bool = False,
|
||||
proc_func: Callable = None,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
fetch data from underlying data source
|
||||
@@ -181,6 +189,14 @@ class DataHandler(Serializable):
|
||||
- if isinstance(col_set, List[str]):
|
||||
|
||||
select several sets of meaningful columns, the returned data has multiple levels
|
||||
proc_func: Callable
|
||||
- Give a hook for processing data before fetching
|
||||
- An example to explain the necessity of the hook:
|
||||
- A Dataset learned some processors to process data which is related to data segmentation
|
||||
- It will apply them every time when preparing data.
|
||||
- The learned processor require the dataframe remains the same format when fitting and applying
|
||||
- However the data format will change according to the parameters.
|
||||
- So the processors should be applied to the underlayer data.
|
||||
|
||||
squeeze : bool
|
||||
whether squeeze columns and index
|
||||
@@ -189,8 +205,15 @@ class DataHandler(Serializable):
|
||||
-------
|
||||
pd.DataFrame.
|
||||
"""
|
||||
if proc_func is None:
|
||||
df = self._data
|
||||
else:
|
||||
# FIXME: fetching by time first will be more friendly to `proc_func`
|
||||
# Copy in case of `proc_func` changing the data inplace....
|
||||
df = proc_func(fetch_df_by_index(self._data, selector, level, fetch_orig=self.fetch_orig).copy())
|
||||
|
||||
# Fetch column first will be more friendly to SepDataFrame
|
||||
df = self._fetch_df_by_col(self._data, col_set)
|
||||
df = self._fetch_df_by_col(df, col_set)
|
||||
df = fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig)
|
||||
if squeeze:
|
||||
# squeeze columns
|
||||
@@ -257,6 +280,10 @@ class DataHandler(Serializable):
|
||||
class DataHandlerLP(DataHandler):
|
||||
"""
|
||||
DataHandler with **(L)earnable (P)rocessor**
|
||||
|
||||
Tips to improving the performance of data handler
|
||||
- To reduce the memory cost
|
||||
- `drop_raw=True`: this will modify the data inplace on raw data;
|
||||
"""
|
||||
|
||||
# data key
|
||||
@@ -278,7 +305,7 @@ class DataHandlerLP(DataHandler):
|
||||
instruments=None,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
data_loader: Tuple[dict, str, DataLoader] = None,
|
||||
data_loader: Union[dict, str, DataLoader] = None,
|
||||
infer_processors=[],
|
||||
learn_processors=[],
|
||||
process_type=PTYPE_A,
|
||||
@@ -405,14 +432,28 @@ class DataHandlerLP(DataHandler):
|
||||
if self.drop_raw:
|
||||
del self._data
|
||||
|
||||
def config(self, processor_kwargs: dict = None, **kwargs):
|
||||
"""
|
||||
configuration of data.
|
||||
# what data to be loaded from data source
|
||||
|
||||
This method will be used when loading pickled handler from dataset.
|
||||
The data will be initialized with different time range.
|
||||
|
||||
"""
|
||||
super().config(**kwargs)
|
||||
if processor_kwargs is not None:
|
||||
for processor in self.get_all_processors():
|
||||
processor.config(**processor_kwargs)
|
||||
|
||||
# init type
|
||||
IT_FIT_SEQ = "fit_seq" # the input of `fit` will be the output of the previous processor
|
||||
IT_FIT_IND = "fit_ind" # the input of `fit` will be the original df
|
||||
IT_LS = "load_state" # The state of the object has been load by pickle
|
||||
|
||||
def init(self, init_type: str = IT_FIT_SEQ, enable_cache: bool = False):
|
||||
def setup_data(self, init_type: str = IT_FIT_SEQ, **kwargs):
|
||||
"""
|
||||
Initialize the data of Qlib
|
||||
Set up the data in case of running initialization for multiple time
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -427,7 +468,7 @@ class DataHandlerLP(DataHandler):
|
||||
when we call `init` next time
|
||||
"""
|
||||
# init raw data
|
||||
super().init(enable_cache=enable_cache)
|
||||
super().setup_data(**kwargs)
|
||||
|
||||
with TimeInspector.logt("fit & process data"):
|
||||
if init_type == DataHandlerLP.IT_FIT_IND:
|
||||
@@ -456,6 +497,7 @@ class DataHandlerLP(DataHandler):
|
||||
level: Union[str, int] = "datetime",
|
||||
col_set=DataHandler.CS_ALL,
|
||||
data_key: str = DK_I,
|
||||
proc_func: Callable = None,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
fetch data from underlying data source
|
||||
@@ -470,12 +512,18 @@ class DataHandlerLP(DataHandler):
|
||||
select a set of meaningful columns.(e.g. features, columns).
|
||||
data_key : str
|
||||
the data to fetch: DK_*.
|
||||
proc_func: Callable
|
||||
please refer to the doc of DataHandler.fetch
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame:
|
||||
"""
|
||||
df = self._get_df_by_key(data_key)
|
||||
if proc_func is not None:
|
||||
# FIXME: fetch by time first will be more friendly to proc_func
|
||||
# Copy incase of `proc_func` changing the data inplace....
|
||||
df = proc_func(fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig).copy())
|
||||
# Fetch column first will be more friendly to SepDataFrame
|
||||
df = self._fetch_df_by_col(df, col_set)
|
||||
return fetch_df_by_index(df, selector, level, fetch_orig=self.fetch_orig)
|
||||
|
||||
@@ -13,6 +13,7 @@ from qlib.data import D
|
||||
from qlib.data import filter as filter_module
|
||||
from qlib.data.filter import BaseDFilter
|
||||
from qlib.utils import load_dataset, init_instance_by_config
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
|
||||
class DataLoader(abc.ABC):
|
||||
@@ -217,3 +218,68 @@ class StaticDataLoader(DataLoader):
|
||||
join=self.join,
|
||||
)
|
||||
self._data.sort_index(inplace=True)
|
||||
|
||||
|
||||
class DataLoaderDH(DataLoader):
|
||||
"""DataLoaderDH
|
||||
DataLoader based on (D)ata (H)andler
|
||||
It is designed to load multiple data from data handler
|
||||
- If you just want to load data from single datahandler, you can write them in single data handler
|
||||
|
||||
TODO: What make this module not that easy to use.
|
||||
- For online scenario
|
||||
- The underlayer data handler should be configured. But data loader doesn't provide such interface & hook.
|
||||
"""
|
||||
|
||||
def __init__(self, handler_config: dict, fetch_kwargs: dict = {}, is_group=False):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
handler_config : dict
|
||||
handler_config will be used to describe the handlers
|
||||
|
||||
.. code-block::
|
||||
|
||||
<handler_config> := {
|
||||
"group_name1": <handler>
|
||||
"group_name2": <handler>
|
||||
}
|
||||
or
|
||||
<handler_config> := <handler>
|
||||
<handler> := DataHandler Instance | DataHandler Config
|
||||
|
||||
fetch_kwargs : dict
|
||||
fetch_kwargs will be used to describe the different arguments of fetch method, such as col_set, squeeze, data_key, etc.
|
||||
|
||||
is_group: bool
|
||||
is_group will be used to describe whether the key of handler_config is group
|
||||
|
||||
"""
|
||||
from qlib.data.dataset.handler import DataHandler
|
||||
|
||||
if is_group:
|
||||
self.handlers = {
|
||||
grp: init_instance_by_config(config, accept_types=DataHandler) for grp, config in handler_config.items()
|
||||
}
|
||||
else:
|
||||
self.handlers = init_instance_by_config(handler_config, accept_types=DataHandler)
|
||||
|
||||
self.is_group = is_group
|
||||
self.fetch_kwargs = {"col_set": DataHandler.CS_RAW}
|
||||
self.fetch_kwargs.update(fetch_kwargs)
|
||||
|
||||
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
if instruments is not None:
|
||||
get_module_logger(self.__class__.__name__).warning(f"instruments[{instruments}] is ignored")
|
||||
|
||||
if self.is_group:
|
||||
df = pd.concat(
|
||||
{
|
||||
grp: dh.fetch(selector=slice(start_time, end_time), level="datetime", **self.fetch_kwargs)
|
||||
for grp, dh in self.handlers.items()
|
||||
},
|
||||
axis=1,
|
||||
)
|
||||
else:
|
||||
df = self.handlers.fetch(selector=slice(start_time, end_time), level="datetime", **self.fetch_kwargs)
|
||||
return df
|
||||
|
||||
18
qlib/data/dataset/processor.py
Executable file → Normal file
18
qlib/data/dataset/processor.py
Executable file → Normal file
@@ -2,6 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import abc
|
||||
from typing import Union, Text
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
@@ -14,7 +15,7 @@ from ...utils.paral import datetime_groupby_apply
|
||||
EPS = 1e-12
|
||||
|
||||
|
||||
def get_group_columns(df: pd.DataFrame, group: str):
|
||||
def get_group_columns(df: pd.DataFrame, group: Union[Text, None]):
|
||||
"""
|
||||
get a group of columns from multi-index columns DataFrame
|
||||
|
||||
@@ -72,6 +73,17 @@ class Processor(Serializable):
|
||||
"""
|
||||
return True
|
||||
|
||||
def config(self, **kwargs):
|
||||
attr_list = {"fit_start_time", "fit_end_time"}
|
||||
for k, v in kwargs.items():
|
||||
if k in attr_list and hasattr(self, k):
|
||||
setattr(self, k, v)
|
||||
|
||||
for attr in attr_list:
|
||||
if attr in kwargs:
|
||||
kwargs.pop(attr)
|
||||
super().config(**kwargs)
|
||||
|
||||
|
||||
class DropnaProcessor(Processor):
|
||||
def __init__(self, fields_group=None):
|
||||
@@ -118,7 +130,7 @@ class FilterCol(Processor):
|
||||
|
||||
|
||||
class TanhProcess(Processor):
|
||||
""" Use tanh to process noise data"""
|
||||
"""Use tanh to process noise data"""
|
||||
|
||||
def __call__(self, df):
|
||||
def tanh_denoise(data):
|
||||
@@ -133,7 +145,7 @@ class TanhProcess(Processor):
|
||||
|
||||
|
||||
class ProcessInf(Processor):
|
||||
"""Process infinity """
|
||||
"""Process infinity"""
|
||||
|
||||
def __call__(self, df):
|
||||
def replace_inf(data):
|
||||
|
||||
118
qlib/log.py
118
qlib/log.py
@@ -12,7 +12,41 @@ from contextlib import contextmanager
|
||||
from .config import C
|
||||
|
||||
|
||||
def get_module_logger(module_name, level: Optional[int] = None):
|
||||
class MetaLogger(type):
|
||||
def __new__(cls, name, bases, dict):
|
||||
wrapper_dict = logging.Logger.__dict__.copy()
|
||||
for key in wrapper_dict:
|
||||
if key not in dict and key != "__reduce__":
|
||||
dict[key] = wrapper_dict[key]
|
||||
return type.__new__(cls, name, bases, dict)
|
||||
|
||||
|
||||
class QlibLogger(metaclass=MetaLogger):
|
||||
"""
|
||||
Customized logger for Qlib.
|
||||
"""
|
||||
|
||||
def __init__(self, module_name):
|
||||
self.module_name = module_name
|
||||
self.level = 0
|
||||
|
||||
@property
|
||||
def logger(self):
|
||||
logger = logging.getLogger(self.module_name)
|
||||
logger.setLevel(self.level)
|
||||
return logger
|
||||
|
||||
def setLevel(self, level):
|
||||
self.level = level
|
||||
|
||||
def __getattr__(self, name):
|
||||
# During unpickling, python will call __getattr__. Use this line to avoid maximum recursion error.
|
||||
if name in {"__setstate__"}:
|
||||
raise AttributeError
|
||||
return self.logger.__getattribute__(name)
|
||||
|
||||
|
||||
def get_module_logger(module_name, level: Optional[int] = None) -> logging.Logger:
|
||||
"""
|
||||
Get a logger for a specific module.
|
||||
|
||||
@@ -27,7 +61,7 @@ def get_module_logger(module_name, level: Optional[int] = None):
|
||||
|
||||
module_name = "qlib.{}".format(module_name)
|
||||
# Get logger.
|
||||
module_logger = logging.getLogger(module_name)
|
||||
module_logger = QlibLogger(module_name)
|
||||
module_logger.setLevel(level)
|
||||
return module_logger
|
||||
|
||||
@@ -129,3 +163,83 @@ class LogFilter(logging.Filter):
|
||||
elif isinstance(self.param, list):
|
||||
allow = not any([self.match_msg(p, record.msg) for p in self.param])
|
||||
return allow
|
||||
|
||||
|
||||
def set_global_logger_level(level: int, return_orig_handler_level: bool = False):
|
||||
"""set qlib.xxx logger handlers level
|
||||
|
||||
Parameters
|
||||
----------
|
||||
level: int
|
||||
logger level
|
||||
|
||||
return_orig_handler_level: bool
|
||||
return origin handler level map
|
||||
|
||||
Examples
|
||||
---------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import qlib
|
||||
import logging
|
||||
from qlib.log import get_module_logger, set_global_logger_level
|
||||
qlib.init()
|
||||
|
||||
tmp_logger_01 = get_module_logger("tmp_logger_01", level=logging.INFO)
|
||||
tmp_logger_01.info("1. tmp_logger_01 info show")
|
||||
|
||||
global_level = logging.WARNING + 1
|
||||
set_global_logger_level(global_level)
|
||||
tmp_logger_02 = get_module_logger("tmp_logger_02", level=logging.INFO)
|
||||
tmp_logger_02.log(msg="2. tmp_logger_02 log show", level=global_level)
|
||||
|
||||
tmp_logger_01.info("3. tmp_logger_01 info do not show")
|
||||
|
||||
"""
|
||||
_handler_level_map = {}
|
||||
qlib_logger = logging.root.manager.loggerDict.get("qlib", None)
|
||||
if qlib_logger is not None:
|
||||
for _handler in qlib_logger.handlers:
|
||||
_handler_level_map[_handler] = _handler.level
|
||||
_handler.level = level
|
||||
return _handler_level_map if return_orig_handler_level else None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_global_logger_level_cm(level: int):
|
||||
"""set qlib.xxx logger handlers level to use contextmanager
|
||||
|
||||
Parameters
|
||||
----------
|
||||
level: int
|
||||
logger level
|
||||
|
||||
Examples
|
||||
---------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
import qlib
|
||||
import logging
|
||||
from qlib.log import get_module_logger, set_global_logger_level_cm
|
||||
qlib.init()
|
||||
|
||||
tmp_logger_01 = get_module_logger("tmp_logger_01", level=logging.INFO)
|
||||
tmp_logger_01.info("1. tmp_logger_01 info show")
|
||||
|
||||
global_level = logging.WARNING + 1
|
||||
with set_global_logger_level_cm(global_level):
|
||||
tmp_logger_02 = get_module_logger("tmp_logger_02", level=logging.INFO)
|
||||
tmp_logger_02.log(msg="2. tmp_logger_02 log show", level=global_level)
|
||||
tmp_logger_01.info("3. tmp_logger_01 info do not show")
|
||||
|
||||
tmp_logger_01.info("4. tmp_logger_01 info show")
|
||||
|
||||
"""
|
||||
_handler_level_map = set_global_logger_level(level, return_orig_handler_level=True)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for _handler, _level in _handler_level_map.items():
|
||||
_handler.level = _level
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import abc
|
||||
from typing import Text, Union
|
||||
from ..utils.serial import Serializable
|
||||
from ..data.dataset import Dataset
|
||||
|
||||
@@ -10,11 +11,11 @@ class BaseModel(Serializable, metaclass=abc.ABCMeta):
|
||||
|
||||
@abc.abstractmethod
|
||||
def predict(self, *args, **kwargs) -> object:
|
||||
""" Make predictions after modeling things """
|
||||
"""Make predictions after modeling things"""
|
||||
pass
|
||||
|
||||
def __call__(self, *args, **kwargs) -> object:
|
||||
""" leverage Python syntactic sugar to make the models' behaviors like functions """
|
||||
"""leverage Python syntactic sugar to make the models' behaviors like functions"""
|
||||
return self.predict(*args, **kwargs)
|
||||
|
||||
|
||||
@@ -59,7 +60,7 @@ class Model(BaseModel):
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def predict(self, dataset: Dataset) -> object:
|
||||
def predict(self, dataset: Dataset, segment: Union[Text, slice] = "test") -> object:
|
||||
"""give prediction given Dataset
|
||||
|
||||
Parameters
|
||||
@@ -67,6 +68,9 @@ class Model(BaseModel):
|
||||
dataset : Dataset
|
||||
dataset will generate the processed dataset from model training.
|
||||
|
||||
segment : Text or slice
|
||||
dataset will use this segment to prepare data. (default=test)
|
||||
|
||||
Returns
|
||||
-------
|
||||
Prediction results with certain type such as `pandas.Series`.
|
||||
|
||||
0
qlib/model/ens/__init__.py
Normal file
0
qlib/model/ens/__init__.py
Normal file
115
qlib/model/ens/ensemble.py
Normal file
115
qlib/model/ens/ensemble.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
Ensemble module can merge the objects in an Ensemble. For example, if there are many submodels predictions, we may need to merge them into an ensemble prediction.
|
||||
"""
|
||||
|
||||
from typing import Union
|
||||
import pandas as pd
|
||||
from qlib.utils import FLATTEN_TUPLE, flatten_dict
|
||||
|
||||
|
||||
class Ensemble:
|
||||
"""Merge the ensemble_dict into an ensemble object.
|
||||
|
||||
For example: {Rollinga_b: object, Rollingb_c: object} -> object
|
||||
|
||||
When calling this class:
|
||||
|
||||
Args:
|
||||
ensemble_dict (dict): the ensemble dict like {name: things} waiting for merging
|
||||
|
||||
Returns:
|
||||
object: the ensemble object
|
||||
"""
|
||||
|
||||
def __call__(self, ensemble_dict: dict, *args, **kwargs):
|
||||
raise NotImplementedError(f"Please implement the `__call__` method.")
|
||||
|
||||
|
||||
class SingleKeyEnsemble(Ensemble):
|
||||
|
||||
"""
|
||||
Extract the object if there is only one key and value in the dict. Make the result more readable.
|
||||
{Only key: Only value} -> Only value
|
||||
|
||||
If there is more than 1 key or less than 1 key, then do nothing.
|
||||
Even you can run this recursively to make dict more readable.
|
||||
|
||||
NOTE: Default runs recursively.
|
||||
|
||||
When calling this class:
|
||||
|
||||
Args:
|
||||
ensemble_dict (dict): the dict. The key of the dict will be ignored.
|
||||
|
||||
Returns:
|
||||
dict: the readable dict.
|
||||
"""
|
||||
|
||||
def __call__(self, ensemble_dict: Union[dict, object], recursion: bool = True) -> object:
|
||||
if not isinstance(ensemble_dict, dict):
|
||||
return ensemble_dict
|
||||
if recursion:
|
||||
tmp_dict = {}
|
||||
for k, v in ensemble_dict.items():
|
||||
tmp_dict[k] = self(v, recursion)
|
||||
ensemble_dict = tmp_dict
|
||||
keys = list(ensemble_dict.keys())
|
||||
if len(keys) == 1:
|
||||
ensemble_dict = ensemble_dict[keys[0]]
|
||||
return ensemble_dict
|
||||
|
||||
|
||||
class RollingEnsemble(Ensemble):
|
||||
|
||||
"""Merge a dict of rolling dataframe like `prediction` or `IC` into an ensemble.
|
||||
|
||||
NOTE: The values of dict must be pd.DataFrame, and have the index "datetime".
|
||||
|
||||
When calling this class:
|
||||
|
||||
Args:
|
||||
ensemble_dict (dict): a dict like {"A": pd.DataFrame, "B": pd.DataFrame}.
|
||||
The key of the dict will be ignored.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: the complete result of rolling.
|
||||
"""
|
||||
|
||||
def __call__(self, ensemble_dict: dict) -> pd.DataFrame:
|
||||
artifact_list = list(ensemble_dict.values())
|
||||
artifact_list.sort(key=lambda x: x.index.get_level_values("datetime").min())
|
||||
artifact = pd.concat(artifact_list)
|
||||
# If there are duplicated predition, use the latest perdiction
|
||||
artifact = artifact[~artifact.index.duplicated(keep="last")]
|
||||
artifact = artifact.sort_index()
|
||||
return artifact
|
||||
|
||||
|
||||
class AverageEnsemble(Ensemble):
|
||||
"""
|
||||
Average and standardize a dict of same shape dataframe like `prediction` or `IC` into an ensemble.
|
||||
|
||||
NOTE: The values of dict must be pd.DataFrame, and have the index "datetime". If it is a nested dict, then flat it.
|
||||
|
||||
When calling this class:
|
||||
|
||||
Args:
|
||||
ensemble_dict (dict): a dict like {"A": pd.DataFrame, "B": pd.DataFrame}.
|
||||
The key of the dict will be ignored.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: the complete result of averaging and standardizing.
|
||||
"""
|
||||
|
||||
def __call__(self, ensemble_dict: dict) -> pd.DataFrame:
|
||||
# need to flatten the nested dict
|
||||
ensemble_dict = flatten_dict(ensemble_dict, sep=FLATTEN_TUPLE)
|
||||
values = list(ensemble_dict.values())
|
||||
results = pd.concat(values, axis=1)
|
||||
results = results.groupby("datetime").apply(lambda df: (df - df.mean()) / df.std())
|
||||
results = results.mean(axis=1)
|
||||
results = results.sort_index()
|
||||
return results
|
||||
113
qlib/model/ens/group.py
Normal file
113
qlib/model/ens/group.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
Group can group a set of objects based on `group_func` and change them to a dict.
|
||||
After group, we provide a method to reduce them.
|
||||
|
||||
For example:
|
||||
|
||||
group: {(A,B,C1): object, (A,B,C2): object} -> {(A,B): {C1: object, C2: object}}
|
||||
reduce: {(A,B): {C1: object, C2: object}} -> {(A,B): object}
|
||||
|
||||
"""
|
||||
|
||||
from qlib.model.ens.ensemble import Ensemble, RollingEnsemble
|
||||
from typing import Callable, Union
|
||||
from joblib import Parallel, delayed
|
||||
|
||||
|
||||
class Group:
|
||||
"""Group the objects based on dict"""
|
||||
|
||||
def __init__(self, group_func=None, ens: Ensemble = None):
|
||||
"""
|
||||
Init Group.
|
||||
|
||||
Args:
|
||||
group_func (Callable, optional): Given a dict and return the group key and one of the group elements.
|
||||
|
||||
For example: {(A,B,C1): object, (A,B,C2): object} -> {(A,B): {C1: object, C2: object}}
|
||||
|
||||
Defaults to None.
|
||||
|
||||
ens (Ensemble, optional): If not None, do ensemble for grouped value after grouping.
|
||||
"""
|
||||
self._group_func = group_func
|
||||
self._ens_func = ens
|
||||
|
||||
def group(self, *args, **kwargs) -> dict:
|
||||
"""
|
||||
Group a set of objects and change them to a dict.
|
||||
|
||||
For example: {(A,B,C1): object, (A,B,C2): object} -> {(A,B): {C1: object, C2: object}}
|
||||
|
||||
Returns:
|
||||
dict: grouped dict
|
||||
"""
|
||||
if isinstance(getattr(self, "_group_func", None), Callable):
|
||||
return self._group_func(*args, **kwargs)
|
||||
else:
|
||||
raise NotImplementedError(f"Please specify valid `group_func`.")
|
||||
|
||||
def reduce(self, *args, **kwargs) -> dict:
|
||||
"""
|
||||
Reduce grouped dict.
|
||||
|
||||
For example: {(A,B): {C1: object, C2: object}} -> {(A,B): object}
|
||||
|
||||
Returns:
|
||||
dict: reduced dict
|
||||
"""
|
||||
if isinstance(getattr(self, "_ens_func", None), Callable):
|
||||
return self._ens_func(*args, **kwargs)
|
||||
else:
|
||||
raise NotImplementedError(f"Please specify valid `_ens_func`.")
|
||||
|
||||
def __call__(self, ungrouped_dict: dict, n_jobs: int = 1, verbose: int = 0, *args, **kwargs) -> dict:
|
||||
"""
|
||||
Group the ungrouped_dict into different groups.
|
||||
|
||||
Args:
|
||||
ungrouped_dict (dict): the ungrouped dict waiting for grouping like {name: things}
|
||||
|
||||
Returns:
|
||||
dict: grouped_dict like {G1: object, G2: object}
|
||||
n_jobs: how many progress you need.
|
||||
verbose: the print mode for Parallel.
|
||||
"""
|
||||
|
||||
# NOTE: The multiprocessing will raise error if you use `Serializable`
|
||||
# Because the `Serializable` will affect the behaviors of pickle
|
||||
grouped_dict = self.group(ungrouped_dict, *args, **kwargs)
|
||||
|
||||
key_l = []
|
||||
job_l = []
|
||||
for key, value in grouped_dict.items():
|
||||
key_l.append(key)
|
||||
job_l.append(delayed(Group.reduce)(self, value))
|
||||
return dict(zip(key_l, Parallel(n_jobs=n_jobs, verbose=verbose)(job_l)))
|
||||
|
||||
|
||||
class RollingGroup(Group):
|
||||
"""Group the rolling dict"""
|
||||
|
||||
def group(self, rolling_dict: dict) -> dict:
|
||||
"""Given an rolling dict likes {(A,B,R): things}, return the grouped dict likes {(A,B): {R:things}}
|
||||
|
||||
NOTE: There is an assumption which is the rolling key is at the end of the key tuple, because the rolling results always need to be ensemble firstly.
|
||||
|
||||
Args:
|
||||
rolling_dict (dict): an rolling dict. If the key is not a tuple, then do nothing.
|
||||
|
||||
Returns:
|
||||
dict: grouped dict
|
||||
"""
|
||||
grouped_dict = {}
|
||||
for key, values in rolling_dict.items():
|
||||
if isinstance(key, tuple):
|
||||
grouped_dict.setdefault(key[:-1], {})[key[-1]] = values
|
||||
return grouped_dict
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(ens=RollingEnsemble())
|
||||
@@ -1,27 +0,0 @@
|
||||
import abc
|
||||
import typing
|
||||
|
||||
|
||||
class TaskGen(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
def __call__(self, *args, **kwargs) -> typing.List[dict]:
|
||||
"""
|
||||
generate
|
||||
|
||||
Parameters
|
||||
----------
|
||||
args, kwargs:
|
||||
The info for generating tasks
|
||||
Example 1):
|
||||
input: a specific task template
|
||||
output: rolling version of the tasks
|
||||
Example 2):
|
||||
input: a specific task template
|
||||
output: a set of tasks with different losses
|
||||
|
||||
Returns
|
||||
-------
|
||||
typing.List[dict]:
|
||||
A list of tasks
|
||||
"""
|
||||
pass
|
||||
@@ -1,42 +1,446 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from qlib.utils import init_instance_by_config, flatten_dict
|
||||
"""
|
||||
The Trainer will train a list of tasks and return a list of model recorders.
|
||||
There are two steps in each Trainer including ``train``(make model recorder) and ``end_train``(modify model recorder).
|
||||
|
||||
This is a concept called ``DelayTrainer``, which can be used in online simulating for parallel training.
|
||||
In ``DelayTrainer``, the first step is only to save some necessary info to model recorders, and the second step which will be finished in the end can do some concurrent and time-consuming operations such as model fitting.
|
||||
|
||||
``Qlib`` offer two kinds of Trainer, ``TrainerR`` is the simplest way and ``TrainerRM`` is based on TaskManager to help manager tasks lifecycle automatically.
|
||||
"""
|
||||
|
||||
import socket
|
||||
from typing import Callable, List
|
||||
|
||||
from qlib.data.dataset import Dataset
|
||||
from qlib.model.base import Model
|
||||
from qlib.utils import flatten_dict, get_cls_kwargs, init_instance_by_config
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord
|
||||
from qlib.workflow.recorder import Recorder
|
||||
from qlib.workflow.task.manage import TaskManager, run_task
|
||||
|
||||
|
||||
def task_train(task_config: dict, experiment_name):
|
||||
def begin_task_train(task_config: dict, experiment_name: str, recorder_name: str = None) -> Recorder:
|
||||
"""
|
||||
task based training
|
||||
Begin task training to start a recorder and save the task config.
|
||||
|
||||
Args:
|
||||
task_config (dict): the config of a task
|
||||
experiment_name (str): the name of experiment
|
||||
recorder_name (str): the given name will be the recorder name. None for using rid.
|
||||
|
||||
Returns:
|
||||
Recorder: the model recorder
|
||||
"""
|
||||
with R.start(experiment_name=experiment_name, recorder_name=recorder_name):
|
||||
R.log_params(**flatten_dict(task_config))
|
||||
R.save_objects(**{"task": task_config}) # keep the original format and datatype
|
||||
R.set_tags(**{"hostname": socket.gethostname()})
|
||||
recorder: Recorder = R.get_recorder()
|
||||
return recorder
|
||||
|
||||
|
||||
def end_task_train(rec: Recorder, experiment_name: str) -> Recorder:
|
||||
"""
|
||||
Finish task training with real model fitting and saving.
|
||||
|
||||
Args:
|
||||
rec (Recorder): the recorder will be resumed
|
||||
experiment_name (str): the name of experiment
|
||||
|
||||
Returns:
|
||||
Recorder: the model recorder
|
||||
"""
|
||||
with R.start(experiment_name=experiment_name, recorder_id=rec.info["id"], resume=True):
|
||||
task_config = R.load_object("task")
|
||||
# model & dataset initiation
|
||||
model: Model = init_instance_by_config(task_config["model"])
|
||||
dataset: Dataset = init_instance_by_config(task_config["dataset"])
|
||||
# model training
|
||||
model.fit(dataset)
|
||||
R.save_objects(**{"params.pkl": model})
|
||||
# this dataset is saved for online inference. So the concrete data should not be dumped
|
||||
dataset.config(dump_all=False, recursive=True)
|
||||
R.save_objects(**{"dataset": dataset})
|
||||
# generate records: prediction, backtest, and analysis
|
||||
records = task_config.get("record", [])
|
||||
if isinstance(records, dict): # prevent only one dict
|
||||
records = [records]
|
||||
for record in records:
|
||||
cls, kwargs = get_cls_kwargs(record, default_module="qlib.workflow.record_temp")
|
||||
if cls is SignalRecord:
|
||||
rconf = {"model": model, "dataset": dataset, "recorder": rec}
|
||||
else:
|
||||
rconf = {"recorder": rec}
|
||||
r = cls(**kwargs, **rconf)
|
||||
r.generate()
|
||||
|
||||
return rec
|
||||
|
||||
|
||||
def task_train(task_config: dict, experiment_name: str) -> Recorder:
|
||||
"""
|
||||
Task based training, will be divided into two steps.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task_config : dict
|
||||
A dict describes a task setting.
|
||||
The config of a task.
|
||||
experiment_name: str
|
||||
The name of experiment
|
||||
|
||||
Returns
|
||||
----------
|
||||
Recorder: The instance of the recorder
|
||||
"""
|
||||
recorder = begin_task_train(task_config, experiment_name)
|
||||
recorder = end_task_train(recorder, experiment_name)
|
||||
return recorder
|
||||
|
||||
|
||||
class Trainer:
|
||||
"""
|
||||
The trainer can train a list of models.
|
||||
There are Trainer and DelayTrainer, which can be distinguished by when it will finish real training.
|
||||
"""
|
||||
|
||||
# model initiaiton
|
||||
model = init_instance_by_config(task_config["model"])
|
||||
dataset = init_instance_by_config(task_config["dataset"])
|
||||
def __init__(self):
|
||||
self.delay = False
|
||||
|
||||
# start exp
|
||||
with R.start(experiment_name=experiment_name):
|
||||
# train model
|
||||
R.log_params(**flatten_dict(task_config))
|
||||
model.fit(dataset)
|
||||
recorder = R.get_recorder()
|
||||
R.save_objects(**{"params.pkl": model})
|
||||
def train(self, tasks: list, *args, **kwargs) -> list:
|
||||
"""
|
||||
Given a list of task definitions, begin training, and return the models.
|
||||
|
||||
# generate records: prediction, backtest, and analysis
|
||||
for record in task_config["record"]:
|
||||
if record["class"] == SignalRecord.__name__:
|
||||
srconf = {"model": model, "dataset": dataset, "recorder": recorder}
|
||||
record["kwargs"].update(srconf)
|
||||
sr = init_instance_by_config(record)
|
||||
sr.generate()
|
||||
else:
|
||||
rconf = {"recorder": recorder}
|
||||
record["kwargs"].update(rconf)
|
||||
ar = init_instance_by_config(record)
|
||||
ar.generate()
|
||||
For Trainer, it finishes real training in this method.
|
||||
For DelayTrainer, it only does some preparation in this method.
|
||||
|
||||
Args:
|
||||
tasks: a list of tasks
|
||||
|
||||
Returns:
|
||||
list: a list of models
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `train` method.")
|
||||
|
||||
def end_train(self, models: list, *args, **kwargs) -> list:
|
||||
"""
|
||||
Given a list of models, finished something at the end of training if you need.
|
||||
The models may be Recorder, txt file, database, and so on.
|
||||
|
||||
For Trainer, it does some finishing touches in this method.
|
||||
For DelayTrainer, it finishes real training in this method.
|
||||
|
||||
Args:
|
||||
models: a list of models
|
||||
|
||||
Returns:
|
||||
list: a list of models
|
||||
"""
|
||||
# do nothing if you finished all work in `train` method
|
||||
return models
|
||||
|
||||
def is_delay(self) -> bool:
|
||||
"""
|
||||
If Trainer will delay finishing `end_train`.
|
||||
|
||||
Returns:
|
||||
bool: if DelayTrainer
|
||||
"""
|
||||
return self.delay
|
||||
|
||||
|
||||
class TrainerR(Trainer):
|
||||
"""
|
||||
Trainer based on (R)ecorder.
|
||||
It will train a list of tasks and return a list of model recorders in a linear way.
|
||||
|
||||
Assumption: models were defined by `task` and the results will be saved to `Recorder`.
|
||||
"""
|
||||
|
||||
# Those tag will help you distinguish whether the Recorder has finished traning
|
||||
STATUS_KEY = "train_status"
|
||||
STATUS_BEGIN = "begin_task_train"
|
||||
STATUS_END = "end_task_train"
|
||||
|
||||
def __init__(self, experiment_name: str = None, train_func: Callable = task_train):
|
||||
"""
|
||||
Init TrainerR.
|
||||
|
||||
Args:
|
||||
experiment_name (str, optional): the default name of experiment.
|
||||
train_func (Callable, optional): default training method. Defaults to `task_train`.
|
||||
"""
|
||||
super().__init__()
|
||||
self.experiment_name = experiment_name
|
||||
self.train_func = train_func
|
||||
|
||||
def train(self, tasks: list, train_func: Callable = None, experiment_name: str = None, **kwargs) -> List[Recorder]:
|
||||
"""
|
||||
Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed.
|
||||
|
||||
Args:
|
||||
tasks (list): a list of definitions based on `task` dict
|
||||
train_func (Callable): the training method which needs at least `tasks` and `experiment_name`. None for the default training method.
|
||||
experiment_name (str): the experiment name, None for use default name.
|
||||
kwargs: the params for train_func.
|
||||
|
||||
Returns:
|
||||
List[Recorder]: a list of Recorders
|
||||
"""
|
||||
if len(tasks) == 0:
|
||||
return []
|
||||
if train_func is None:
|
||||
train_func = self.train_func
|
||||
if experiment_name is None:
|
||||
experiment_name = self.experiment_name
|
||||
recs = []
|
||||
for task in tasks:
|
||||
rec = train_func(task, experiment_name, **kwargs)
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN})
|
||||
recs.append(rec)
|
||||
return recs
|
||||
|
||||
def end_train(self, recs: list, **kwargs) -> List[Recorder]:
|
||||
"""
|
||||
Set STATUS_END tag to the recorders.
|
||||
|
||||
Args:
|
||||
recs (list): a list of trained recorders.
|
||||
|
||||
Returns:
|
||||
List[Recorder]: the same list as the param.
|
||||
"""
|
||||
for rec in recs:
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
|
||||
return recs
|
||||
|
||||
|
||||
class DelayTrainerR(TrainerR):
|
||||
"""
|
||||
A delayed implementation based on TrainerR, which means `train` method may only do some preparation and `end_train` method can do the real model fitting.
|
||||
"""
|
||||
|
||||
def __init__(self, experiment_name: str = None, train_func=begin_task_train, end_train_func=end_task_train):
|
||||
"""
|
||||
Init TrainerRM.
|
||||
|
||||
Args:
|
||||
experiment_name (str): the default name of experiment.
|
||||
train_func (Callable, optional): default train method. Defaults to `begin_task_train`.
|
||||
end_train_func (Callable, optional): default end_train method. Defaults to `end_task_train`.
|
||||
"""
|
||||
super().__init__(experiment_name, train_func)
|
||||
self.end_train_func = end_train_func
|
||||
self.delay = True
|
||||
|
||||
def end_train(self, recs, end_train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:
|
||||
"""
|
||||
Given a list of Recorder and return a list of trained Recorder.
|
||||
This class will finish real data loading and model fitting.
|
||||
|
||||
Args:
|
||||
recs (list): a list of Recorder, the tasks have been saved to them
|
||||
end_train_func (Callable, optional): the end_train method which needs at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
|
||||
experiment_name (str): the experiment name, None for use default name.
|
||||
kwargs: the params for end_train_func.
|
||||
|
||||
Returns:
|
||||
List[Recorder]: a list of Recorders
|
||||
"""
|
||||
if end_train_func is None:
|
||||
end_train_func = self.end_train_func
|
||||
if experiment_name is None:
|
||||
experiment_name = self.experiment_name
|
||||
for rec in recs:
|
||||
if rec.list_tags()[self.STATUS_KEY] == self.STATUS_END:
|
||||
continue
|
||||
end_train_func(rec, experiment_name, **kwargs)
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
|
||||
return recs
|
||||
|
||||
|
||||
class TrainerRM(Trainer):
|
||||
"""
|
||||
Trainer based on (R)ecorder and Task(M)anager.
|
||||
It can train a list of tasks and return a list of model recorders in a multiprocessing way.
|
||||
|
||||
Assumption: `task` will be saved to TaskManager and `task` will be fetched and trained from TaskManager
|
||||
"""
|
||||
|
||||
# Those tag will help you distinguish whether the Recorder has finished traning
|
||||
STATUS_KEY = "train_status"
|
||||
STATUS_BEGIN = "begin_task_train"
|
||||
STATUS_END = "end_task_train"
|
||||
|
||||
def __init__(self, experiment_name: str = None, task_pool: str = None, train_func=task_train):
|
||||
"""
|
||||
Init TrainerR.
|
||||
|
||||
Args:
|
||||
experiment_name (str): the default name of experiment.
|
||||
task_pool (str): task pool name in TaskManager. None for use same name as experiment_name.
|
||||
train_func (Callable, optional): default training method. Defaults to `task_train`.
|
||||
"""
|
||||
super().__init__()
|
||||
self.experiment_name = experiment_name
|
||||
self.task_pool = task_pool
|
||||
self.train_func = train_func
|
||||
|
||||
def train(
|
||||
self,
|
||||
tasks: list,
|
||||
train_func: Callable = None,
|
||||
experiment_name: str = None,
|
||||
before_status: str = TaskManager.STATUS_WAITING,
|
||||
after_status: str = TaskManager.STATUS_DONE,
|
||||
**kwargs,
|
||||
) -> List[Recorder]:
|
||||
"""
|
||||
Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed.
|
||||
|
||||
This method defaults to a single process, but TaskManager offered a great way to parallel training.
|
||||
Users can customize their train_func to realize multiple processes or even multiple machines.
|
||||
|
||||
Args:
|
||||
tasks (list): a list of definitions based on `task` dict
|
||||
train_func (Callable): the training method which needs at least `task`s and `experiment_name`. None for the default training method.
|
||||
experiment_name (str): the experiment name, None for use default name.
|
||||
before_status (str): the tasks in before_status will be fetched and trained. Can be STATUS_WAITING, STATUS_PART_DONE.
|
||||
after_status (str): the tasks after trained will become after_status. Can be STATUS_WAITING, STATUS_PART_DONE.
|
||||
kwargs: the params for train_func.
|
||||
|
||||
Returns:
|
||||
List[Recorder]: a list of Recorders
|
||||
"""
|
||||
if len(tasks) == 0:
|
||||
return []
|
||||
if train_func is None:
|
||||
train_func = self.train_func
|
||||
if experiment_name is None:
|
||||
experiment_name = self.experiment_name
|
||||
task_pool = self.task_pool
|
||||
if task_pool is None:
|
||||
task_pool = experiment_name
|
||||
tm = TaskManager(task_pool=task_pool)
|
||||
_id_list = tm.create_task(tasks) # all tasks will be saved to MongoDB
|
||||
run_task(
|
||||
train_func,
|
||||
task_pool,
|
||||
experiment_name=experiment_name,
|
||||
before_status=before_status,
|
||||
after_status=after_status,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
recs = []
|
||||
for _id in _id_list:
|
||||
rec = tm.re_query(_id)["res"]
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN})
|
||||
recs.append(rec)
|
||||
return recs
|
||||
|
||||
def end_train(self, recs: list, **kwargs) -> List[Recorder]:
|
||||
"""
|
||||
Set STATUS_END tag to the recorders.
|
||||
|
||||
Args:
|
||||
recs (list): a list of trained recorders.
|
||||
|
||||
Returns:
|
||||
List[Recorder]: the same list as the param.
|
||||
"""
|
||||
for rec in recs:
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
|
||||
return recs
|
||||
|
||||
|
||||
class DelayTrainerRM(TrainerRM):
|
||||
"""
|
||||
A delayed implementation based on TrainerRM, which means `train` method may only do some preparation and `end_train` method can do the real model fitting.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
experiment_name: str = None,
|
||||
task_pool: str = None,
|
||||
train_func=begin_task_train,
|
||||
end_train_func=end_task_train,
|
||||
):
|
||||
"""
|
||||
Init DelayTrainerRM.
|
||||
|
||||
Args:
|
||||
experiment_name (str): the default name of experiment.
|
||||
task_pool (str): task pool name in TaskManager. None for use same name as experiment_name.
|
||||
train_func (Callable, optional): default train method. Defaults to `begin_task_train`.
|
||||
end_train_func (Callable, optional): default end_train method. Defaults to `end_task_train`.
|
||||
"""
|
||||
super().__init__(experiment_name, task_pool, train_func)
|
||||
self.end_train_func = end_train_func
|
||||
self.delay = True
|
||||
|
||||
def train(self, tasks: list, train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:
|
||||
"""
|
||||
Same as `train` of TrainerRM, after_status will be STATUS_PART_DONE.
|
||||
|
||||
Args:
|
||||
tasks (list): a list of definition based on `task` dict
|
||||
train_func (Callable): the train method which need at least `task`s and `experiment_name`. Defaults to None for using self.train_func.
|
||||
experiment_name (str): the experiment name, None for use default name.
|
||||
|
||||
Returns:
|
||||
List[Recorder]: a list of Recorders
|
||||
"""
|
||||
if len(tasks) == 0:
|
||||
return []
|
||||
return super().train(
|
||||
tasks,
|
||||
train_func=train_func,
|
||||
experiment_name=experiment_name,
|
||||
after_status=TaskManager.STATUS_PART_DONE,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def end_train(self, recs, end_train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:
|
||||
"""
|
||||
Given a list of Recorder and return a list of trained Recorder.
|
||||
This class will finish real data loading and model fitting.
|
||||
|
||||
NOTE: This method will train all STATUS_PART_DONE tasks in the task pool, not only the ``recs``.
|
||||
|
||||
Args:
|
||||
recs (list): a list of Recorder, the tasks have been saved to them.
|
||||
end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
|
||||
experiment_name (str): the experiment name, None for use default name.
|
||||
kwargs: the params for end_train_func.
|
||||
|
||||
Returns:
|
||||
List[Recorder]: a list of Recorders
|
||||
"""
|
||||
|
||||
if end_train_func is None:
|
||||
end_train_func = self.end_train_func
|
||||
if experiment_name is None:
|
||||
experiment_name = self.experiment_name
|
||||
task_pool = self.task_pool
|
||||
if task_pool is None:
|
||||
task_pool = experiment_name
|
||||
tasks = []
|
||||
for rec in recs:
|
||||
tasks.append(rec.load_object("task"))
|
||||
|
||||
run_task(
|
||||
end_train_func,
|
||||
task_pool,
|
||||
query={"filter": {"$in": tasks}}, # only train these tasks
|
||||
experiment_name=experiment_name,
|
||||
before_status=TaskManager.STATUS_PART_DONE,
|
||||
**kwargs,
|
||||
)
|
||||
for rec in recs:
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
|
||||
return recs
|
||||
|
||||
@@ -5,9 +5,9 @@ import abc
|
||||
|
||||
|
||||
class BaseOptimizer(abc.ABC):
|
||||
""" Construct portfolio with a optimization related method """
|
||||
"""Construct portfolio with a optimization related method"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def __call__(self, *args, **kwargs) -> object:
|
||||
""" Generate a optimized portfolio allocation """
|
||||
"""Generate a optimized portfolio allocation"""
|
||||
pass
|
||||
|
||||
@@ -6,6 +6,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import re
|
||||
import copy
|
||||
import json
|
||||
@@ -24,7 +25,9 @@ import collections
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import Union, Tuple, Text, Optional
|
||||
from typing import Union, Tuple, Any, Text, Optional
|
||||
from types import ModuleType
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from ..config import C
|
||||
from ..log import get_module_logger, set_log_with_config
|
||||
@@ -165,24 +168,25 @@ def parse_field(field):
|
||||
return re.sub(r"\$(\w+)", r'Feature("\1")', re.sub(r"(\w+\s*)\(", r"Operators.\1(", field))
|
||||
|
||||
|
||||
def get_module_by_module_path(module_path):
|
||||
def get_module_by_module_path(module_path: Union[str, ModuleType]):
|
||||
"""Load module path
|
||||
|
||||
:param module_path:
|
||||
:return:
|
||||
"""
|
||||
|
||||
if module_path.endswith(".py"):
|
||||
module_spec = importlib.util.spec_from_file_location("", module_path)
|
||||
module = importlib.util.module_from_spec(module_spec)
|
||||
module_spec.loader.exec_module(module)
|
||||
if isinstance(module_path, ModuleType):
|
||||
module = module_path
|
||||
else:
|
||||
module = importlib.import_module(module_path)
|
||||
|
||||
if module_path.endswith(".py"):
|
||||
module_spec = importlib.util.spec_from_file_location("", module_path)
|
||||
module = importlib.util.module_from_spec(module_spec)
|
||||
module_spec.loader.exec_module(module)
|
||||
else:
|
||||
module = importlib.import_module(module_path)
|
||||
return module
|
||||
|
||||
|
||||
def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
|
||||
def get_cls_kwargs(config: Union[dict, str], default_module: Union[str, ModuleType] = None) -> (type, dict):
|
||||
"""
|
||||
extract class and kwargs from config info
|
||||
|
||||
@@ -191,8 +195,10 @@ def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
|
||||
config : [dict, str]
|
||||
similar to config
|
||||
|
||||
module : Python module
|
||||
default_module : Python module or str
|
||||
It should be a python module to load the class type
|
||||
This function will load class from the config['module_path'] first.
|
||||
If config['module_path'] doesn't exists, it will load the class from default_module.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -200,10 +206,14 @@ def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
|
||||
the class object and it's arguments.
|
||||
"""
|
||||
if isinstance(config, dict):
|
||||
module = get_module_by_module_path(config.get("module_path", default_module))
|
||||
|
||||
# raise AttributeError
|
||||
klass = getattr(module, config["class"])
|
||||
kwargs = config.get("kwargs", {})
|
||||
elif isinstance(config, str):
|
||||
module = get_module_by_module_path(default_module)
|
||||
|
||||
klass = getattr(module, config)
|
||||
kwargs = {}
|
||||
else:
|
||||
@@ -212,8 +222,8 @@ def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
|
||||
|
||||
|
||||
def init_instance_by_config(
|
||||
config: Union[str, dict, object], module=None, accept_types: Union[type, Tuple[type]] = (), **kwargs
|
||||
) -> object:
|
||||
config: Union[str, dict, object], default_module=None, accept_types: Union[type, Tuple[type]] = (), **kwargs
|
||||
) -> Any:
|
||||
"""
|
||||
get initialized instance with config
|
||||
|
||||
@@ -227,13 +237,19 @@ def init_instance_by_config(
|
||||
'model_path': path, # It is optional if module is given
|
||||
}
|
||||
str example.
|
||||
"ClassName": getattr(module, config)() will be used.
|
||||
1) specify a pickle object
|
||||
- path like 'file:///<path to pickle file>/obj.pkl'
|
||||
2) specify a class name
|
||||
- "ClassName": getattr(module, config)() will be used.
|
||||
object example:
|
||||
instance of accept_types
|
||||
module : Python module
|
||||
default_module : Python module
|
||||
Optional. It should be a python module.
|
||||
NOTE: the "module_path" will be override by `module` arguments
|
||||
|
||||
This function will load class from the config['module_path'] first.
|
||||
If config['module_path'] doesn't exists, it will load the class from default_module.
|
||||
|
||||
accept_types: Union[type, Tuple[type]]
|
||||
Optional. If the config is a instance of specific type, return the config directly.
|
||||
This will be passed into the second parameter of isinstance.
|
||||
@@ -246,10 +262,14 @@ def init_instance_by_config(
|
||||
if isinstance(config, accept_types):
|
||||
return config
|
||||
|
||||
if module is None:
|
||||
module = get_module_by_module_path(config["module_path"])
|
||||
if isinstance(config, str):
|
||||
# path like 'file:///<path to pickle file>/obj.pkl'
|
||||
pr = urlparse(config)
|
||||
if pr.scheme == "file":
|
||||
with open(os.path.join(pr.netloc, pr.path), "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
klass, cls_kwargs = get_cls_kwargs(config, module)
|
||||
klass, cls_kwargs = get_cls_kwargs(config, default_module=default_module)
|
||||
return klass(**cls_kwargs, **kwargs)
|
||||
|
||||
|
||||
@@ -502,7 +522,7 @@ def get_date_range(trading_date, left_shift=0, right_shift=0, future=False):
|
||||
return calendar
|
||||
|
||||
|
||||
def get_date_by_shift(trading_date, shift, future=False, clip_shift=True):
|
||||
def get_date_by_shift(trading_date, shift, future=False, clip_shift=True, freq="day"):
|
||||
"""get trading date with shift bias wil cur_date
|
||||
e.g. : shift == 1, return next trading date
|
||||
shift == -1, return previous trading date
|
||||
@@ -515,7 +535,7 @@ def get_date_by_shift(trading_date, shift, future=False, clip_shift=True):
|
||||
"""
|
||||
from qlib.data import D
|
||||
|
||||
cal = D.calendar(future=future)
|
||||
cal = D.calendar(future=future, freq=freq)
|
||||
if pd.to_datetime(trading_date) not in list(cal):
|
||||
raise ValueError("{} is not trading day!".format(str(trading_date)))
|
||||
_index = bisect.bisect_left(cal, trading_date)
|
||||
@@ -696,23 +716,33 @@ def lazy_sort_index(df: pd.DataFrame, axis=0) -> pd.DataFrame:
|
||||
return df.sort_index(axis=axis)
|
||||
|
||||
|
||||
def flatten_dict(d, parent_key="", sep="."):
|
||||
"""flatten_dict.
|
||||
FLATTEN_TUPLE = "_FLATTEN_TUPLE"
|
||||
|
||||
|
||||
def flatten_dict(d, parent_key="", sep=".") -> dict:
|
||||
"""
|
||||
Flatten a nested dict.
|
||||
|
||||
>>> flatten_dict({'a': 1, 'c': {'a': 2, 'b': {'x': 5, 'y' : 10}}, 'd': [1, 2, 3]})
|
||||
>>> {'a': 1, 'c.a': 2, 'c.b.x': 5, 'd': [1, 2, 3], 'c.b.y': 10}
|
||||
|
||||
Parameters
|
||||
----------
|
||||
d :
|
||||
d
|
||||
parent_key :
|
||||
parent_key
|
||||
sep :
|
||||
sep
|
||||
>>> flatten_dict({'a': 1, 'c': {'a': 2, 'b': {'x': 5, 'y' : 10}}, 'd': [1, 2, 3]}, sep=FLATTEN_TUPLE)
|
||||
>>> {'a': 1, ('c','a'): 2, ('c','b','x'): 5, 'd': [1, 2, 3], ('c','b','y'): 10}
|
||||
|
||||
Args:
|
||||
d (dict): the dict waiting for flatting
|
||||
parent_key (str, optional): the parent key, will be a prefix in new key. Defaults to "".
|
||||
sep (str, optional): the separator for string connecting. FLATTEN_TUPLE for tuple connecting.
|
||||
|
||||
Returns:
|
||||
dict: flatten dict
|
||||
"""
|
||||
items = []
|
||||
for k, v in d.items():
|
||||
new_key = parent_key + sep + k if parent_key else k
|
||||
if sep == FLATTEN_TUPLE:
|
||||
new_key = (parent_key, k) if parent_key else k
|
||||
else:
|
||||
new_key = parent_key + sep + k if parent_key else k
|
||||
if isinstance(v, collections.abc.MutableMapping):
|
||||
items.extend(flatten_dict(v, new_key, sep=sep).items())
|
||||
else:
|
||||
|
||||
@@ -3,16 +3,24 @@
|
||||
|
||||
from pathlib import Path
|
||||
import pickle
|
||||
import typing
|
||||
import dill
|
||||
from typing import Union
|
||||
|
||||
|
||||
class Serializable:
|
||||
"""
|
||||
Serializable behaves like pickle.
|
||||
But it only saves the state whose name **does not** start with `_`
|
||||
Serializable will change the behaviors of pickle.
|
||||
- It only saves the state whose name **does not** start with `_`
|
||||
It provides a syntactic sugar for distinguish the attributes which user doesn't want.
|
||||
- For examples, a learnable Datahandler just wants to save the parameters without data when dumping to disk
|
||||
"""
|
||||
|
||||
pickle_backend = "pickle" # another optional value is "dill" which can pickle more things of python.
|
||||
default_dump_all = False # if dump all things
|
||||
|
||||
def __init__(self):
|
||||
self._dump_all = False
|
||||
self._dump_all = self.default_dump_all
|
||||
self._exclude = []
|
||||
|
||||
def __getstate__(self) -> dict:
|
||||
@@ -33,18 +41,86 @@ class Serializable:
|
||||
@property
|
||||
def exclude(self):
|
||||
"""
|
||||
What attribute will be dumped
|
||||
What attribute will not be dumped
|
||||
"""
|
||||
return getattr(self, "_exclude", [])
|
||||
|
||||
def config(self, dump_all: bool = None, exclude: list = None):
|
||||
if dump_all is not None:
|
||||
self._dump_all = dump_all
|
||||
FLAG_KEY = "_qlib_serial_flag"
|
||||
|
||||
if exclude is not None:
|
||||
self._exclude = exclude
|
||||
def config(self, dump_all: bool = None, exclude: list = None, recursive=False):
|
||||
"""
|
||||
configure the serializable object
|
||||
|
||||
def to_pickle(self, path: [Path, str], dump_all: bool = None, exclude: list = None):
|
||||
Parameters
|
||||
----------
|
||||
dump_all : bool
|
||||
will the object dump all object
|
||||
exclude : list
|
||||
What attribute will not be dumped
|
||||
recursive : bool
|
||||
will the configuration be recursive
|
||||
"""
|
||||
|
||||
params = {"dump_all": dump_all, "exclude": exclude}
|
||||
|
||||
for k, v in params.items():
|
||||
if v is not None:
|
||||
attr_name = f"_{k}"
|
||||
setattr(self, attr_name, v)
|
||||
|
||||
if recursive:
|
||||
for obj in self.__dict__.values():
|
||||
# set flag to prevent endless loop
|
||||
self.__dict__[self.FLAG_KEY] = True
|
||||
if isinstance(obj, Serializable) and self.FLAG_KEY not in obj.__dict__:
|
||||
obj.config(**params, recursive=True)
|
||||
del self.__dict__[self.FLAG_KEY]
|
||||
|
||||
def to_pickle(self, path: Union[Path, str], dump_all: bool = None, exclude: list = None):
|
||||
"""
|
||||
Dump self to a pickle file.
|
||||
|
||||
Args:
|
||||
path (Union[Path, str]): the path to dump
|
||||
dump_all (bool, optional): if need to dump all things. Defaults to None.
|
||||
exclude (list, optional): will exclude the attributes in this list when dumping. Defaults to None.
|
||||
"""
|
||||
self.config(dump_all=dump_all, exclude=exclude)
|
||||
with Path(path).open("wb") as f:
|
||||
pickle.dump(self, f)
|
||||
self.get_backend().dump(self, f)
|
||||
|
||||
@classmethod
|
||||
def load(cls, filepath):
|
||||
"""
|
||||
Load the collector from a filepath.
|
||||
|
||||
Args:
|
||||
filepath (str): the path of file
|
||||
|
||||
Raises:
|
||||
TypeError: the pickled file must be `Collector`
|
||||
|
||||
Returns:
|
||||
Collector: the instance of Collector
|
||||
"""
|
||||
with open(filepath, "rb") as f:
|
||||
object = cls.get_backend().load(f)
|
||||
if isinstance(object, cls):
|
||||
return object
|
||||
else:
|
||||
raise TypeError(f"The instance of {type(object)} is not a valid `{type(cls)}`!")
|
||||
|
||||
@classmethod
|
||||
def get_backend(cls):
|
||||
"""
|
||||
Return the real backend of a Serializable class. The pickle_backend value can be "pickle" or "dill".
|
||||
|
||||
Returns:
|
||||
module: pickle or dill module based on pickle_backend
|
||||
"""
|
||||
if cls.pickle_backend == "pickle":
|
||||
return pickle
|
||||
elif cls.pickle_backend == "dill":
|
||||
return dill
|
||||
else:
|
||||
raise ValueError("Unknown pickle backend, please use 'pickle' or 'dill'.")
|
||||
|
||||
@@ -23,7 +23,10 @@ class QlibRecorder:
|
||||
@contextmanager
|
||||
def start(
|
||||
self,
|
||||
*,
|
||||
experiment_id: Optional[Text] = None,
|
||||
experiment_name: Optional[Text] = None,
|
||||
recorder_id: Optional[Text] = None,
|
||||
recorder_name: Optional[Text] = None,
|
||||
uri: Optional[Text] = None,
|
||||
resume: bool = False,
|
||||
@@ -45,8 +48,12 @@ class QlibRecorder:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
experiment_id : str
|
||||
id of the experiment one wants to start.
|
||||
experiment_name : str
|
||||
name of the experiment one wants to start.
|
||||
recorder_id : str
|
||||
id of the recorder under the experiment one wants to start.
|
||||
recorder_name : str
|
||||
name of the recorder under the experiment one wants to start.
|
||||
uri : str
|
||||
@@ -57,7 +64,14 @@ class QlibRecorder:
|
||||
resume : bool
|
||||
whether to resume the specific recorder with given name under the given experiment.
|
||||
"""
|
||||
run = self.start_exp(experiment_name, recorder_name, uri, resume)
|
||||
run = self.start_exp(
|
||||
experiment_id=experiment_id,
|
||||
experiment_name=experiment_name,
|
||||
recorder_id=recorder_id,
|
||||
recorder_name=recorder_name,
|
||||
uri=uri,
|
||||
resume=resume,
|
||||
)
|
||||
try:
|
||||
yield run
|
||||
except Exception as e:
|
||||
@@ -65,7 +79,9 @@ class QlibRecorder:
|
||||
raise e
|
||||
self.end_exp(Recorder.STATUS_FI)
|
||||
|
||||
def start_exp(self, experiment_name=None, recorder_name=None, uri=None, resume=False):
|
||||
def start_exp(
|
||||
self, *, experiment_id=None, experiment_name=None, recorder_id=None, recorder_name=None, uri=None, resume=False
|
||||
):
|
||||
"""
|
||||
Lower level method for starting an experiment. When use this method, one should end the experiment manually
|
||||
and the status of the recorder may not be handled properly. Here is the example code:
|
||||
@@ -79,8 +95,12 @@ class QlibRecorder:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
experiment_id : str
|
||||
id of the experiment one wants to start.
|
||||
experiment_name : str
|
||||
the name of the experiment to be started
|
||||
recorder_id : str
|
||||
id of the recorder under the experiment one wants to start.
|
||||
recorder_name : str
|
||||
name of the recorder under the experiment one wants to start.
|
||||
uri : str
|
||||
@@ -93,7 +113,14 @@ class QlibRecorder:
|
||||
-------
|
||||
An experiment instance being started.
|
||||
"""
|
||||
return self.exp_manager.start_exp(experiment_name, recorder_name, uri, resume)
|
||||
return self.exp_manager.start_exp(
|
||||
experiment_id=experiment_id,
|
||||
experiment_name=experiment_name,
|
||||
recorder_id=recorder_id,
|
||||
recorder_name=recorder_name,
|
||||
uri=uri,
|
||||
resume=resume,
|
||||
)
|
||||
|
||||
def end_exp(self, recorder_status=Recorder.STATUS_FI):
|
||||
"""
|
||||
@@ -202,13 +229,13 @@ class QlibRecorder:
|
||||
|
||||
- no id or name specified, return the active experiment.
|
||||
|
||||
- if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given id or name, and the experiment is set to be active.
|
||||
- if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given id or name.
|
||||
|
||||
- If `active experiment` not exists:
|
||||
|
||||
- no id or name specified, create a default experiment, and the experiment is set to be active.
|
||||
|
||||
- if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given name or the default experiment, and the experiment is set to be active.
|
||||
- if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given name or the default experiment.
|
||||
|
||||
- Else If '`create`' is False:
|
||||
|
||||
@@ -260,7 +287,7 @@ class QlibRecorder:
|
||||
-------
|
||||
An experiment instance with given id or name.
|
||||
"""
|
||||
return self.exp_manager.get_exp(experiment_id, experiment_name, create)
|
||||
return self.exp_manager.get_exp(experiment_id, experiment_name, create, start=False)
|
||||
|
||||
def delete_exp(self, experiment_id=None, experiment_name=None):
|
||||
"""
|
||||
@@ -304,7 +331,7 @@ class QlibRecorder:
|
||||
"""
|
||||
self.exp_manager.set_uri(uri)
|
||||
|
||||
def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=None):
|
||||
def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=None) -> Recorder:
|
||||
"""
|
||||
Method for retrieving a recorder.
|
||||
|
||||
@@ -358,7 +385,7 @@ class QlibRecorder:
|
||||
A recorder instance.
|
||||
"""
|
||||
return self.get_exp(experiment_name=experiment_name, create=False).get_recorder(
|
||||
recorder_id, recorder_name, create=False
|
||||
recorder_id, recorder_name, create=False, start=False
|
||||
)
|
||||
|
||||
def delete_recorder(self, recorder_id=None, recorder_name=None):
|
||||
@@ -416,6 +443,12 @@ class QlibRecorder:
|
||||
"""
|
||||
self.get_exp().get_recorder().save_objects(local_path, artifact_path, **kwargs)
|
||||
|
||||
def load_object(self, name: Text):
|
||||
"""
|
||||
Method for loading an object from artifacts in the experiment in the uri.
|
||||
"""
|
||||
return self.get_exp().get_recorder().load_object(name)
|
||||
|
||||
def log_params(self, **kwargs):
|
||||
"""
|
||||
Method for logging parameters during an experiment. In addition to using ``R``, one can also log to a specific recorder after getting it with `get_recorder` API.
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import mlflow
|
||||
import mlflow, logging
|
||||
from mlflow.entities import ViewType
|
||||
from mlflow.exceptions import MlflowException
|
||||
from pathlib import Path
|
||||
from .recorder import Recorder, MLflowRecorder
|
||||
from ..log import get_module_logger
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
logger = get_module_logger("workflow", logging.INFO)
|
||||
|
||||
|
||||
class Experiment:
|
||||
@@ -39,12 +39,14 @@ class Experiment:
|
||||
output["recorders"] = list(recorders.keys())
|
||||
return output
|
||||
|
||||
def start(self, recorder_name=None, resume=False):
|
||||
def start(self, *, recorder_id=None, recorder_name=None, resume=False):
|
||||
"""
|
||||
Start the experiment and set it to be active. This method will also start a new recorder.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
recorder_id : str
|
||||
the id of the recorder to be created.
|
||||
recorder_name : str
|
||||
the name of the recorder to be created.
|
||||
resume : bool
|
||||
@@ -107,24 +109,24 @@ class Experiment:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `delete_recorder` method.")
|
||||
|
||||
def get_recorder(self, recorder_id=None, recorder_name=None, create: bool = True):
|
||||
def get_recorder(self, recorder_id=None, recorder_name=None, create: bool = True, start: bool = False):
|
||||
"""
|
||||
Retrieve a Recorder for user. When user specify recorder id and name, the method will try to return the
|
||||
specific recorder. When user does not provide recorder id or name, the method will try to return the current
|
||||
active recorder. The `create` argument determines whether the method will automatically create a new recorder
|
||||
according to user's specification if the recorder hasn't been created before
|
||||
according to user's specification if the recorder hasn't been created before.
|
||||
|
||||
* If `create` is True:
|
||||
|
||||
* If `active recorder` exists:
|
||||
|
||||
* no id or name specified, return the active recorder.
|
||||
* if id or name is specified, return the specified recorder. If no such exp found, create a new recorder with given id or name, and the recorder shoud be active.
|
||||
* if id or name is specified, return the specified recorder. If no such exp found, create a new recorder with given id or name. If `start` is set to be True, the recorder is set to be active.
|
||||
|
||||
* If `active recorder` not exists:
|
||||
|
||||
* no id or name specified, create a new recorder.
|
||||
* if id or name is specified, return the specified experiment. If no such exp found, create a new recorder with given id or name, and the recorder shoud be active.
|
||||
* if id or name is specified, return the specified experiment. If no such exp found, create a new recorder with given id or name. If `start` is set to be True, the recorder is set to be active.
|
||||
|
||||
* Else If `create` is False:
|
||||
|
||||
@@ -146,6 +148,8 @@ class Experiment:
|
||||
the name of the recorder to be deleted.
|
||||
create : boolean
|
||||
create the recorder if it hasn't been created before.
|
||||
start : boolean
|
||||
start the new recorder if one is created.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -159,8 +163,11 @@ class Experiment:
|
||||
if create:
|
||||
recorder, is_new = self._get_or_create_rec(recorder_id=recorder_id, recorder_name=recorder_name)
|
||||
else:
|
||||
recorder, is_new = self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name), False
|
||||
if is_new:
|
||||
recorder, is_new = (
|
||||
self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name),
|
||||
False,
|
||||
)
|
||||
if is_new and start:
|
||||
self.active_recorder = recorder
|
||||
# start the recorder
|
||||
self.active_recorder.start_run()
|
||||
@@ -174,7 +181,10 @@ class Experiment:
|
||||
try:
|
||||
if recorder_id is None and recorder_name is None:
|
||||
recorder_name = self._default_rec_name
|
||||
return self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name), False
|
||||
return (
|
||||
self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name),
|
||||
False,
|
||||
)
|
||||
except ValueError:
|
||||
if recorder_name is None:
|
||||
recorder_name = self._default_rec_name
|
||||
@@ -230,14 +240,14 @@ class MLflowExperiment(Experiment):
|
||||
def __repr__(self):
|
||||
return "{name}(id={id}, info={info})".format(name=self.__class__.__name__, id=self.id, info=self.info)
|
||||
|
||||
def start(self, recorder_name=None, resume=False):
|
||||
def start(self, *, recorder_id=None, recorder_name=None, resume=False):
|
||||
logger.info(f"Experiment {self.id} starts running ...")
|
||||
# Get or create recorder
|
||||
if recorder_name is None:
|
||||
recorder_name = self._default_rec_name
|
||||
# resume the recorder
|
||||
if resume:
|
||||
recorder, _ = self._get_or_create_rec(recorder_name=recorder_name)
|
||||
recorder, _ = self._get_or_create_rec(recorder_id=recorder_id, recorder_name=recorder_name)
|
||||
# create a new recorder
|
||||
else:
|
||||
recorder = self.create_recorder(recorder_name)
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
import mlflow
|
||||
from mlflow.exceptions import MlflowException
|
||||
from mlflow.entities import ViewType
|
||||
import os
|
||||
import os, logging
|
||||
from pathlib import Path
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Text
|
||||
@@ -14,7 +14,7 @@ from ..config import C
|
||||
from .recorder import Recorder
|
||||
from ..log import get_module_logger
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
logger = get_module_logger("workflow", logging.INFO)
|
||||
|
||||
|
||||
class ExpManager:
|
||||
@@ -33,7 +33,10 @@ class ExpManager:
|
||||
|
||||
def start_exp(
|
||||
self,
|
||||
*,
|
||||
experiment_id: Optional[Text] = None,
|
||||
experiment_name: Optional[Text] = None,
|
||||
recorder_id: Optional[Text] = None,
|
||||
recorder_name: Optional[Text] = None,
|
||||
uri: Optional[Text] = None,
|
||||
resume: bool = False,
|
||||
@@ -45,8 +48,12 @@ class ExpManager:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
experiment_id : str
|
||||
id of the active experiment.
|
||||
experiment_name : str
|
||||
name of the active experiment.
|
||||
recorder_id : str
|
||||
id of the recorder to be started.
|
||||
recorder_name : str
|
||||
name of the recorder to be started.
|
||||
uri : str
|
||||
@@ -102,10 +109,9 @@ class ExpManager:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `search_records` method.")
|
||||
|
||||
def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True):
|
||||
def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True, start: bool = False):
|
||||
"""
|
||||
Retrieve an experiment. This method includes getting an active experiment, and get_or_create a specific experiment.
|
||||
The returned experiment will be active.
|
||||
|
||||
When user specify experiment id and name, the method will try to return the specific experiment.
|
||||
When user does not provide recorder id or name, the method will try to return the current active experiment.
|
||||
@@ -117,12 +123,12 @@ class ExpManager:
|
||||
* If `active experiment` exists:
|
||||
|
||||
* no id or name specified, return the active experiment.
|
||||
* if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given id or name, and the experiment is set to be active.
|
||||
* if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given id or name. If `start` is set to be True, the experiment is set to be active.
|
||||
|
||||
* If `active experiment` not exists:
|
||||
|
||||
* no id or name specified, create a default experiment.
|
||||
* if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given id or name, and the experiment is set to be active.
|
||||
* if id or name is specified, return the specified experiment. If no such exp found, create a new experiment with given id or name. If `start` is set to be True, the experiment is set to be active.
|
||||
|
||||
* Else If `create` is False:
|
||||
|
||||
@@ -144,6 +150,8 @@ class ExpManager:
|
||||
name of the experiment to return.
|
||||
create : boolean
|
||||
create the experiment it if hasn't been created before.
|
||||
start : boolean
|
||||
start the new experiment if one is created.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -159,8 +167,11 @@ class ExpManager:
|
||||
if create:
|
||||
exp, is_new = self._get_or_create_exp(experiment_id=experiment_id, experiment_name=experiment_name)
|
||||
else:
|
||||
exp, is_new = self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name), False
|
||||
if is_new:
|
||||
exp, is_new = (
|
||||
self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name),
|
||||
False,
|
||||
)
|
||||
if is_new and start:
|
||||
self.active_experiment = exp
|
||||
# start the recorder
|
||||
self.active_experiment.start()
|
||||
@@ -172,7 +183,10 @@ class ExpManager:
|
||||
automatically create a new experiment based on the given id and name.
|
||||
"""
|
||||
try:
|
||||
return self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name), False
|
||||
return (
|
||||
self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name),
|
||||
False,
|
||||
)
|
||||
except ValueError:
|
||||
if experiment_name is None:
|
||||
experiment_name = self._default_exp_name
|
||||
@@ -291,7 +305,10 @@ class MLflowExpManager(ExpManager):
|
||||
|
||||
def start_exp(
|
||||
self,
|
||||
*,
|
||||
experiment_id: Optional[Text] = None,
|
||||
experiment_name: Optional[Text] = None,
|
||||
recorder_id: Optional[Text] = None,
|
||||
recorder_name: Optional[Text] = None,
|
||||
uri: Optional[Text] = None,
|
||||
resume: bool = False,
|
||||
@@ -301,11 +318,11 @@ class MLflowExpManager(ExpManager):
|
||||
# Create experiment
|
||||
if experiment_name is None:
|
||||
experiment_name = self._default_exp_name
|
||||
experiment, _ = self._get_or_create_exp(experiment_name=experiment_name)
|
||||
experiment, _ = self._get_or_create_exp(experiment_id=experiment_id, experiment_name=experiment_name)
|
||||
# Set up active experiment
|
||||
self.active_experiment = experiment
|
||||
# Start the experiment
|
||||
self.active_experiment.start(recorder_name, resume)
|
||||
self.active_experiment.start(recorder_id=recorder_id, recorder_name=recorder_name, resume=resume)
|
||||
|
||||
return self.active_experiment
|
||||
|
||||
|
||||
0
qlib/workflow/online/__init__.py
Normal file
0
qlib/workflow/online/__init__.py
Normal file
304
qlib/workflow/online/manager.py
Normal file
304
qlib/workflow/online/manager.py
Normal file
@@ -0,0 +1,304 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
OnlineManager can manage a set of `Online Strategy <#Online Strategy>`_ and run them dynamically.
|
||||
|
||||
With the change of time, the decisive models will be also changed. In this module, we call those contributing models `online` models.
|
||||
In every routine(such as every day or every minute), the `online` models may be changed and the prediction of them needs to be updated.
|
||||
So this module provides a series of methods to control this process.
|
||||
|
||||
This module also provides a method to simulate `Online Strategy <#Online Strategy>`_ in history.
|
||||
Which means you can verify your strategy or find a better one.
|
||||
|
||||
There are 4 total situations for using different trainers in different situations:
|
||||
|
||||
|
||||
|
||||
========================= ===================================================================================
|
||||
Situations Description
|
||||
========================= ===================================================================================
|
||||
Online + Trainer When you REAL want to do a routine, the Trainer will help you train the models.
|
||||
|
||||
Online + DelayTrainer In normal online routine, whether Trainer or DelayTrainer will REAL train models
|
||||
in this routine. So it is not necessary to use DelayTrainer when do a REAL routine.
|
||||
|
||||
Simulation + Trainer When your models have some temporal dependence on the previous models, then you
|
||||
need to consider using Trainer. This means it will REAL train your models in
|
||||
every routine and prepare signals for every routine.
|
||||
|
||||
Simulation + DelayTrainer When your models don't have any temporal dependence, you can use DelayTrainer
|
||||
for the ability to multitasking. It means all tasks in all routines
|
||||
can be REAL trained at the end of simulating. The signals will be prepared well at
|
||||
different time segments (based on whether or not any new model is online).
|
||||
========================= ===================================================================================
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Callable, Dict, List, Union
|
||||
|
||||
import pandas as pd
|
||||
from qlib import get_module_logger
|
||||
from qlib.data.data import D
|
||||
from qlib.log import set_global_logger_level
|
||||
from qlib.model.ens.ensemble import AverageEnsemble
|
||||
from qlib.model.trainer import DelayTrainerR, Trainer, TrainerR
|
||||
from qlib.utils import flatten_dict
|
||||
from qlib.utils.serial import Serializable
|
||||
from qlib.workflow.online.strategy import OnlineStrategy
|
||||
from qlib.workflow.task.collect import MergeCollector
|
||||
|
||||
|
||||
class OnlineManager(Serializable):
|
||||
"""
|
||||
OnlineManager can manage online models with `Online Strategy <#Online Strategy>`_.
|
||||
It also provides a history recording of which models are online at what time.
|
||||
"""
|
||||
|
||||
STATUS_SIMULATING = "simulating" # when calling `simulate`
|
||||
STATUS_NORMAL = "normal" # the normal status
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
strategies: Union[OnlineStrategy, List[OnlineStrategy]],
|
||||
trainer: Trainer = None,
|
||||
begin_time: Union[str, pd.Timestamp] = None,
|
||||
freq="day",
|
||||
):
|
||||
"""
|
||||
Init OnlineManager.
|
||||
One OnlineManager must have at least one OnlineStrategy.
|
||||
|
||||
Args:
|
||||
strategies (Union[OnlineStrategy, List[OnlineStrategy]]): an instance of OnlineStrategy or a list of OnlineStrategy
|
||||
begin_time (Union[str,pd.Timestamp], optional): the OnlineManager will begin at this time. Defaults to None for using the latest date.
|
||||
trainer (Trainer): the trainer to train task. None for using TrainerR.
|
||||
freq (str, optional): data frequency. Defaults to "day".
|
||||
"""
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
if not isinstance(strategies, list):
|
||||
strategies = [strategies]
|
||||
self.strategies = strategies
|
||||
self.freq = freq
|
||||
if begin_time is None:
|
||||
begin_time = D.calendar(freq=self.freq).max()
|
||||
self.begin_time = pd.Timestamp(begin_time)
|
||||
self.cur_time = self.begin_time
|
||||
# OnlineManager will recorder the history of online models, which is a dict like {pd.Timestamp, {strategy, [online_models]}}.
|
||||
self.history = {}
|
||||
if trainer is None:
|
||||
trainer = TrainerR()
|
||||
self.trainer = trainer
|
||||
self.signals = None
|
||||
self.status = self.STATUS_NORMAL
|
||||
|
||||
def first_train(self, strategies: List[OnlineStrategy] = None, model_kwargs: dict = {}):
|
||||
"""
|
||||
Get tasks from every strategy's first_tasks method and train them.
|
||||
If using DelayTrainer, it can finish training all together after every strategy's first_tasks.
|
||||
|
||||
Args:
|
||||
strategies (List[OnlineStrategy]): the strategies list (need this param when adding strategies). None for use default strategies.
|
||||
model_kwargs (dict): the params for `prepare_online_models`
|
||||
"""
|
||||
if strategies is None:
|
||||
strategies = self.strategies
|
||||
for strategy in strategies:
|
||||
|
||||
self.logger.info(f"Strategy `{strategy.name_id}` begins first training...")
|
||||
tasks = strategy.first_tasks()
|
||||
models = self.trainer.train(tasks, experiment_name=strategy.name_id)
|
||||
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
|
||||
self.logger.info(f"Finished training {len(models)} models.")
|
||||
|
||||
online_models = strategy.prepare_online_models(models, **model_kwargs)
|
||||
self.history.setdefault(self.cur_time, {})[strategy] = online_models
|
||||
|
||||
def routine(
|
||||
self,
|
||||
cur_time: Union[str, pd.Timestamp] = None,
|
||||
task_kwargs: dict = {},
|
||||
model_kwargs: dict = {},
|
||||
signal_kwargs: dict = {},
|
||||
):
|
||||
"""
|
||||
Typical update process for every strategy and record the online history.
|
||||
|
||||
The typical update process after a routine, such as day by day or month by month.
|
||||
The process is: Update predictions -> Prepare tasks -> Prepare online models -> Prepare signals.
|
||||
|
||||
If using DelayTrainer, it can finish training all together after every strategy's prepare_tasks.
|
||||
|
||||
Args:
|
||||
cur_time (Union[str,pd.Timestamp], optional): run routine method in this time. Defaults to None.
|
||||
task_kwargs (dict): the params for `prepare_tasks`
|
||||
model_kwargs (dict): the params for `prepare_online_models`
|
||||
signal_kwargs (dict): the params for `prepare_signals`
|
||||
"""
|
||||
if cur_time is None:
|
||||
cur_time = D.calendar(freq=self.freq).max()
|
||||
self.cur_time = pd.Timestamp(cur_time) # None for latest date
|
||||
|
||||
for strategy in self.strategies:
|
||||
self.logger.info(f"Strategy `{strategy.name_id}` begins routine...")
|
||||
if self.status == self.STATUS_NORMAL:
|
||||
strategy.tool.update_online_pred()
|
||||
|
||||
tasks = strategy.prepare_tasks(self.cur_time, **task_kwargs)
|
||||
models = self.trainer.train(tasks)
|
||||
if self.status == self.STATUS_NORMAL or not self.trainer.is_delay():
|
||||
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
|
||||
self.logger.info(f"Finished training {len(models)} models.")
|
||||
online_models = strategy.prepare_online_models(models, **model_kwargs)
|
||||
self.history.setdefault(self.cur_time, {})[strategy] = online_models
|
||||
|
||||
if not self.trainer.is_delay():
|
||||
self.prepare_signals(**signal_kwargs)
|
||||
|
||||
def get_collector(self) -> MergeCollector:
|
||||
"""
|
||||
Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results from every strategy.
|
||||
This collector can be a basis as the signals preparation.
|
||||
|
||||
Returns:
|
||||
MergeCollector: the collector to merge other collectors.
|
||||
"""
|
||||
collector_dict = {}
|
||||
for strategy in self.strategies:
|
||||
collector_dict[strategy.name_id] = strategy.get_collector()
|
||||
return MergeCollector(collector_dict, process_list=[])
|
||||
|
||||
def add_strategy(self, strategies: Union[OnlineStrategy, List[OnlineStrategy]]):
|
||||
"""
|
||||
Add some new strategies to OnlineManager.
|
||||
|
||||
Args:
|
||||
strategy (Union[OnlineStrategy, List[OnlineStrategy]]): a list of OnlineStrategy
|
||||
"""
|
||||
if not isinstance(strategies, list):
|
||||
strategies = [strategies]
|
||||
self.first_train(strategies)
|
||||
self.strategies.extend(strategies)
|
||||
|
||||
def prepare_signals(self, prepare_func: Callable = AverageEnsemble(), over_write=False):
|
||||
"""
|
||||
After preparing the data of the last routine (a box in box-plot) which means the end of the routine, we can prepare trading signals for the next routine.
|
||||
|
||||
NOTE: Given a set prediction, all signals before these prediction end times will be prepared well.
|
||||
|
||||
Even if the latest signal already exists, the latest calculation result will be overwritten.
|
||||
|
||||
.. note::
|
||||
|
||||
Given a prediction of a certain time, all signals before this time will be prepared well.
|
||||
|
||||
Args:
|
||||
prepare_func (Callable, optional): Get signals from a dict after collecting. Defaults to AverageEnsemble(), the results collected by MergeCollector must be {xxx:pred}.
|
||||
over_write (bool, optional): If True, the new signals will overwrite. If False, the new signals will append to the end of signals. Defaults to False.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame: the signals.
|
||||
"""
|
||||
signals = prepare_func(self.get_collector()())
|
||||
old_signals = self.signals
|
||||
if old_signals is not None and not over_write:
|
||||
old_max = old_signals.index.get_level_values("datetime").max()
|
||||
new_signals = signals.loc[old_max:]
|
||||
signals = pd.concat([old_signals, new_signals], axis=0)
|
||||
else:
|
||||
new_signals = signals
|
||||
self.logger.info(f"Finished preparing new {len(new_signals)} signals.")
|
||||
self.signals = signals
|
||||
return new_signals
|
||||
|
||||
def get_signals(self) -> Union[pd.Series, pd.DataFrame]:
|
||||
"""
|
||||
Get prepared online signals.
|
||||
|
||||
Returns:
|
||||
Union[pd.Series, pd.DataFrame]: pd.Series for only one signals every datetime.
|
||||
pd.DataFrame for multiple signals, for example, buy and sell operations use different trading signals.
|
||||
"""
|
||||
return self.signals
|
||||
|
||||
SIM_LOG_LEVEL = logging.INFO + 1 # when simulating, reduce information
|
||||
SIM_LOG_NAME = "SIMULATE_INFO"
|
||||
|
||||
def simulate(
|
||||
self, end_time, frequency="day", task_kwargs={}, model_kwargs={}, signal_kwargs={}
|
||||
) -> Union[pd.Series, pd.DataFrame]:
|
||||
"""
|
||||
Starting from the current time, this method will simulate every routine in OnlineManager until the end time.
|
||||
|
||||
Considering the parallel training, the models and signals can be prepared after all routine simulating.
|
||||
|
||||
The delay training way can be ``DelayTrainer`` and the delay preparing signals way can be ``delay_prepare``.
|
||||
|
||||
Args:
|
||||
end_time: the time the simulation will end
|
||||
frequency: the calendar frequency
|
||||
task_kwargs (dict): the params for `prepare_tasks`
|
||||
model_kwargs (dict): the params for `prepare_online_models`
|
||||
signal_kwargs (dict): the params for `prepare_signals`
|
||||
|
||||
Returns:
|
||||
Union[pd.Series, pd.DataFrame]: pd.Series for only one signals every datetime.
|
||||
pd.DataFrame for multiple signals, for example, buy and sell operations use different trading signals.
|
||||
"""
|
||||
self.status = self.STATUS_SIMULATING
|
||||
cal = D.calendar(start_time=self.cur_time, end_time=end_time, freq=frequency)
|
||||
self.first_train()
|
||||
|
||||
simulate_level = self.SIM_LOG_LEVEL
|
||||
set_global_logger_level(simulate_level)
|
||||
logging.addLevelName(simulate_level, self.SIM_LOG_NAME)
|
||||
|
||||
for cur_time in cal:
|
||||
self.logger.log(level=simulate_level, msg=f"Simulating at {str(cur_time)}......")
|
||||
self.routine(
|
||||
cur_time,
|
||||
task_kwargs=task_kwargs,
|
||||
model_kwargs=model_kwargs,
|
||||
signal_kwargs=signal_kwargs,
|
||||
)
|
||||
# delay prepare the models and signals
|
||||
if self.trainer.is_delay():
|
||||
self.delay_prepare(model_kwargs=model_kwargs, signal_kwargs=signal_kwargs)
|
||||
|
||||
# FIXME: get logging level firstly and restore it here
|
||||
set_global_logger_level(logging.DEBUG)
|
||||
self.logger.info(f"Finished preparing signals")
|
||||
self.status = self.STATUS_NORMAL
|
||||
return self.get_signals()
|
||||
|
||||
def delay_prepare(self, model_kwargs={}, signal_kwargs={}):
|
||||
"""
|
||||
Prepare all models and signals if something is waiting for preparation.
|
||||
|
||||
Args:
|
||||
model_kwargs: the params for `end_train`
|
||||
signal_kwargs: the params for `prepare_signals`
|
||||
"""
|
||||
last_models = {}
|
||||
signals_time = D.calendar()[0]
|
||||
need_prepare = False
|
||||
for cur_time, strategy_models in self.history.items():
|
||||
self.cur_time = cur_time
|
||||
|
||||
for strategy, models in strategy_models.items():
|
||||
# only new online models need to prepare
|
||||
if last_models.setdefault(strategy, set()) != set(models):
|
||||
models = self.trainer.end_train(models, experiment_name=strategy.name_id, **model_kwargs)
|
||||
strategy.tool.reset_online_tag(models)
|
||||
need_prepare = True
|
||||
last_models[strategy] = set(models)
|
||||
|
||||
if need_prepare:
|
||||
# NOTE: Assumption: the predictions of online models need less than next cur_time, or this method will work in a wrong way.
|
||||
self.prepare_signals(**signal_kwargs)
|
||||
if signals_time > cur_time:
|
||||
self.logger.warn(
|
||||
f"The signals have already parpred to {signals_time} by last preparation, but current time is only {cur_time}. This may be because the online models predict more than they should, which can cause signals to be contaminated by the offline models."
|
||||
)
|
||||
need_prepare = False
|
||||
signals_time = self.signals.index.get_level_values("datetime").max()
|
||||
211
qlib/workflow/online/strategy.py
Normal file
211
qlib/workflow/online/strategy.py
Normal file
@@ -0,0 +1,211 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
OnlineStrategy module is an element of online serving.
|
||||
"""
|
||||
|
||||
from copy import deepcopy
|
||||
from typing import List, Tuple, Union
|
||||
from qlib.data.data import D
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.model.ens.group import RollingGroup
|
||||
from qlib.workflow.online.utils import OnlineTool, OnlineToolR
|
||||
from qlib.workflow.recorder import Recorder
|
||||
from qlib.workflow.task.collect import Collector, RecorderCollector
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.utils import TimeAdjuster
|
||||
|
||||
|
||||
class OnlineStrategy:
|
||||
"""
|
||||
OnlineStrategy is working with `Online Manager <#Online Manager>`_, responding to how the tasks are generated, the models are updated and signals are prepared.
|
||||
"""
|
||||
|
||||
def __init__(self, name_id: str):
|
||||
"""
|
||||
Init OnlineStrategy.
|
||||
This module **MUST** use `Trainer <../reference/api.html#Trainer>`_ to finishing model training.
|
||||
|
||||
Args:
|
||||
name_id (str): a unique name or id.
|
||||
trainer (Trainer, optional): a instance of Trainer. Defaults to None.
|
||||
"""
|
||||
self.name_id = name_id
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
self.tool = OnlineTool()
|
||||
|
||||
def prepare_tasks(self, cur_time, **kwargs) -> List[dict]:
|
||||
"""
|
||||
After the end of a routine, check whether we need to prepare and train some new tasks based on cur_time (None for latest)..
|
||||
Return the new tasks waiting for training.
|
||||
|
||||
You can find the last online models by OnlineTool.online_models.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `prepare_tasks` method.")
|
||||
|
||||
def prepare_online_models(self, trained_models, cur_time=None) -> List[object]:
|
||||
"""
|
||||
Select some models from trained models and set them to online models.
|
||||
This is a typical implementation to online all trained models, you can override it to implement the complex method.
|
||||
You can find the last online models by OnlineTool.online_models if you still need them.
|
||||
|
||||
NOTE: Reset all online models to trained models. If there are no trained models, then do nothing.
|
||||
|
||||
Args:
|
||||
models (list): a list of models.
|
||||
cur_time (pd.Dataframe): current time from OnlineManger. None for the latest.
|
||||
|
||||
Returns:
|
||||
List[object]: a list of online models.
|
||||
"""
|
||||
if not trained_models:
|
||||
return self.tool.online_models()
|
||||
self.tool.reset_online_tag(trained_models)
|
||||
return trained_models
|
||||
|
||||
def first_tasks(self) -> List[dict]:
|
||||
"""
|
||||
Generate a series of tasks firstly and return them.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `first_tasks` method.")
|
||||
|
||||
def get_collector(self) -> Collector:
|
||||
"""
|
||||
Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect different results of this strategy.
|
||||
|
||||
For example:
|
||||
1) collect predictions in Recorder
|
||||
2) collect signals in a txt file
|
||||
|
||||
Returns:
|
||||
Collector
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_collector` method.")
|
||||
|
||||
|
||||
class RollingStrategy(OnlineStrategy):
|
||||
|
||||
"""
|
||||
This example strategy always uses the latest rolling model sas online models.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name_id: str,
|
||||
task_template: Union[dict, List[dict]],
|
||||
rolling_gen: RollingGen,
|
||||
):
|
||||
"""
|
||||
Init RollingStrategy.
|
||||
|
||||
Assumption: the str of name_id, the experiment name, and the trainer's experiment name are the same.
|
||||
|
||||
Args:
|
||||
name_id (str): a unique name or id. Will be also the name of the Experiment.
|
||||
task_template (Union[dict, List[dict]]): a list of task_template or a single template, which will be used to generate many tasks using rolling_gen.
|
||||
rolling_gen (RollingGen): an instance of RollingGen
|
||||
"""
|
||||
super().__init__(name_id=name_id)
|
||||
self.exp_name = self.name_id
|
||||
if not isinstance(task_template, list):
|
||||
task_template = [task_template]
|
||||
self.task_template = task_template
|
||||
self.rg = rolling_gen
|
||||
self.tool = OnlineToolR(self.exp_name)
|
||||
self.ta = TimeAdjuster()
|
||||
|
||||
def get_collector(self, process_list=[RollingGroup()], rec_key_func=None, rec_filter_func=None, artifacts_key=None):
|
||||
"""
|
||||
Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results. The returned collector must distinguish results in different models.
|
||||
|
||||
Assumption: the models can be distinguished based on the model name and rolling test segments.
|
||||
If you do not want this assumption, please implement your method or use another rec_key_func.
|
||||
|
||||
Args:
|
||||
rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.
|
||||
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
|
||||
artifacts_key (List[str], optional): the artifacts key you want to get. If None, get all artifacts.
|
||||
"""
|
||||
|
||||
def rec_key(recorder):
|
||||
task_config = recorder.load_object("task")
|
||||
model_key = task_config["model"]["class"]
|
||||
rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"]
|
||||
return model_key, rolling_key
|
||||
|
||||
if rec_key_func is None:
|
||||
rec_key_func = rec_key
|
||||
|
||||
artifacts_collector = RecorderCollector(
|
||||
experiment=self.exp_name,
|
||||
process_list=process_list,
|
||||
rec_key_func=rec_key_func,
|
||||
rec_filter_func=rec_filter_func,
|
||||
artifacts_key=artifacts_key,
|
||||
)
|
||||
|
||||
return artifacts_collector
|
||||
|
||||
def first_tasks(self) -> List[dict]:
|
||||
"""
|
||||
Use rolling_gen to generate different tasks based on task_template.
|
||||
|
||||
Returns:
|
||||
List[dict]: a list of tasks
|
||||
"""
|
||||
return task_generator(
|
||||
tasks=self.task_template,
|
||||
generators=self.rg, # generate different date segment
|
||||
)
|
||||
|
||||
def prepare_tasks(self, cur_time) -> List[dict]:
|
||||
"""
|
||||
Prepare new tasks based on cur_time (None for the latest).
|
||||
|
||||
You can find the last online models by OnlineToolR.online_models.
|
||||
|
||||
Returns:
|
||||
List[dict]: a list of new tasks.
|
||||
"""
|
||||
latest_records, max_test = self._list_latest(self.tool.online_models())
|
||||
if max_test is None:
|
||||
self.logger.warn(f"No latest online recorders, no new tasks.")
|
||||
return []
|
||||
calendar_latest = D.calendar(end_time=cur_time)[-1] if cur_time is None else cur_time
|
||||
self.logger.info(
|
||||
f"The interval between current time {calendar_latest} and last rolling test begin time {max_test[0]} is {self.ta.cal_interval(calendar_latest, max_test[0])}, the rolling step is {self.rg.step}"
|
||||
)
|
||||
if self.ta.cal_interval(calendar_latest, max_test[0]) >= self.rg.step:
|
||||
old_tasks = []
|
||||
tasks_tmp = []
|
||||
for rec in latest_records:
|
||||
task = rec.load_object("task")
|
||||
old_tasks.append(deepcopy(task))
|
||||
test_begin = task["dataset"]["kwargs"]["segments"]["test"][0]
|
||||
# modify the test segment to generate new tasks
|
||||
task["dataset"]["kwargs"]["segments"]["test"] = (test_begin, calendar_latest)
|
||||
tasks_tmp.append(task)
|
||||
new_tasks_tmp = task_generator(tasks_tmp, self.rg)
|
||||
new_tasks = [task for task in new_tasks_tmp if task not in old_tasks]
|
||||
return new_tasks
|
||||
return []
|
||||
|
||||
def _list_latest(self, rec_list: List[Recorder]):
|
||||
"""
|
||||
List latest recorder form rec_list
|
||||
|
||||
Args:
|
||||
rec_list (List[Recorder]): a list of Recorder
|
||||
|
||||
Returns:
|
||||
List[Recorder], pd.Timestamp: the latest recorders and their test end time
|
||||
"""
|
||||
if len(rec_list) == 0:
|
||||
return rec_list, None
|
||||
max_test = max(rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] for rec in rec_list)
|
||||
latest_rec = []
|
||||
for rec in rec_list:
|
||||
if rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] == max_test:
|
||||
latest_rec.append(rec)
|
||||
return latest_rec, max_test
|
||||
160
qlib/workflow/online/update.py
Normal file
160
qlib/workflow/online/update.py
Normal file
@@ -0,0 +1,160 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
Updater is a module to update artifacts such as predictions when the stock data is updating.
|
||||
"""
|
||||
|
||||
from abc import ABCMeta, abstractmethod
|
||||
|
||||
import pandas as pd
|
||||
from qlib import get_module_logger
|
||||
from qlib.data import D
|
||||
from qlib.data.dataset import DatasetH
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.model import Model
|
||||
from qlib.utils import get_date_by_shift
|
||||
from qlib.workflow.recorder import Recorder
|
||||
|
||||
|
||||
class RMDLoader:
|
||||
"""
|
||||
Recorder Model Dataset Loader
|
||||
"""
|
||||
|
||||
def __init__(self, rec: Recorder):
|
||||
self.rec = rec
|
||||
|
||||
def get_dataset(self, start_time, end_time, segments=None) -> DatasetH:
|
||||
"""
|
||||
Load, config and setup dataset.
|
||||
|
||||
This dataset is for inference.
|
||||
|
||||
Args:
|
||||
start_time :
|
||||
the start_time of underlying data
|
||||
end_time :
|
||||
the end_time of underlying data
|
||||
segments : dict
|
||||
the segments config for dataset
|
||||
Due to the time series dataset (TSDatasetH), the test segments maybe different from start_time and end_time
|
||||
|
||||
Returns:
|
||||
DatasetH: the instance of DatasetH
|
||||
|
||||
"""
|
||||
if segments is None:
|
||||
segments = {"test": (start_time, end_time)}
|
||||
dataset: DatasetH = self.rec.load_object("dataset")
|
||||
dataset.config(handler_kwargs={"start_time": start_time, "end_time": end_time}, segments=segments)
|
||||
dataset.setup_data(handler_kwargs={"init_type": DataHandlerLP.IT_LS})
|
||||
return dataset
|
||||
|
||||
def get_model(self) -> Model:
|
||||
return self.rec.load_object("params.pkl")
|
||||
|
||||
|
||||
class RecordUpdater(metaclass=ABCMeta):
|
||||
"""
|
||||
Update a specific recorders
|
||||
"""
|
||||
|
||||
def __init__(self, record: Recorder, *args, **kwargs):
|
||||
self.record = record
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
|
||||
@abstractmethod
|
||||
def update(self, *args, **kwargs):
|
||||
"""
|
||||
Update info for specific recorder
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class PredUpdater(RecordUpdater):
|
||||
"""
|
||||
Update the prediction in the Recorder
|
||||
"""
|
||||
|
||||
def __init__(self, record: Recorder, to_date=None, hist_ref: int = 0, freq="day"):
|
||||
"""
|
||||
Init PredUpdater.
|
||||
|
||||
Args:
|
||||
record : Recorder
|
||||
to_date :
|
||||
update to prediction to the `to_date`
|
||||
hist_ref : int
|
||||
Sometimes, the dataset will have historical depends.
|
||||
Leave the problem to users to set the length of historical dependency
|
||||
|
||||
.. note::
|
||||
|
||||
the start_time is not included in the hist_ref
|
||||
|
||||
"""
|
||||
# TODO: automate this hist_ref in the future.
|
||||
super().__init__(record=record)
|
||||
|
||||
self.to_date = to_date
|
||||
self.hist_ref = hist_ref
|
||||
self.freq = freq
|
||||
self.rmdl = RMDLoader(rec=record)
|
||||
|
||||
if to_date == None:
|
||||
to_date = D.calendar(freq=freq)[-1]
|
||||
self.to_date = pd.Timestamp(to_date)
|
||||
self.old_pred = record.load_object("pred.pkl")
|
||||
self.last_end = self.old_pred.index.get_level_values("datetime").max()
|
||||
|
||||
def prepare_data(self) -> DatasetH:
|
||||
"""
|
||||
Load dataset
|
||||
|
||||
Separating this function will make it easier to reuse the dataset
|
||||
|
||||
Returns:
|
||||
DatasetH: the instance of DatasetH
|
||||
"""
|
||||
start_time_buffer = get_date_by_shift(self.last_end, -self.hist_ref + 1, clip_shift=False, freq=self.freq)
|
||||
start_time = get_date_by_shift(self.last_end, 1, freq=self.freq)
|
||||
seg = {"test": (start_time, self.to_date)}
|
||||
dataset = self.rmdl.get_dataset(start_time=start_time_buffer, end_time=self.to_date, segments=seg)
|
||||
return dataset
|
||||
|
||||
def update(self, dataset: DatasetH = None):
|
||||
"""
|
||||
Update the prediction in a recorder.
|
||||
|
||||
Args:
|
||||
DatasetH: the instance of DatasetH. None for reprepare.
|
||||
"""
|
||||
# FIXME: the problem below is not solved
|
||||
# The model dumped on GPU instances can not be loaded on CPU instance. Follow exception will raised
|
||||
# RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
|
||||
# https://github.com/pytorch/pytorch/issues/16797
|
||||
|
||||
start_time = get_date_by_shift(self.last_end, 1, freq=self.freq)
|
||||
if start_time >= self.to_date:
|
||||
self.logger.info(
|
||||
f"The prediction in {self.record.info['id']} are latest ({start_time}). No need to update to {self.to_date}."
|
||||
)
|
||||
return
|
||||
|
||||
# load dataset
|
||||
if dataset is None:
|
||||
# For reusing the dataset
|
||||
dataset = self.prepare_data()
|
||||
|
||||
# Load model
|
||||
model = self.rmdl.get_model()
|
||||
|
||||
new_pred: pd.Series = model.predict(dataset)
|
||||
|
||||
cb_pred = pd.concat([self.old_pred, new_pred.to_frame("score")], axis=0)
|
||||
cb_pred = cb_pred.sort_index()
|
||||
|
||||
self.record.save_objects(**{"pred.pkl": cb_pred})
|
||||
|
||||
self.logger.info(f"Finish updating new {new_pred.shape[0]} predictions in {self.record.info['id']}.")
|
||||
168
qlib/workflow/online/utils.py
Normal file
168
qlib/workflow/online/utils.py
Normal file
@@ -0,0 +1,168 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
OnlineTool is a module to set and unset a series of `online` models.
|
||||
The `online` models are some decisive models in some time points, which can be changed with the change of time.
|
||||
This allows us to use efficient submodels as the market-style changing.
|
||||
"""
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.workflow.online.update import PredUpdater
|
||||
from qlib.workflow.recorder import Recorder
|
||||
from qlib.workflow.task.utils import list_recorders
|
||||
|
||||
|
||||
class OnlineTool:
|
||||
"""
|
||||
OnlineTool will manage `online` models in an experiment that includes the model recorders.
|
||||
"""
|
||||
|
||||
ONLINE_KEY = "online_status" # the online status key in recorder
|
||||
ONLINE_TAG = "online" # the 'online' model
|
||||
OFFLINE_TAG = "offline" # the 'offline' model, not for online serving
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Init OnlineTool.
|
||||
"""
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
|
||||
def set_online_tag(self, tag, recorder: Union[list, object]):
|
||||
"""
|
||||
Set `tag` to the model to sign whether online.
|
||||
|
||||
Args:
|
||||
tag (str): the tags in `ONLINE_TAG`, `OFFLINE_TAG`
|
||||
recorder (Union[list,object]): the model's recorder
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `set_online_tag` method.")
|
||||
|
||||
def get_online_tag(self, recorder: object) -> str:
|
||||
"""
|
||||
Given a model recorder and return its online tag.
|
||||
|
||||
Args:
|
||||
recorder (Object): the model's recorder
|
||||
|
||||
Returns:
|
||||
str: the online tag
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_online_tag` method.")
|
||||
|
||||
def reset_online_tag(self, recorder: Union[list, object]):
|
||||
"""
|
||||
Offline all models and set the recorders to 'online'.
|
||||
|
||||
Args:
|
||||
recorder (Union[list,object]):
|
||||
the recorder you want to reset to 'online'.
|
||||
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `reset_online_tag` method.")
|
||||
|
||||
def online_models(self) -> list:
|
||||
"""
|
||||
Get current `online` models
|
||||
|
||||
Returns:
|
||||
list: a list of `online` models.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `online_models` method.")
|
||||
|
||||
def update_online_pred(self, to_date=None):
|
||||
"""
|
||||
Update the predictions of `online` models to to_date.
|
||||
|
||||
Args:
|
||||
to_date (pd.Timestamp): the pred before this date will be updated. None for updating to the latest.
|
||||
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `update_online_pred` method.")
|
||||
|
||||
|
||||
class OnlineToolR(OnlineTool):
|
||||
"""
|
||||
The implementation of OnlineTool based on (R)ecorder.
|
||||
"""
|
||||
|
||||
def __init__(self, experiment_name: str):
|
||||
"""
|
||||
Init OnlineToolR.
|
||||
|
||||
Args:
|
||||
experiment_name (str): the experiment name.
|
||||
"""
|
||||
super().__init__()
|
||||
self.exp_name = experiment_name
|
||||
|
||||
def set_online_tag(self, tag, recorder: Union[Recorder, List]):
|
||||
"""
|
||||
Set `tag` to the model's recorder to sign whether online.
|
||||
|
||||
Args:
|
||||
tag (str): the tags in `ONLINE_TAG`, `NEXT_ONLINE_TAG`, `OFFLINE_TAG`
|
||||
recorder (Union[Recorder, List]): a list of Recorder or an instance of Recorder
|
||||
"""
|
||||
if isinstance(recorder, Recorder):
|
||||
recorder = [recorder]
|
||||
for rec in recorder:
|
||||
rec.set_tags(**{self.ONLINE_KEY: tag})
|
||||
self.logger.info(f"Set {len(recorder)} models to '{tag}'.")
|
||||
|
||||
def get_online_tag(self, recorder: Recorder) -> str:
|
||||
"""
|
||||
Given a model recorder and return its online tag.
|
||||
|
||||
Args:
|
||||
recorder (Recorder): an instance of recorder
|
||||
|
||||
Returns:
|
||||
str: the online tag
|
||||
"""
|
||||
tags = recorder.list_tags()
|
||||
return tags.get(self.ONLINE_KEY, self.OFFLINE_TAG)
|
||||
|
||||
def reset_online_tag(self, recorder: Union[Recorder, List]):
|
||||
"""
|
||||
Offline all models and set the recorders to 'online'.
|
||||
|
||||
Args:
|
||||
recorder (Union[Recorder, List]):
|
||||
the recorder you want to reset to 'online'.
|
||||
|
||||
"""
|
||||
if isinstance(recorder, Recorder):
|
||||
recorder = [recorder]
|
||||
recs = list_recorders(self.exp_name)
|
||||
self.set_online_tag(self.OFFLINE_TAG, list(recs.values()))
|
||||
self.set_online_tag(self.ONLINE_TAG, recorder)
|
||||
|
||||
def online_models(self) -> list:
|
||||
"""
|
||||
Get current `online` models
|
||||
|
||||
Returns:
|
||||
list: a list of `online` models.
|
||||
"""
|
||||
return list(list_recorders(self.exp_name, lambda rec: self.get_online_tag(rec) == self.ONLINE_TAG).values())
|
||||
|
||||
def update_online_pred(self, to_date=None):
|
||||
"""
|
||||
Update the predictions of online models to to_date.
|
||||
|
||||
Args:
|
||||
to_date (pd.Timestamp): the pred before this date will be updated. None for updating to latest time in Calendar.
|
||||
"""
|
||||
online_models = self.online_models()
|
||||
for rec in online_models:
|
||||
hist_ref = 0
|
||||
task = rec.load_object("task")
|
||||
# Special treatment of historical dependencies
|
||||
if task["dataset"]["class"] == "TSDatasetH":
|
||||
hist_ref = task["dataset"]["kwargs"]["step_len"]
|
||||
PredUpdater(rec, to_date=to_date, hist_ref=hist_ref).update()
|
||||
|
||||
self.logger.info(f"Finished updating {len(online_models)} online model predictions of {self.exp_name}.")
|
||||
@@ -1,7 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import re
|
||||
import re, logging
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
@@ -13,10 +13,10 @@ from ..data.dataset.handler import DataHandlerLP
|
||||
from ..utils import init_instance_by_config, get_module_by_module_path
|
||||
from ..log import get_module_logger
|
||||
from ..utils import flatten_dict
|
||||
from ..contrib.eva.alpha import calc_ic, calc_long_short_return
|
||||
from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec
|
||||
from ..contrib.strategy.strategy import BaseStrategy
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
logger = get_module_logger("workflow", logging.INFO)
|
||||
|
||||
|
||||
class RecordTemp:
|
||||
@@ -39,7 +39,13 @@ class RecordTemp:
|
||||
return "/".join(names)
|
||||
|
||||
def __init__(self, recorder):
|
||||
self.recorder = recorder
|
||||
self._recorder = recorder
|
||||
|
||||
@property
|
||||
def recorder(self):
|
||||
if self._recorder is None:
|
||||
raise ValueError("This RecordTemp did not set recorder yet.")
|
||||
return self._recorder
|
||||
|
||||
def generate(self, **kwargs):
|
||||
"""
|
||||
@@ -145,6 +151,10 @@ class SignalRecord(RecordTemp):
|
||||
del params["data_key"]
|
||||
# The backend handler should be DataHandler
|
||||
raw_label = self.dataset.prepare(**params)
|
||||
except AttributeError:
|
||||
# The data handler is initialize with `drop_raw=True`...
|
||||
# So raw_label is not available
|
||||
raw_label = None
|
||||
|
||||
self.recorder.save_objects(**{"label.pkl": raw_label})
|
||||
self.dataset.__class__ = orig_cls
|
||||
@@ -156,6 +166,60 @@ class SignalRecord(RecordTemp):
|
||||
return super().load(name)
|
||||
|
||||
|
||||
class HFSignalRecord(SignalRecord):
|
||||
"""
|
||||
This is the Signal Analysis Record class that generates the analysis results such as IC and IR. This class inherits the ``RecordTemp`` class.
|
||||
"""
|
||||
|
||||
artifact_path = "hg_sig_analysis"
|
||||
|
||||
def __init__(self, recorder, **kwargs):
|
||||
super().__init__(recorder=recorder)
|
||||
|
||||
def generate(self):
|
||||
pred = self.load("pred.pkl")
|
||||
raw_label = self.load("label.pkl")
|
||||
long_pre, short_pre = calc_long_short_prec(pred.iloc[:, 0], raw_label.iloc[:, 0], is_alpha=True)
|
||||
ic, ric = calc_ic(pred.iloc[:, 0], raw_label.iloc[:, 0])
|
||||
metrics = {
|
||||
"IC": ic.mean(),
|
||||
"ICIR": ic.mean() / ic.std(),
|
||||
"Rank IC": ric.mean(),
|
||||
"Rank ICIR": ric.mean() / ric.std(),
|
||||
"Long precision": long_pre.mean(),
|
||||
"Short precision": short_pre.mean(),
|
||||
}
|
||||
objects = {"ic.pkl": ic, "ric.pkl": ric}
|
||||
objects.update({"long_pre.pkl": long_pre, "short_pre.pkl": short_pre})
|
||||
long_short_r, long_avg_r = calc_long_short_return(pred.iloc[:, 0], raw_label.iloc[:, 0])
|
||||
metrics.update(
|
||||
{
|
||||
"Long-Short Average Return": long_short_r.mean(),
|
||||
"Long-Short Average Sharpe": long_short_r.mean() / long_short_r.std(),
|
||||
}
|
||||
)
|
||||
objects.update(
|
||||
{
|
||||
"long_short_r.pkl": long_short_r,
|
||||
"long_avg_r.pkl": long_avg_r,
|
||||
}
|
||||
)
|
||||
self.recorder.log_metrics(**metrics)
|
||||
self.recorder.save_objects(**objects, artifact_path=self.get_path())
|
||||
pprint(metrics)
|
||||
|
||||
def list(self):
|
||||
paths = [
|
||||
self.get_path("ic.pkl"),
|
||||
self.get_path("ric.pkl"),
|
||||
self.get_path("long_pre.pkl"),
|
||||
self.get_path("short_pre.pkl"),
|
||||
self.get_path("long_short_r.pkl"),
|
||||
self.get_path("long_avg_r.pkl"),
|
||||
]
|
||||
return paths
|
||||
|
||||
|
||||
class SigAnaRecord(SignalRecord):
|
||||
"""
|
||||
This is the Signal Analysis Record class that generates the analysis results such as IC and IR. This class inherits the ``RecordTemp`` class.
|
||||
@@ -176,6 +240,9 @@ class SigAnaRecord(SignalRecord):
|
||||
|
||||
pred = self.load("pred.pkl")
|
||||
label = self.load("label.pkl")
|
||||
if label is None or not isinstance(label, pd.DataFrame) or label.empty:
|
||||
logger.warn(f"Empty label.")
|
||||
return
|
||||
ic, ric = calc_ic(pred.iloc[:, 0], label.iloc[:, 0])
|
||||
metrics = {
|
||||
"IC": ic.mean(),
|
||||
@@ -248,11 +315,20 @@ class PortAnaRecord(SignalRecord):
|
||||
report_dict = normal_backtest(pred_score, strategy=self.strategy, **self.backtest_config)
|
||||
report_normal = report_dict.get("report_df")
|
||||
positions_normal = report_dict.get("positions")
|
||||
self.recorder.save_objects(**{"report_normal.pkl": report_normal}, artifact_path=PortAnaRecord.get_path())
|
||||
self.recorder.save_objects(**{"positions_normal.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path())
|
||||
self.recorder.save_objects(
|
||||
**{"report_normal.pkl": report_normal},
|
||||
artifact_path=PortAnaRecord.get_path(),
|
||||
)
|
||||
self.recorder.save_objects(
|
||||
**{"positions_normal.pkl": positions_normal},
|
||||
artifact_path=PortAnaRecord.get_path(),
|
||||
)
|
||||
order_normal = report_dict.get("order_list")
|
||||
if order_normal:
|
||||
self.recorder.save_objects(**{"order_normal.pkl": order_normal}, artifact_path=PortAnaRecord.get_path())
|
||||
self.recorder.save_objects(
|
||||
**{"order_normal.pkl": order_normal},
|
||||
artifact_path=PortAnaRecord.get_path(),
|
||||
)
|
||||
|
||||
# analysis
|
||||
analysis = dict()
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import mlflow
|
||||
import mlflow, logging
|
||||
import shutil, os, pickle, tempfile, codecs, pickle
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from ..utils.objm import FileManager
|
||||
from ..log import get_module_logger
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
logger = get_module_logger("workflow", logging.INFO)
|
||||
|
||||
|
||||
class Recorder:
|
||||
@@ -39,6 +39,9 @@ class Recorder:
|
||||
def __str__(self):
|
||||
return str(self.info)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.info["id"])
|
||||
|
||||
@property
|
||||
def info(self):
|
||||
output = dict()
|
||||
@@ -232,6 +235,14 @@ class MLflowRecorder(Recorder):
|
||||
client=self.client,
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.info["id"])
|
||||
|
||||
def __eq__(self, o: object) -> bool:
|
||||
if isinstance(o, MLflowRecorder):
|
||||
return self.info["id"] == o.info["id"]
|
||||
return False
|
||||
|
||||
@property
|
||||
def uri(self):
|
||||
return self._uri
|
||||
|
||||
13
qlib/workflow/task/__init__.py
Normal file
13
qlib/workflow/task/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
Task related workflow is implemented in this folder
|
||||
|
||||
A typical task workflow
|
||||
|
||||
| Step | Description |
|
||||
|-----------------------+------------------------------------------------|
|
||||
| TaskGen | Generating tasks. |
|
||||
| TaskManager(optional) | Manage generated tasks |
|
||||
| run task | retrive tasks from TaskManager and run tasks. |
|
||||
"""
|
||||
219
qlib/workflow/task/collect.py
Normal file
219
qlib/workflow/task/collect.py
Normal file
@@ -0,0 +1,219 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
Collector module can collect objects from everywhere and process them such as merging, grouping, averaging and so on.
|
||||
"""
|
||||
|
||||
from typing import Callable, Dict, List
|
||||
from qlib.utils.serial import Serializable
|
||||
from qlib.workflow import R
|
||||
|
||||
|
||||
class Collector(Serializable):
|
||||
"""The collector to collect different results"""
|
||||
|
||||
pickle_backend = "dill" # use dill to dump user method
|
||||
|
||||
def __init__(self, process_list=[]):
|
||||
"""
|
||||
Init Collector.
|
||||
|
||||
Args:
|
||||
process_list (list or Callable): the list of processors or the instance of a processor to process dict.
|
||||
"""
|
||||
if not isinstance(process_list, list):
|
||||
process_list = [process_list]
|
||||
self.process_list = process_list
|
||||
|
||||
def collect(self) -> dict:
|
||||
"""
|
||||
Collect the results and return a dict like {key: things}
|
||||
|
||||
Returns:
|
||||
dict: the dict after collecting.
|
||||
|
||||
For example:
|
||||
|
||||
{"prediction": pd.Series}
|
||||
|
||||
{"IC": {"Xgboost": pd.Series, "LSTM": pd.Series}}
|
||||
|
||||
......
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `collect` method.")
|
||||
|
||||
@staticmethod
|
||||
def process_collect(collected_dict, process_list=[], *args, **kwargs) -> dict:
|
||||
"""
|
||||
Do a series of processing to the dict returned by collect and return a dict like {key: things}
|
||||
For example, you can group and ensemble.
|
||||
|
||||
Args:
|
||||
collected_dict (dict): the dict return by `collect`
|
||||
process_list (list or Callable): the list of processors or the instance of a processor to process dict.
|
||||
The processor order is the same as the list order.
|
||||
For example: [Group1(..., Ensemble1()), Group2(..., Ensemble2())]
|
||||
|
||||
Returns:
|
||||
dict: the dict after processing.
|
||||
"""
|
||||
if not isinstance(process_list, list):
|
||||
process_list = [process_list]
|
||||
result = {}
|
||||
for artifact in collected_dict:
|
||||
value = collected_dict[artifact]
|
||||
for process in process_list:
|
||||
if not callable(process):
|
||||
raise NotImplementedError(f"{type(process)} is not supported in `process_collect`.")
|
||||
value = process(value, *args, **kwargs)
|
||||
result[artifact] = value
|
||||
return result
|
||||
|
||||
def __call__(self, *args, **kwargs) -> dict:
|
||||
"""
|
||||
Do the workflow including ``collect`` and ``process_collect``
|
||||
|
||||
Returns:
|
||||
dict: the dict after collecting and processing.
|
||||
"""
|
||||
collected = self.collect()
|
||||
return self.process_collect(collected, self.process_list, *args, **kwargs)
|
||||
|
||||
|
||||
class MergeCollector(Collector):
|
||||
"""
|
||||
A collector to collect the results of other Collectors
|
||||
|
||||
For example:
|
||||
|
||||
We have 2 collector, which named A and B.
|
||||
A can collect {"prediction": pd.Series} and B can collect {"IC": {"Xgboost": pd.Series, "LSTM": pd.Series}}.
|
||||
Then after this class's collect, we can collect {"A_prediction": pd.Series, "B_IC": {"Xgboost": pd.Series, "LSTM": pd.Series}}
|
||||
|
||||
......
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, collector_dict: Dict[str, Collector], process_list: List[Callable] = [], merge_func=None):
|
||||
"""
|
||||
Init MergeCollector.
|
||||
|
||||
Args:
|
||||
collector_dict (Dict[str,Collector]): the dict like {collector_key, Collector}
|
||||
process_list (List[Callable]): the list of processors or the instance of processor to process dict.
|
||||
merge_func (Callable): a method to generate outermost key. The given params are ``collector_key`` from collector_dict and ``key`` from every collector after collecting.
|
||||
None for using tuple to connect them, such as "ABC"+("a","b") -> ("ABC", ("a","b")).
|
||||
"""
|
||||
super().__init__(process_list=process_list)
|
||||
self.collector_dict = collector_dict
|
||||
self.merge_func = merge_func
|
||||
|
||||
def collect(self) -> dict:
|
||||
"""
|
||||
Collect all results of collector_dict and change the outermost key to a recombination key.
|
||||
|
||||
Returns:
|
||||
dict: the dict after collecting.
|
||||
"""
|
||||
collect_dict = {}
|
||||
for collector_key, collector in self.collector_dict.items():
|
||||
tmp_dict = collector()
|
||||
for key, value in tmp_dict.items():
|
||||
if self.merge_func is not None:
|
||||
collect_dict[self.merge_func(collector_key, key)] = value
|
||||
else:
|
||||
collect_dict[(collector_key, key)] = value
|
||||
return collect_dict
|
||||
|
||||
|
||||
class RecorderCollector(Collector):
|
||||
ART_KEY_RAW = "__raw"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
experiment,
|
||||
process_list=[],
|
||||
rec_key_func=None,
|
||||
rec_filter_func=None,
|
||||
artifacts_path={"pred": "pred.pkl"},
|
||||
artifacts_key=None,
|
||||
):
|
||||
"""
|
||||
Init RecorderCollector.
|
||||
|
||||
Args:
|
||||
experiment (Experiment or str): an instance of an Experiment or the name of an Experiment
|
||||
process_list (list or Callable): the list of processors or the instance of a processor to process dict.
|
||||
rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id.
|
||||
rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None.
|
||||
artifacts_path (dict, optional): The artifacts name and its path in Recorder. Defaults to {"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}.
|
||||
artifacts_key (str or List, optional): the artifacts key you want to get. If None, get all artifacts.
|
||||
"""
|
||||
super().__init__(process_list=process_list)
|
||||
if isinstance(experiment, str):
|
||||
experiment = R.get_exp(experiment_name=experiment)
|
||||
self.experiment = experiment
|
||||
self.artifacts_path = artifacts_path
|
||||
if rec_key_func is None:
|
||||
rec_key_func = lambda rec: rec.info["id"]
|
||||
if artifacts_key is None:
|
||||
artifacts_key = list(self.artifacts_path.keys())
|
||||
self.rec_key_func = rec_key_func
|
||||
self.artifacts_key = artifacts_key
|
||||
self.rec_filter_func = rec_filter_func
|
||||
|
||||
def collect(self, artifacts_key=None, rec_filter_func=None, only_exist=True) -> dict:
|
||||
"""
|
||||
Collect different artifacts based on recorder after filtering.
|
||||
|
||||
Args:
|
||||
artifacts_key (str or List, optional): the artifacts key you want to get. If None, use the default.
|
||||
rec_filter_func (Callable, optional): filter the recorder by return True or False. If None, use the default.
|
||||
only_exist (bool, optional): if only collect the artifacts when a recorder really has.
|
||||
If True, the recorder with exception when loading will not be collected. But if False, it will raise the exception.
|
||||
|
||||
Returns:
|
||||
dict: the dict after collected like {artifact: {rec_key: object}}
|
||||
"""
|
||||
if artifacts_key is None:
|
||||
artifacts_key = self.artifacts_key
|
||||
if rec_filter_func is None:
|
||||
rec_filter_func = self.rec_filter_func
|
||||
|
||||
if isinstance(artifacts_key, str):
|
||||
artifacts_key = [artifacts_key]
|
||||
|
||||
collect_dict = {}
|
||||
# filter records
|
||||
recs = self.experiment.list_recorders()
|
||||
recs_flt = {}
|
||||
for rid, rec in recs.items():
|
||||
if rec_filter_func is None or rec_filter_func(rec):
|
||||
recs_flt[rid] = rec
|
||||
|
||||
for _, rec in recs_flt.items():
|
||||
rec_key = self.rec_key_func(rec)
|
||||
for key in artifacts_key:
|
||||
if self.ART_KEY_RAW == key:
|
||||
artifact = rec
|
||||
else:
|
||||
try:
|
||||
artifact = rec.load_object(self.artifacts_path[key])
|
||||
except Exception as e:
|
||||
if only_exist:
|
||||
# only collect existing artifact
|
||||
continue
|
||||
raise e
|
||||
collect_dict.setdefault(key, {})[rec_key] = artifact
|
||||
|
||||
return collect_dict
|
||||
|
||||
def get_exp_name(self) -> str:
|
||||
"""
|
||||
Get experiment name
|
||||
|
||||
Returns:
|
||||
str: experiment name
|
||||
"""
|
||||
return self.experiment.name
|
||||
231
qlib/workflow/task/gen.py
Normal file
231
qlib/workflow/task/gen.py
Normal file
@@ -0,0 +1,231 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
TaskGenerator module can generate many tasks based on TaskGen and some task templates.
|
||||
"""
|
||||
import abc
|
||||
import copy
|
||||
from typing import List, Union, Callable
|
||||
from .utils import TimeAdjuster
|
||||
|
||||
|
||||
def task_generator(tasks, generators) -> list:
|
||||
"""
|
||||
Use a list of TaskGen and a list of task templates to generate different tasks.
|
||||
|
||||
For examples:
|
||||
|
||||
There are 3 task templates a,b,c and 2 TaskGen A,B. A will generates 2 tasks from a template and B will generates 3 tasks from a template.
|
||||
task_generator([a, b, c], [A, B]) will finally generate 3*2*3 = 18 tasks.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tasks : List[dict] or dict
|
||||
a list of task templates or a single task
|
||||
generators : List[TaskGen] or TaskGen
|
||||
a list of TaskGen or a single TaskGen
|
||||
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
a list of tasks
|
||||
"""
|
||||
|
||||
if isinstance(tasks, dict):
|
||||
tasks = [tasks]
|
||||
if isinstance(generators, TaskGen):
|
||||
generators = [generators]
|
||||
|
||||
# generate gen_task_list
|
||||
for gen in generators:
|
||||
new_task_list = []
|
||||
for task in tasks:
|
||||
new_task_list.extend(gen.generate(task))
|
||||
tasks = new_task_list
|
||||
|
||||
return tasks
|
||||
|
||||
|
||||
class TaskGen(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
The base class for generating different tasks
|
||||
|
||||
Example 1:
|
||||
|
||||
input: a specific task template and rolling steps
|
||||
|
||||
output: rolling version of the tasks
|
||||
|
||||
Example 2:
|
||||
|
||||
input: a specific task template and losses list
|
||||
|
||||
output: a set of tasks with different losses
|
||||
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def generate(self, task: dict) -> List[dict]:
|
||||
"""
|
||||
Generate different tasks based on a task template
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task: dict
|
||||
a task template
|
||||
|
||||
Returns
|
||||
-------
|
||||
typing.List[dict]:
|
||||
A list of tasks
|
||||
"""
|
||||
pass
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""
|
||||
This is just a syntactic sugar for generate
|
||||
"""
|
||||
return self.generate(*args, **kwargs)
|
||||
|
||||
|
||||
def handler_mod(task: dict, rolling_gen):
|
||||
"""
|
||||
Help to modify the handler end time when using RollingGen
|
||||
|
||||
Args:
|
||||
task (dict): a task template
|
||||
rg (RollingGen): an instance of RollingGen
|
||||
"""
|
||||
try:
|
||||
interval = rolling_gen.ta.cal_interval(
|
||||
task["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"],
|
||||
task["dataset"]["kwargs"]["segments"][rolling_gen.test_key][1],
|
||||
)
|
||||
# if end_time < the end of test_segments, then change end_time to allow load more data
|
||||
if interval < 0:
|
||||
task["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"] = copy.deepcopy(
|
||||
task["dataset"]["kwargs"]["segments"][rolling_gen.test_key][1]
|
||||
)
|
||||
except KeyError:
|
||||
# Maybe dataset do not have handler, then do nothing.
|
||||
pass
|
||||
|
||||
|
||||
class RollingGen(TaskGen):
|
||||
ROLL_EX = TimeAdjuster.SHIFT_EX # fixed start date, expanding end date
|
||||
ROLL_SD = TimeAdjuster.SHIFT_SD # fixed segments size, slide it from start date
|
||||
|
||||
def __init__(self, step: int = 40, rtype: str = ROLL_EX, ds_extra_mod_func: Union[None, Callable] = handler_mod):
|
||||
"""
|
||||
Generate tasks for rolling
|
||||
|
||||
Parameters
|
||||
----------
|
||||
step : int
|
||||
step to rolling
|
||||
rtype : str
|
||||
rolling type (expanding, sliding)
|
||||
ds_extra_mod_func: Callable
|
||||
A method like: handler_mod(task: dict, rg: RollingGen)
|
||||
Do some extra action after generating a task. For example, use ``handler_mod`` to modify the end time of the handler of a dataset.
|
||||
"""
|
||||
self.step = step
|
||||
self.rtype = rtype
|
||||
self.ds_extra_mod_func = ds_extra_mod_func
|
||||
self.ta = TimeAdjuster(future=True)
|
||||
|
||||
self.test_key = "test"
|
||||
self.train_key = "train"
|
||||
|
||||
def generate(self, task: dict) -> List[dict]:
|
||||
"""
|
||||
Converting the task into a rolling task.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task: dict
|
||||
A dict describing a task. For example.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
DEFAULT_TASK = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
},
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": "csi100",
|
||||
},
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-20"), # Please avoid leaking the future test data into validation
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
},
|
||||
"record": [
|
||||
{
|
||||
"class": "SignalRecord",
|
||||
"module_path": "qlib.workflow.record_temp",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
Returns
|
||||
----------
|
||||
List[dict]: a list of tasks
|
||||
"""
|
||||
res = []
|
||||
|
||||
prev_seg = None
|
||||
test_end = None
|
||||
while True:
|
||||
t = copy.deepcopy(task)
|
||||
|
||||
# calculate segments
|
||||
if prev_seg is None:
|
||||
# First rolling
|
||||
# 1) prepare the end point
|
||||
segments: dict = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"]))
|
||||
test_end = self.ta.max() if segments[self.test_key][1] is None else segments[self.test_key][1]
|
||||
# 2) and init test segments
|
||||
test_start_idx = self.ta.align_idx(segments[self.test_key][0])
|
||||
segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1))
|
||||
else:
|
||||
segments = {}
|
||||
try:
|
||||
for k, seg in prev_seg.items():
|
||||
# decide how to shift
|
||||
# expanding only for train data, the segments size of test data and valid data won't change
|
||||
if k == self.train_key and self.rtype == self.ROLL_EX:
|
||||
rtype = self.ta.SHIFT_EX
|
||||
else:
|
||||
rtype = self.ta.SHIFT_SD
|
||||
# shift the segments data
|
||||
segments[k] = self.ta.shift(seg, step=self.step, rtype=rtype)
|
||||
if segments[self.test_key][0] > test_end:
|
||||
break
|
||||
except KeyError:
|
||||
# We reach the end of tasks
|
||||
# No more rolling
|
||||
break
|
||||
|
||||
# update segments of this task
|
||||
t["dataset"]["kwargs"]["segments"] = copy.deepcopy(segments)
|
||||
prev_seg = segments
|
||||
if self.ds_extra_mod_func is not None:
|
||||
self.ds_extra_mod_func(t, self)
|
||||
res.append(t)
|
||||
return res
|
||||
493
qlib/workflow/task/manage.py
Normal file
493
qlib/workflow/task/manage.py
Normal file
@@ -0,0 +1,493 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
TaskManager can fetch unused tasks automatically and manage the lifecycle of a set of tasks with error handling.
|
||||
These features can run tasks concurrently and ensure every task will be used only once.
|
||||
Task Manager will store all tasks in `MongoDB <https://www.mongodb.com/>`_.
|
||||
Users **MUST** finished the configuration of `MongoDB <https://www.mongodb.com/>`_ when using this module.
|
||||
|
||||
A task in TaskManager consists of 3 parts
|
||||
- tasks description: the desc will define the task
|
||||
- tasks status: the status of the task
|
||||
- tasks result: A user can get the task with the task description and task result.
|
||||
"""
|
||||
import concurrent
|
||||
import pickle
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import Callable, List
|
||||
|
||||
import fire
|
||||
import pymongo
|
||||
from bson.binary import Binary
|
||||
from bson.objectid import ObjectId
|
||||
from pymongo.errors import InvalidDocument
|
||||
from qlib import auto_init, get_module_logger
|
||||
from tqdm.cli import tqdm
|
||||
|
||||
from .utils import get_mongodb
|
||||
|
||||
|
||||
class TaskManager:
|
||||
"""
|
||||
TaskManager
|
||||
|
||||
Here is what will a task looks like when it created by TaskManager
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
{
|
||||
'def': pickle serialized task definition. using pickle will make it easier
|
||||
'filter': json-like data. This is for filtering the tasks.
|
||||
'status': 'waiting' | 'running' | 'done'
|
||||
'res': pickle serialized task result,
|
||||
}
|
||||
|
||||
The tasks manager assumes that you will only update the tasks you fetched.
|
||||
The mongo fetch one and update will make it date updating secure.
|
||||
|
||||
.. note::
|
||||
|
||||
Assumption: the data in MongoDB was encoded and the data out of MongoDB was decoded
|
||||
|
||||
Here are four status which are:
|
||||
|
||||
STATUS_WAITING: waiting for training
|
||||
|
||||
STATUS_RUNNING: training
|
||||
|
||||
STATUS_PART_DONE: finished some step and waiting for next step
|
||||
|
||||
STATUS_DONE: all work done
|
||||
"""
|
||||
|
||||
STATUS_WAITING = "waiting"
|
||||
STATUS_RUNNING = "running"
|
||||
STATUS_DONE = "done"
|
||||
STATUS_PART_DONE = "part_done"
|
||||
|
||||
ENCODE_FIELDS_PREFIX = ["def", "res"]
|
||||
|
||||
def __init__(self, task_pool: str = None):
|
||||
"""
|
||||
Init Task Manager, remember to make the statement of MongoDB url and database name firstly.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task_pool: str
|
||||
the name of Collection in MongoDB
|
||||
"""
|
||||
self.mdb = get_mongodb()
|
||||
if task_pool is not None:
|
||||
self.task_pool = getattr(self.mdb, task_pool)
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
|
||||
def list(self) -> list:
|
||||
"""
|
||||
List the all collection(task_pool) of the db
|
||||
|
||||
Returns:
|
||||
list
|
||||
"""
|
||||
return self.mdb.list_collection_names()
|
||||
|
||||
def _encode_task(self, task):
|
||||
for prefix in self.ENCODE_FIELDS_PREFIX:
|
||||
for k in list(task.keys()):
|
||||
if k.startswith(prefix):
|
||||
task[k] = Binary(pickle.dumps(task[k]))
|
||||
return task
|
||||
|
||||
def _decode_task(self, task):
|
||||
for prefix in self.ENCODE_FIELDS_PREFIX:
|
||||
for k in list(task.keys()):
|
||||
if k.startswith(prefix):
|
||||
task[k] = pickle.loads(task[k])
|
||||
return task
|
||||
|
||||
def _dict_to_str(self, flt):
|
||||
return {k: str(v) for k, v in flt.items()}
|
||||
|
||||
def replace_task(self, task, new_task):
|
||||
"""
|
||||
Use a new task to replace a old one
|
||||
|
||||
Args:
|
||||
task: old task
|
||||
new_task: new task
|
||||
"""
|
||||
new_task = self._encode_task(new_task)
|
||||
query = {"_id": ObjectId(task["_id"])}
|
||||
try:
|
||||
self.task_pool.replace_one(query, new_task)
|
||||
except InvalidDocument:
|
||||
task["filter"] = self._dict_to_str(task["filter"])
|
||||
self.task_pool.replace_one(query, new_task)
|
||||
|
||||
def insert_task(self, task):
|
||||
"""
|
||||
Insert a task.
|
||||
|
||||
Args:
|
||||
task: the task waiting for insert
|
||||
|
||||
Returns:
|
||||
pymongo.results.InsertOneResult
|
||||
"""
|
||||
try:
|
||||
insert_result = self.task_pool.insert_one(task)
|
||||
except InvalidDocument:
|
||||
task["filter"] = self._dict_to_str(task["filter"])
|
||||
insert_result = self.task_pool.insert_one(task)
|
||||
return insert_result
|
||||
|
||||
def insert_task_def(self, task_def):
|
||||
"""
|
||||
Insert a task to task_pool
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task_def: dict
|
||||
the task definition
|
||||
|
||||
Returns
|
||||
-------
|
||||
pymongo.results.InsertOneResult
|
||||
"""
|
||||
task = self._encode_task(
|
||||
{
|
||||
"def": task_def,
|
||||
"filter": task_def, # FIXME: catch the raised error
|
||||
"status": self.STATUS_WAITING,
|
||||
}
|
||||
)
|
||||
insert_result = self.insert_task(task)
|
||||
return insert_result
|
||||
|
||||
def create_task(self, task_def_l, dry_run=False, print_nt=False) -> List[str]:
|
||||
"""
|
||||
If the tasks in task_def_l are new, then insert new tasks into the task_pool, and record inserted_id.
|
||||
If a task is not new, then just query its _id.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task_def_l: list
|
||||
a list of task
|
||||
dry_run: bool
|
||||
if insert those new tasks to task pool
|
||||
print_nt: bool
|
||||
if print new task
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[str]
|
||||
a list of the _id of task_def_l
|
||||
"""
|
||||
new_tasks = []
|
||||
_id_list = []
|
||||
for t in task_def_l:
|
||||
try:
|
||||
r = self.task_pool.find_one({"filter": t})
|
||||
except InvalidDocument:
|
||||
r = self.task_pool.find_one({"filter": self._dict_to_str(t)})
|
||||
if r is None:
|
||||
new_tasks.append(t)
|
||||
if not dry_run:
|
||||
insert_result = self.insert_task_def(t)
|
||||
_id_list.append(insert_result.inserted_id)
|
||||
else:
|
||||
_id_list.append(None)
|
||||
else:
|
||||
_id_list.append(self._decode_task(r)["_id"])
|
||||
|
||||
self.logger.info(f"Total Tasks: {len(task_def_l)}, New Tasks: {len(new_tasks)}")
|
||||
|
||||
if print_nt: # print new task
|
||||
for t in new_tasks:
|
||||
print(t)
|
||||
|
||||
if dry_run:
|
||||
return []
|
||||
|
||||
return _id_list
|
||||
|
||||
def fetch_task(self, query={}, status=STATUS_WAITING) -> dict:
|
||||
"""
|
||||
Use query to fetch tasks.
|
||||
|
||||
Args:
|
||||
query (dict, optional): query dict. Defaults to {}.
|
||||
status (str, optional): [description]. Defaults to STATUS_WAITING.
|
||||
|
||||
Returns:
|
||||
dict: a task(document in collection) after decoding
|
||||
"""
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
query.update({"status": status})
|
||||
task = self.task_pool.find_one_and_update(
|
||||
query, {"$set": {"status": self.STATUS_RUNNING}}, sort=[("priority", pymongo.DESCENDING)]
|
||||
)
|
||||
# null will be at the top after sorting when using ASCENDING, so the larger the number higher, the higher the priority
|
||||
if task is None:
|
||||
return None
|
||||
task["status"] = self.STATUS_RUNNING
|
||||
return self._decode_task(task)
|
||||
|
||||
@contextmanager
|
||||
def safe_fetch_task(self, query={}, status=STATUS_WAITING):
|
||||
"""
|
||||
Fetch task from task_pool using query with contextmanager
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query: dict
|
||||
the dict of query
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict: a task(document in collection) after decoding
|
||||
"""
|
||||
task = self.fetch_task(query=query, status=status)
|
||||
try:
|
||||
yield task
|
||||
except Exception:
|
||||
if task is not None:
|
||||
self.logger.info("Returning task before raising error")
|
||||
self.return_task(task)
|
||||
self.logger.info("Task returned")
|
||||
raise
|
||||
|
||||
def task_fetcher_iter(self, query={}):
|
||||
while True:
|
||||
with self.safe_fetch_task(query=query) as task:
|
||||
if task is None:
|
||||
break
|
||||
yield task
|
||||
|
||||
def query(self, query={}, decode=True):
|
||||
"""
|
||||
Query task in collection.
|
||||
This function may raise exception `pymongo.errors.CursorNotFound: cursor id not found` if it takes too long to iterate the generator
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query: dict
|
||||
the dict of query
|
||||
decode: bool
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict: a task(document in collection) after decoding
|
||||
"""
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
for t in self.task_pool.find(query):
|
||||
yield self._decode_task(t)
|
||||
|
||||
def re_query(self, _id):
|
||||
"""
|
||||
Use _id to query task.
|
||||
|
||||
Args:
|
||||
_id (str): _id of a document
|
||||
|
||||
Returns:
|
||||
dict: a task(document in collection) after decoding
|
||||
"""
|
||||
t = self.task_pool.find_one({"_id": ObjectId(_id)})
|
||||
return self._decode_task(t)
|
||||
|
||||
def commit_task_res(self, task, res, status=STATUS_DONE):
|
||||
"""
|
||||
Commit the result to task['res'].
|
||||
|
||||
Args:
|
||||
task ([type]): [description]
|
||||
res (object): the result you want to save
|
||||
status (str, optional): STATUS_WAITING, STATUS_RUNNING, STATUS_DONE, STATUS_PART_DONE. Defaults to STATUS_DONE.
|
||||
"""
|
||||
# A workaround to use the class attribute.
|
||||
if status is None:
|
||||
status = TaskManager.STATUS_DONE
|
||||
self.task_pool.update_one({"_id": task["_id"]}, {"$set": {"status": status, "res": Binary(pickle.dumps(res))}})
|
||||
|
||||
def return_task(self, task, status=STATUS_WAITING):
|
||||
"""
|
||||
Return a task to status. Alway using in error handling.
|
||||
|
||||
Args:
|
||||
task ([type]): [description]
|
||||
status (str, optional): STATUS_WAITING, STATUS_RUNNING, STATUS_DONE, STATUS_PART_DONE. Defaults to STATUS_WAITING.
|
||||
"""
|
||||
if status is None:
|
||||
status = TaskManager.STATUS_WAITING
|
||||
update_dict = {"$set": {"status": status}}
|
||||
self.task_pool.update_one({"_id": task["_id"]}, update_dict)
|
||||
|
||||
def remove(self, query={}):
|
||||
"""
|
||||
Remove the task using query
|
||||
|
||||
Parameters
|
||||
----------
|
||||
query: dict
|
||||
the dict of query
|
||||
|
||||
"""
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
self.task_pool.delete_many(query)
|
||||
|
||||
def task_stat(self, query={}) -> dict:
|
||||
"""
|
||||
Count the tasks in every status.
|
||||
|
||||
Args:
|
||||
query (dict, optional): the query dict. Defaults to {}.
|
||||
|
||||
Returns:
|
||||
dict
|
||||
"""
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
tasks = self.query(query=query, decode=False)
|
||||
status_stat = {}
|
||||
for t in tasks:
|
||||
status_stat[t["status"]] = status_stat.get(t["status"], 0) + 1
|
||||
return status_stat
|
||||
|
||||
def reset_waiting(self, query={}):
|
||||
"""
|
||||
Reset all running task into waiting status. Can be used when some running task exit unexpected.
|
||||
|
||||
Args:
|
||||
query (dict, optional): the query dict. Defaults to {}.
|
||||
"""
|
||||
query = query.copy()
|
||||
# default query
|
||||
if "status" not in query:
|
||||
query["status"] = self.STATUS_RUNNING
|
||||
return self.reset_status(query=query, status=self.STATUS_WAITING)
|
||||
|
||||
def reset_status(self, query, status):
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
print(self.task_pool.update_many(query, {"$set": {"status": status}}))
|
||||
|
||||
def prioritize(self, task, priority: int):
|
||||
"""
|
||||
Set priority for task
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task : dict
|
||||
The task query from the database
|
||||
priority : int
|
||||
the target priority
|
||||
"""
|
||||
update_dict = {"$set": {"priority": priority}}
|
||||
self.task_pool.update_one({"_id": task["_id"]}, update_dict)
|
||||
|
||||
def _get_undone_n(self, task_stat):
|
||||
return task_stat.get(self.STATUS_WAITING, 0) + task_stat.get(self.STATUS_RUNNING, 0)
|
||||
|
||||
def _get_total(self, task_stat):
|
||||
return sum(task_stat.values())
|
||||
|
||||
def wait(self, query={}):
|
||||
task_stat = self.task_stat(query)
|
||||
total = self._get_total(task_stat)
|
||||
last_undone_n = self._get_undone_n(task_stat)
|
||||
with tqdm(total=total, initial=total - last_undone_n) as pbar:
|
||||
while True:
|
||||
time.sleep(10)
|
||||
undone_n = self._get_undone_n(self.task_stat(query))
|
||||
pbar.update(last_undone_n - undone_n)
|
||||
last_undone_n = undone_n
|
||||
if undone_n == 0:
|
||||
break
|
||||
|
||||
def __str__(self):
|
||||
return f"TaskManager({self.task_pool})"
|
||||
|
||||
|
||||
def run_task(
|
||||
task_func: Callable,
|
||||
task_pool: str,
|
||||
query: dict = {},
|
||||
force_release: bool = False,
|
||||
before_status: str = TaskManager.STATUS_WAITING,
|
||||
after_status: str = TaskManager.STATUS_DONE,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
While the task pool is not empty (has WAITING tasks), use task_func to fetch and run tasks in task_pool
|
||||
|
||||
After running this method, here are 4 situations (before_status -> after_status):
|
||||
|
||||
STATUS_WAITING -> STATUS_DONE: use task["def"] as `task_func` param
|
||||
|
||||
STATUS_WAITING -> STATUS_PART_DONE: use task["def"] as `task_func` param
|
||||
|
||||
STATUS_PART_DONE -> STATUS_PART_DONE: use task["res"] as `task_func` param
|
||||
|
||||
STATUS_PART_DONE -> STATUS_DONE: use task["res"] as `task_func` param
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task_func : Callable
|
||||
def (task_def, **kwargs) -> <res which will be committed>
|
||||
the function to run the task
|
||||
task_pool : str
|
||||
the name of the task pool (Collection in MongoDB)
|
||||
query: dict
|
||||
will use this dict to query task_pool when fetching task
|
||||
force_release : bool
|
||||
will the program force to release the resource
|
||||
before_status : str:
|
||||
the tasks in before_status will be fetched and trained. Can be STATUS_WAITING, STATUS_PART_DONE.
|
||||
after_status : str:
|
||||
the tasks after trained will become after_status. Can be STATUS_WAITING, STATUS_PART_DONE.
|
||||
kwargs
|
||||
the params for `task_func`
|
||||
"""
|
||||
tm = TaskManager(task_pool)
|
||||
|
||||
ever_run = False
|
||||
|
||||
while True:
|
||||
with tm.safe_fetch_task(status=before_status, query=query) as task:
|
||||
if task is None:
|
||||
break
|
||||
get_module_logger("run_task").info(task["def"])
|
||||
# when fetching `WAITING` task, use task["def"] to train
|
||||
if before_status == TaskManager.STATUS_WAITING:
|
||||
param = task["def"]
|
||||
# when fetching `PART_DONE` task, use task["res"] to train because the middle result has been saved to task["res"]
|
||||
elif before_status == TaskManager.STATUS_PART_DONE:
|
||||
param = task["res"]
|
||||
else:
|
||||
raise ValueError("The fetched task must be `STATUS_WAITING` or `STATUS_PART_DONE`!")
|
||||
if force_release:
|
||||
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
|
||||
res = executor.submit(task_func, param, **kwargs).result()
|
||||
else:
|
||||
res = task_func(param, **kwargs)
|
||||
tm.commit_task_res(task, res, status=after_status)
|
||||
ever_run = True
|
||||
|
||||
return ever_run
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# This is for using it in cmd
|
||||
# E.g. : `python -m qlib.workflow.task.manage list`
|
||||
auto_init()
|
||||
fire.Fire(TaskManager)
|
||||
258
qlib/workflow/task/utils.py
Normal file
258
qlib/workflow/task/utils.py
Normal file
@@ -0,0 +1,258 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
Some tools for task management.
|
||||
"""
|
||||
|
||||
import bisect
|
||||
import pandas as pd
|
||||
from qlib.data import D
|
||||
from qlib.workflow import R
|
||||
from qlib.config import C
|
||||
from qlib.log import get_module_logger
|
||||
from pymongo import MongoClient
|
||||
from pymongo.database import Database
|
||||
from typing import Union
|
||||
|
||||
|
||||
def get_mongodb() -> Database:
|
||||
|
||||
"""
|
||||
Get database in MongoDB, which means you need to declare the address and the name of a database at first.
|
||||
|
||||
For example:
|
||||
|
||||
Using qlib.init():
|
||||
|
||||
mongo_conf = {
|
||||
"task_url": task_url, # your MongoDB url
|
||||
"task_db_name": task_db_name, # database name
|
||||
}
|
||||
qlib.init(..., mongo=mongo_conf)
|
||||
|
||||
After qlib.init():
|
||||
|
||||
C["mongo"] = {
|
||||
"task_url" : "mongodb://localhost:27017/",
|
||||
"task_db_name" : "rolling_db"
|
||||
}
|
||||
|
||||
Returns:
|
||||
Database: the Database instance
|
||||
"""
|
||||
try:
|
||||
cfg = C["mongo"]
|
||||
except KeyError:
|
||||
get_module_logger("task").error("Please configure `C['mongo']` before using TaskManager")
|
||||
raise
|
||||
|
||||
client = MongoClient(cfg["task_url"])
|
||||
return client.get_database(name=cfg["task_db_name"])
|
||||
|
||||
|
||||
def list_recorders(experiment, rec_filter_func=None):
|
||||
"""
|
||||
List all recorders which can pass the filter in an experiment.
|
||||
|
||||
Args:
|
||||
experiment (str or Experiment): the name of an Experiment or an instance
|
||||
rec_filter_func (Callable, optional): return True to retain the given recorder. Defaults to None.
|
||||
|
||||
Returns:
|
||||
dict: a dict {rid: recorder} after filtering.
|
||||
"""
|
||||
if isinstance(experiment, str):
|
||||
experiment = R.get_exp(experiment_name=experiment)
|
||||
recs = experiment.list_recorders()
|
||||
recs_flt = {}
|
||||
for rid, rec in recs.items():
|
||||
if rec_filter_func is None or rec_filter_func(rec):
|
||||
recs_flt[rid] = rec
|
||||
|
||||
return recs_flt
|
||||
|
||||
|
||||
class TimeAdjuster:
|
||||
"""
|
||||
Find appropriate date and adjust date.
|
||||
"""
|
||||
|
||||
def __init__(self, future=True, end_time=None):
|
||||
self._future = future
|
||||
self.cals = D.calendar(future=future, end_time=end_time)
|
||||
|
||||
def set_end_time(self, end_time=None):
|
||||
"""
|
||||
Set end time. None for use calendar's end time.
|
||||
|
||||
Args:
|
||||
end_time
|
||||
"""
|
||||
self.cals = D.calendar(future=self._future, end_time=end_time)
|
||||
|
||||
def get(self, idx: int):
|
||||
"""
|
||||
Get datetime by index.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
idx : int
|
||||
index of the calendar
|
||||
"""
|
||||
if idx >= len(self.cals):
|
||||
return None
|
||||
return self.cals[idx]
|
||||
|
||||
def max(self) -> pd.Timestamp:
|
||||
"""
|
||||
Return the max calendar datetime
|
||||
"""
|
||||
return max(self.cals)
|
||||
|
||||
def align_idx(self, time_point, tp_type="start") -> int:
|
||||
"""
|
||||
Align the index of time_point in the calendar.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
time_point
|
||||
tp_type : str
|
||||
|
||||
Returns
|
||||
-------
|
||||
index : int
|
||||
"""
|
||||
time_point = pd.Timestamp(time_point)
|
||||
if tp_type == "start":
|
||||
idx = bisect.bisect_left(self.cals, time_point)
|
||||
elif tp_type == "end":
|
||||
idx = bisect.bisect_right(self.cals, time_point) - 1
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
return idx
|
||||
|
||||
def cal_interval(self, time_point_A, time_point_B) -> int:
|
||||
"""
|
||||
Calculate the trading day interval (time_point_A - time_point_B)
|
||||
|
||||
Args:
|
||||
time_point_A : time_point_A
|
||||
time_point_B : time_point_B (is the past of time_point_A)
|
||||
|
||||
Returns:
|
||||
int: the interval between A and B
|
||||
"""
|
||||
return self.align_idx(time_point_A) - self.align_idx(time_point_B)
|
||||
|
||||
def align_time(self, time_point, tp_type="start") -> pd.Timestamp:
|
||||
"""
|
||||
Align time_point to trade date of calendar
|
||||
|
||||
Args:
|
||||
time_point
|
||||
Time point
|
||||
tp_type : str
|
||||
time point type (`"start"`, `"end"`)
|
||||
|
||||
Returns:
|
||||
pd.Timestamp
|
||||
"""
|
||||
return self.cals[self.align_idx(time_point, tp_type=tp_type)]
|
||||
|
||||
def align_seg(self, segment: Union[dict, tuple]) -> Union[dict, tuple]:
|
||||
"""
|
||||
Align the given date to the trade date
|
||||
|
||||
for example:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
input: {'train': ('2008-01-01', '2014-12-31'), 'valid': ('2015-01-01', '2016-12-31'), 'test': ('2017-01-01', '2020-08-01')}
|
||||
|
||||
output: {'train': (Timestamp('2008-01-02 00:00:00'), Timestamp('2014-12-31 00:00:00')),
|
||||
'valid': (Timestamp('2015-01-05 00:00:00'), Timestamp('2016-12-30 00:00:00')),
|
||||
'test': (Timestamp('2017-01-03 00:00:00'), Timestamp('2020-07-31 00:00:00'))}
|
||||
|
||||
Parameters
|
||||
----------
|
||||
segment
|
||||
|
||||
Returns
|
||||
-------
|
||||
Union[dict, tuple]: the start and end trade date (pd.Timestamp) between the given start and end date.
|
||||
"""
|
||||
if isinstance(segment, dict):
|
||||
return {k: self.align_seg(seg) for k, seg in segment.items()}
|
||||
elif isinstance(segment, tuple) or isinstance(segment, list):
|
||||
return self.align_time(segment[0], tp_type="start"), self.align_time(segment[1], tp_type="end")
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
def truncate(self, segment: tuple, test_start, days: int) -> tuple:
|
||||
"""
|
||||
Truncate the segment based on the test_start date
|
||||
|
||||
Parameters
|
||||
----------
|
||||
segment : tuple
|
||||
time segment
|
||||
test_start
|
||||
days : int
|
||||
The trading days to be truncated
|
||||
the data in this segment may need 'days' data
|
||||
|
||||
Returns
|
||||
---------
|
||||
tuple: new segment
|
||||
"""
|
||||
test_idx = self.align_idx(test_start)
|
||||
if isinstance(segment, tuple):
|
||||
new_seg = []
|
||||
for time_point in segment:
|
||||
tp_idx = min(self.align_idx(time_point), test_idx - days)
|
||||
assert tp_idx > 0
|
||||
new_seg.append(self.get(tp_idx))
|
||||
return tuple(new_seg)
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
SHIFT_SD = "sliding"
|
||||
SHIFT_EX = "expanding"
|
||||
|
||||
def shift(self, seg: tuple, step: int, rtype=SHIFT_SD) -> tuple:
|
||||
"""
|
||||
Shift the datatime of segment
|
||||
|
||||
Parameters
|
||||
----------
|
||||
seg :
|
||||
datetime segment
|
||||
step : int
|
||||
rolling step
|
||||
rtype : str
|
||||
rolling type ("sliding" or "expanding")
|
||||
|
||||
Returns
|
||||
--------
|
||||
tuple: new segment
|
||||
|
||||
Raises
|
||||
------
|
||||
KeyError:
|
||||
shift will raise error if the index(both start and end) is out of self.cal
|
||||
"""
|
||||
if isinstance(seg, tuple):
|
||||
start_idx, end_idx = self.align_idx(seg[0], tp_type="start"), self.align_idx(seg[1], tp_type="end")
|
||||
if rtype == self.SHIFT_SD:
|
||||
start_idx += step
|
||||
end_idx += step
|
||||
elif rtype == self.SHIFT_EX:
|
||||
end_idx += step
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
if start_idx > len(self.cals):
|
||||
raise KeyError("The segment is out of valid calendar")
|
||||
return self.get(start_idx), self.get(end_idx)
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
@@ -1,12 +1,12 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys, traceback, signal, atexit
|
||||
import sys, traceback, signal, atexit, logging
|
||||
from . import R
|
||||
from .recorder import Recorder
|
||||
from ..log import get_module_logger
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
logger = get_module_logger("workflow", logging.INFO)
|
||||
|
||||
|
||||
# function to handle the experiment when unusual program ending occurs
|
||||
|
||||
@@ -15,7 +15,11 @@
|
||||
### Download CN Data
|
||||
|
||||
```bash
|
||||
# daily data
|
||||
python get_data.py qlib_data --target_dir ~/.qlib/qlib_data/cn_data --region cn
|
||||
|
||||
# 1min data (Optional for running non-high-frequency strategies)
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1min --region cn --interval 1min
|
||||
```
|
||||
|
||||
### Downlaod US Data
|
||||
|
||||
24
scripts/data_collector/contrib/README.md
Normal file
24
scripts/data_collector/contrib/README.md
Normal file
@@ -0,0 +1,24 @@
|
||||
# Get future trading days
|
||||
|
||||
> `D.calendar(future=True)` will be used
|
||||
|
||||
## Requirements
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Collector Data
|
||||
|
||||
```bash
|
||||
# parse instruments, using in qlib/instruments.
|
||||
python future_trading_date_collector.py --qlib_dir ~/.qlib/qlib_data/cn_data --freq day
|
||||
```
|
||||
|
||||
## Parameters
|
||||
|
||||
- qlib_dir: qlib data directory
|
||||
- freq: value from [`day`, `1min`], default `day`
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from loguru import logger
|
||||
|
||||
# get data from baostock
|
||||
import baostock as bs
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
|
||||
|
||||
from data_collector.utils import generate_minutes_calendar_from_daily
|
||||
|
||||
|
||||
def read_calendar_from_qlib(qlib_dir: Path) -> pd.DataFrame:
|
||||
calendar_path = qlib_dir.joinpath("calendars").joinpath("day.txt")
|
||||
if not calendar_path.exists():
|
||||
return pd.DataFrame()
|
||||
return pd.read_csv(calendar_path, header=None)
|
||||
|
||||
|
||||
def write_calendar_to_qlib(qlib_dir: Path, date_list: List[str], freq: str = "day"):
|
||||
calendar_path = str(qlib_dir.joinpath("calendars").joinpath(f"{freq}_future.txt"))
|
||||
|
||||
np.savetxt(calendar_path, date_list, fmt="%s", encoding="utf-8")
|
||||
logger.info(f"write future calendars success: {calendar_path}")
|
||||
|
||||
|
||||
def generate_qlib_calendar(date_list: List[str], freq: str) -> List[str]:
|
||||
print(freq)
|
||||
if freq == "day":
|
||||
return date_list
|
||||
elif freq == "1min":
|
||||
date_list = generate_minutes_calendar_from_daily(date_list, freq=freq).tolist()
|
||||
return list(map(lambda x: pd.Timestamp(x).strftime("%Y-%m-%d %H:%M:%S"), date_list))
|
||||
else:
|
||||
raise ValueError(f"Unsupported freq: {freq}")
|
||||
|
||||
|
||||
def future_calendar_collector(qlib_dir: [str, Path], freq: str = "day"):
|
||||
"""get future calendar
|
||||
|
||||
Parameters
|
||||
----------
|
||||
qlib_dir: str or Path
|
||||
qlib data directory
|
||||
freq: str
|
||||
value from ["day", "1min"], by default day
|
||||
"""
|
||||
qlib_dir = Path(qlib_dir).expanduser().resolve()
|
||||
if not qlib_dir.exists():
|
||||
raise FileNotFoundError(str(qlib_dir))
|
||||
|
||||
lg = bs.login()
|
||||
if lg.error_code != "0":
|
||||
logger.error(f"login error: {lg.error_msg}")
|
||||
return
|
||||
# read daily calendar
|
||||
daily_calendar = read_calendar_from_qlib(qlib_dir)
|
||||
end_year = pd.Timestamp.now().year
|
||||
if daily_calendar.empty:
|
||||
start_year = pd.Timestamp.now().year
|
||||
else:
|
||||
start_year = pd.Timestamp(daily_calendar.iloc[-1, 0]).year
|
||||
rs = bs.query_trade_dates(start_date=pd.Timestamp(f"{start_year}-01-01"), end_date=f"{end_year}-12-31")
|
||||
data_list = []
|
||||
while (rs.error_code == "0") & rs.next():
|
||||
_row_data = rs.get_row_data()
|
||||
if int(_row_data[1]) == 1:
|
||||
data_list.append(_row_data[0])
|
||||
data_list = sorted(data_list)
|
||||
date_list = generate_qlib_calendar(data_list, freq=freq)
|
||||
write_calendar_to_qlib(qlib_dir, date_list, freq=freq)
|
||||
bs.logout()
|
||||
logger.info(f"get trading dates success: {start_year}-01-01 to {end_year}-12-31")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(future_calendar_collector)
|
||||
5
scripts/data_collector/contrib/requirements.txt
Normal file
5
scripts/data_collector/contrib/requirements.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
baostock
|
||||
fire
|
||||
numpy
|
||||
pandas
|
||||
loguru
|
||||
@@ -114,6 +114,8 @@ class IndexBase:
|
||||
$ python collector.py save_new_companies --index_name CSI300 --qlib_dir ~/.qlib/qlib_data/cn_data
|
||||
"""
|
||||
df = self.get_new_companies()
|
||||
if df is None or df.empty:
|
||||
raise ValueError(f"get new companies error: {self.index_name}")
|
||||
df = df.drop_duplicates([self.SYMBOL_FIELD_NAME])
|
||||
df.loc[:, self.INSTRUMENTS_COLUMNS].to_csv(
|
||||
self.instruments_dir.joinpath(f"{self.index_name.lower()}_only_new.txt"), sep="\t", index=False, header=None
|
||||
@@ -184,7 +186,10 @@ class IndexBase:
|
||||
logger.info(f"start parse {self.index_name.lower()} companies.....")
|
||||
instruments_columns = [self.SYMBOL_FIELD_NAME, self.START_DATE_FIELD, self.END_DATE_FIELD]
|
||||
changers_df = self.get_changes()
|
||||
new_df = self.get_new_companies().copy()
|
||||
new_df = self.get_new_companies()
|
||||
if new_df is None or new_df.empty:
|
||||
raise ValueError(f"get new companies error: {self.index_name}")
|
||||
new_df = new_df.copy()
|
||||
logger.info("parse history companies by changes......")
|
||||
for _row in tqdm(changers_df.sort_values(self.DATE_FIELD_NAME, ascending=False).itertuples(index=False)):
|
||||
if _row.type == self.ADD:
|
||||
|
||||
@@ -35,7 +35,7 @@ WIKI_INDEX_NAME_MAP = {
|
||||
class WIKIIndex(IndexBase):
|
||||
# NOTE: The US stock code contains "PRN", and the directory cannot be created on Windows system, use the "_" prefix
|
||||
# https://superuser.com/questions/613313/why-cant-we-make-con-prn-null-folder-in-windows
|
||||
INST_PREFIX = "_"
|
||||
INST_PREFIX = ""
|
||||
|
||||
def __init__(self, index_name: str, qlib_dir: [str, Path] = None, request_retry: int = 5, retry_sleep: int = 3):
|
||||
super(WIKIIndex, self).__init__(
|
||||
@@ -123,7 +123,7 @@ class NASDAQ100Index(WIKIIndex):
|
||||
MAX_WORKERS = 16
|
||||
|
||||
def filter_df(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
if not (set(df.columns) - {"Company", "Ticker"}):
|
||||
if len(df) >= 100 and "Ticker" in df.columns:
|
||||
return df.loc[:, ["Ticker"]].copy()
|
||||
|
||||
@property
|
||||
|
||||
@@ -10,7 +10,9 @@ import random
|
||||
import requests
|
||||
import functools
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from lxml import etree
|
||||
from loguru import logger
|
||||
@@ -418,5 +420,40 @@ def get_trading_date_by_shift(trading_list: list, trading_date: pd.Timestamp, sh
|
||||
return res
|
||||
|
||||
|
||||
def generate_minutes_calendar_from_daily(
|
||||
calendars: Iterable,
|
||||
freq: str = "1min",
|
||||
am_range: Tuple[str, str] = ("09:30:00", "11:29:00"),
|
||||
pm_range: Tuple[str, str] = ("13:00:00", "14:59:00"),
|
||||
) -> pd.Index:
|
||||
"""generate minutes calendar
|
||||
|
||||
Parameters
|
||||
----------
|
||||
calendars: Iterable
|
||||
daily calendar
|
||||
freq: str
|
||||
by default 1min
|
||||
am_range: Tuple[str, str]
|
||||
AM Time Range, by default China-Stock: ("09:30:00", "11:29:00")
|
||||
pm_range: Tuple[str, str]
|
||||
PM Time Range, by default China-Stock: ("13:00:00", "14:59:00")
|
||||
|
||||
"""
|
||||
daily_format: str = "%Y-%m-%d"
|
||||
res = []
|
||||
for _day in calendars:
|
||||
for _range in [am_range, pm_range]:
|
||||
res.append(
|
||||
pd.date_range(
|
||||
f"{pd.Timestamp(_day).strftime(daily_format)} {_range[0]}",
|
||||
f"{pd.Timestamp(_day).strftime(daily_format)} {_range[1]}",
|
||||
freq=freq,
|
||||
)
|
||||
)
|
||||
|
||||
return pd.Index(sorted(set(np.hstack(res))))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
assert len(get_hs_stock_symbols()) >= MINIMUM_SYMBOLS_NUM
|
||||
|
||||
@@ -24,7 +24,12 @@ from qlib.config import REG_CN as REGION_CN
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
from data_collector.base import BaseCollector, BaseNormalize, BaseRun
|
||||
from data_collector.utils import get_calendar_list, get_hs_stock_symbols, get_us_stock_symbols
|
||||
from data_collector.utils import (
|
||||
get_calendar_list,
|
||||
get_hs_stock_symbols,
|
||||
get_us_stock_symbols,
|
||||
generate_minutes_calendar_from_daily,
|
||||
)
|
||||
|
||||
INDEX_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{index_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg={begin}&end={end}"
|
||||
|
||||
@@ -418,21 +423,9 @@ class YahooNormalize1min(YahooNormalize, ABC):
|
||||
return calendar_list_1d
|
||||
|
||||
def generate_1min_from_daily(self, calendars: Iterable) -> pd.Index:
|
||||
res = []
|
||||
daily_format = self.DAILY_FORMAT
|
||||
am_range = self.AM_RANGE
|
||||
pm_range = self.PM_RANGE
|
||||
for _day in calendars:
|
||||
for _range in [am_range, pm_range]:
|
||||
res.append(
|
||||
pd.date_range(
|
||||
f"{_day.strftime(daily_format)} {_range[0]}",
|
||||
f"{_day.strftime(daily_format)} {_range[1]}",
|
||||
freq="1min",
|
||||
)
|
||||
)
|
||||
|
||||
return pd.Index(sorted(set(np.hstack(res))))
|
||||
return generate_minutes_calendar_from_daily(
|
||||
calendars, freq="1min", am_range=self.AM_RANGE, pm_range=self.PM_RANGE
|
||||
)
|
||||
|
||||
def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
# TODO: using daily data factor
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user