mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-29 00:51:19 +08:00
Compare commits
40 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7540b1257b | ||
|
|
57f7ed9914 | ||
|
|
9e3d0249f7 | ||
|
|
2ac964c470 | ||
|
|
07f0d4f599 | ||
|
|
ea4fb33ff2 | ||
|
|
ed0c238787 | ||
|
|
80af395b3c | ||
|
|
4dc66932d5 | ||
|
|
ec8969a3ae | ||
|
|
528f74af09 | ||
|
|
d482726f28 | ||
|
|
cfc3e886ed | ||
|
|
60d45ad770 | ||
|
|
0e8b94a552 | ||
|
|
4bf127eba5 | ||
|
|
c149c8616c | ||
|
|
3274e16c95 | ||
|
|
d496cf7476 | ||
|
|
357ee74b6f | ||
|
|
5da5cf5175 | ||
|
|
6a946761cf | ||
|
|
76b7b5f24b | ||
|
|
d7d19feb4e | ||
|
|
bba6972a55 | ||
|
|
18af288692 | ||
|
|
ba056850cb | ||
|
|
aed5b8ebc0 | ||
|
|
79355666a9 | ||
|
|
144e1e2459 | ||
|
|
635632e4ed | ||
|
|
c5834476e2 | ||
|
|
01afd06e18 | ||
|
|
d533219738 | ||
|
|
5b5c99fe75 | ||
|
|
da48f42f3f | ||
|
|
f979dcf5e8 | ||
|
|
97aa16a078 | ||
|
|
094be9be86 | ||
|
|
d9b9386032 |
3
.github/workflows/python-publish.yml
vendored
3
.github/workflows/python-publish.yml
vendored
@@ -12,7 +12,8 @@ jobs:
|
||||
runs-on: ${{ matrix.os }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [windows-latest, macos-latest, macos-11]
|
||||
os: [windows-latest, macos-11]
|
||||
# FIXME: macos-latest will raise error now.
|
||||
# not supporting 3.6 due to annotations is not supported https://stackoverflow.com/a/52890129
|
||||
python-version: [3.7, 3.8]
|
||||
|
||||
|
||||
32
.github/workflows/test.yml
vendored
32
.github/workflows/test.yml
vendored
@@ -33,7 +33,37 @@ jobs:
|
||||
- name: Install Qlib with pip
|
||||
run: |
|
||||
pip install numpy==1.19.5 ruamel.yaml
|
||||
pip install pyqlib --ignore-installed
|
||||
pip install pyqlib --ignore-installed
|
||||
|
||||
# Check Qlib with pylint
|
||||
# TODO: These problems we will solve in the future. Important among them are: W0221, W0223, W0237, E1102
|
||||
# C0103: invalid-name
|
||||
# C0209: consider-using-f-string
|
||||
# R0402: consider-using-from-import
|
||||
# R1705: no-else-return
|
||||
# R1710: inconsistent-return-statements
|
||||
# R1725: super-with-arguments
|
||||
# R1735: use-dict-literal
|
||||
# W0102: dangerous-default-value
|
||||
# W0212: protected-access
|
||||
# W0221: arguments-differ
|
||||
# W0223: abstract-method
|
||||
# W0231: super-init-not-called
|
||||
# W0237: arguments-renamed
|
||||
# W0612: unused-variable
|
||||
# W0621: redefined-outer-name
|
||||
# W0622: redefined-builtin
|
||||
# FIXME: specify exception type
|
||||
# W0703: broad-except
|
||||
# W1309: f-string-without-interpolation
|
||||
# E1102: not-callable
|
||||
# E1136: unsubscriptable-object
|
||||
# References for parameters: https://github.com/PyCQA/pylint/issues/4577#issuecomment-1000245962
|
||||
- name: Check Qlib with pylint
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install pylint
|
||||
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0201,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' qlib --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500"
|
||||
|
||||
- name: Test data downloads
|
||||
run: |
|
||||
|
||||
5
.pylintrc
Normal file
5
.pylintrc
Normal file
@@ -0,0 +1,5 @@
|
||||
[TYPECHECK]
|
||||
# https://stackoverflow.com/a/53572939
|
||||
# List of members which are set dynamically and missed by Pylint inference
|
||||
# system, and so shouldn't trigger E1101 when accessed.
|
||||
generated-members=numpy.*, torch.*
|
||||
65
README.md
65
README.md
@@ -45,27 +45,52 @@ With Qlib, users can easily try ideas to create better Quant investment strategi
|
||||
|
||||
For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative Investment Platform"](https://arxiv.org/abs/2009.11189).
|
||||
|
||||
- [**Plans**](#plans)
|
||||
- [Framework of Qlib](#framework-of-qlib)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Installation](#installation)
|
||||
- [Data Preparation](#data-preparation)
|
||||
- [Auto Quant Research Workflow](#auto-quant-research-workflow)
|
||||
- [Building Customized Quant Research Workflow by Code](#building-customized-quant-research-workflow-by-code)
|
||||
- [Main Challenges & Solutions in Quant Research](#main-challenges--solutions-in-quant-research)
|
||||
- [Forecasting: Finding Valuable Signals/Patterns](#forecasting-finding-valuable-signalspatterns)
|
||||
- [**Quant Model (Paper) Zoo**](#quant-model-paper-zoo)
|
||||
- [Run a Single Model](#run-a-single-model)
|
||||
- [Run Multiple Models](#run-multiple-models)
|
||||
- [Adapting to Market Dynamics](#adapting-to-market-dynamics)
|
||||
- [**Quant Dataset Zoo**](#quant-dataset-zoo)
|
||||
- [More About Qlib](#more-about-qlib)
|
||||
- [Offline Mode and Online Mode](#offline-mode-and-online-mode)
|
||||
- [Performance of Qlib Data Server](#performance-of-qlib-data-server)
|
||||
- [Related Reports](#related-reports)
|
||||
- [Contact Us](#contact-us)
|
||||
- [Contributing](#contributing)
|
||||
|
||||
<table>
|
||||
<tbody>
|
||||
<tr>
|
||||
<th>Frameworks, Tutorial, Data & DevOps</th>
|
||||
<th>Main Challenges & Solutions in Quant Research</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<li><a href="#plans"><strong>Plans</strong></a></li>
|
||||
<li><a href="#framework-of-qlib">Framework of Qlib</a></li>
|
||||
<li><a href="#quick-start">Quick Start</a></li>
|
||||
<ul dir="auto">
|
||||
<li type="circle"><a href="#installation">Installation</a> </li>
|
||||
<li type="circle"><a href="#data-preparation">Data Preparation</a></li>
|
||||
<li type="circle"><a href="#auto-quant-research-workflow">Auto Quant Research Workflow</a></li>
|
||||
<li type="circle"><a href="#building-customized-quant-research-workflow-by-code">Building Customized Quant Research Workflow by Code</a></li></ul>
|
||||
<li><a href="#quant-dataset-zoo"><strong>Quant Dataset Zoo</strong></a></li>
|
||||
<li><a href="#more-about-qlib">More About Qlib</a></li>
|
||||
<li><a href="#offline-mode-and-online-mode">Offline Mode and Online Mode</a>
|
||||
<ul>
|
||||
<li type="circle"><a href="#performance-of-qlib-data-server">Performance of Qlib Data Server</a></li></ul>
|
||||
<li><a href="#related-reports">Related Reports</a></li>
|
||||
<li><a href="#contact-us">Contact Us</a></li>
|
||||
<li><a href="#contributing">Contributing</a></li>
|
||||
</td>
|
||||
<td valign="baseline">
|
||||
<li><a href="#main-challenges--solutions-in-quant-research">Main Challenges & Solutions in Quant Research</a>
|
||||
<ul>
|
||||
<li type="circle"><a href="#forecasting-finding-valuable-signalspatterns">Forecasting: Finding Valuable Signals/Patterns</a>
|
||||
<ul>
|
||||
<li type="disc"><a href="#quant-model-paper-zoo"><strong>Quant Model (Paper) Zoo</strong></a>
|
||||
<ul>
|
||||
<li type="circle"><a href="#run-a-single-model">Run a Single Model</a></li>
|
||||
<li type="circle"><a href="#run-multiple-models">Run Multiple Models</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</ul>
|
||||
</li>
|
||||
<li type="circle"><a href="#adapting-to-market-dynamics">Adapting to Market Dynamics</a></li>
|
||||
</ul>
|
||||
</li>
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
# Plans
|
||||
New features under development(order by estimated release time).
|
||||
|
||||
@@ -52,7 +52,8 @@ Also, ``Qlib`` provides a high-frequency dataset. Users can run a high-frequency
|
||||
Qlib Format Dataset
|
||||
--------------------
|
||||
``Qlib`` has provided an off-the-shelf dataset in `.bin` format, users could use the script ``scripts/get_data.py`` to download the China-Stock dataset as follows.
|
||||
The price volume data look different from the actual dealling price because of they are **adjusted** (`adjusted price <https://www.investopedia.com/terms/a/adjusted_closing_price.asp>`_). And then you may find that the adjusted price may be different from different data sources. This is because different data sources may vary in the way of adjusting prices. Qlib normalize the price on first trading day of each stock to 1 when adjusting them.
|
||||
The price volume data look different from the actual dealling price because of they are **adjusted** (`adjusted price <https://www.investopedia.com/terms/a/adjusted_closing_price.asp>`_). And then you may find that the adjusted price may be different from different data sources. This is because different data sources may vary in the way of adjusting prices. Qlib normalize the price on first trading day of each stock to 1 when adjusting them.
|
||||
Users can leverage `$factor` to get the original trading price (e.g. `$close / $factor` to get the original close price).
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
|
||||
@@ -28,4 +28,11 @@ The frequency of trading algorithm, decision content and execution environment c
|
||||
Example
|
||||
===========================
|
||||
|
||||
An example of nested decision execution framework for high-frequency can be found `here <https://github.com/microsoft/qlib/blob/main/examples/nested_decision_execution/workflow.py>`_.
|
||||
An example of nested decision execution framework for high-frequency can be found `here <https://github.com/microsoft/qlib/blob/main/examples/nested_decision_execution/workflow.py>`_.
|
||||
|
||||
|
||||
Besides, the above examples, here are some other related work about high-frequency trading in Qlib.
|
||||
|
||||
- `Prediction with high-frequency data <https://github.com/microsoft/qlib/tree/main/examples/highfreq#benchmarks-performance-predicting-the-price-trend-in-high-frequency-data>`_
|
||||
- `Examples <https://github.com/microsoft/qlib/blob/main/examples/orderbook_data/>`_ to extract features form high-frequency data without fixed frequency.
|
||||
- `A paper <https://github.com/microsoft/qlib/tree/high-freq-execution#high-frequency-execution>`_ for high-frequency trading.
|
||||
|
||||
@@ -126,7 +126,9 @@ A prediction sample is shown as follows.
|
||||
|
||||
Normally, the prediction score is the output of the models. But some models are learned from a label with a different scale. So the scale of the prediction score may be different from your expectation(e.g. the return of instruments).
|
||||
|
||||
Qlib didn't add a step to scale the prediction score to a unified scale. Because not every trading strategy cares about the scale(e.g. TopkDropoutStrategy only cares about the order). So the strategy is responsible for rescaling the prediction score(e.g. some portfolio-optimization-based strategies may require a meaningful scale).
|
||||
Qlib didn't add a step to scale the prediction score to a unified scale due to the following reasons.
|
||||
- Because not every trading strategy cares about the scale(e.g. TopkDropoutStrategy only cares about the order). So the strategy is responsible for rescaling the prediction score(e.g. some portfolio-optimization-based strategies may require a meaningful scale).
|
||||
- The model has the flexibility to define the target, loss, and data processing. So we don't think there is a silver bullet to rescale it back directly barely based on the model's outputs. If you want to scale it back to some meaningful values(e.g. stock returns.), an intuitive solution is to create a regression model for the model's recent outputs and your recent target values.
|
||||
|
||||
Running backtest
|
||||
-----------------
|
||||
@@ -192,6 +194,14 @@ Running backtest
|
||||
qlib.init(provider_uri=<qlib data dir>)
|
||||
|
||||
CSI300_BENCH = "SH000300"
|
||||
# Benchmark is for calculating the excess return of your strategy.
|
||||
# Its data format will be like **ONE normal instrument**.
|
||||
# For example, you can query its data with the code below
|
||||
# `D.features(["SH000300"], ["$close"], start_time='2010-01-01', end_time='2017-12-31', freq='day')`
|
||||
# It is different from the argument `market`, which indicates a universe of stocks (e.g. **A SET** of stocks like csi300)
|
||||
# For example, you can query all data from a stock market with the code below.
|
||||
# ` D.features(D.instruments(market='csi300'), ["$close"], start_time='2010-01-01', end_time='2017-12-31', freq='day')`
|
||||
|
||||
FREQ = "day"
|
||||
STRATEGY_CONFIG = {
|
||||
"topk": 50,
|
||||
|
||||
12
docs/conf.py
12
docs/conf.py
@@ -54,9 +54,9 @@ master_doc = "index"
|
||||
|
||||
|
||||
# General information about the project.
|
||||
project = u"QLib"
|
||||
copyright = u"Microsoft"
|
||||
author = u"Microsoft"
|
||||
project = "QLib"
|
||||
copyright = "Microsoft"
|
||||
author = "Microsoft"
|
||||
|
||||
# The version info for the project you're documenting, acts as replacement for
|
||||
# |version| and |release|, also used in various other places throughout the
|
||||
@@ -174,7 +174,7 @@ latex_elements = {
|
||||
# (source start file, target name, title,
|
||||
# author, documentclass [howto, manual, or own class]).
|
||||
latex_documents = [
|
||||
(master_doc, "qlib.tex", u"QLib Documentation", u"Microsoft", "manual"),
|
||||
(master_doc, "qlib.tex", "QLib Documentation", "Microsoft", "manual"),
|
||||
]
|
||||
|
||||
|
||||
@@ -182,7 +182,7 @@ latex_documents = [
|
||||
|
||||
# One entry per manual page. List of tuples
|
||||
# (source start file, name, description, authors, manual section).
|
||||
man_pages = [(master_doc, "qlib", u"QLib Documentation", [author], 1)]
|
||||
man_pages = [(master_doc, "qlib", "QLib Documentation", [author], 1)]
|
||||
|
||||
|
||||
# -- Options for Texinfo output -------------------------------------------
|
||||
@@ -194,7 +194,7 @@ texinfo_documents = [
|
||||
(
|
||||
master_doc,
|
||||
"QLib",
|
||||
u"QLib Documentation",
|
||||
"QLib Documentation",
|
||||
author,
|
||||
"QLib",
|
||||
"One line description of project.",
|
||||
|
||||
@@ -14,9 +14,19 @@ Continuous Integration (CI) tools help you stick to the quality standards by run
|
||||
|
||||
When you submit a PR request, you can check whether your code passes the CI tests in the "check" section at the bottom of the web page.
|
||||
|
||||
A common error is the mixed use of space and tab. You can fix the bug by inputing the following code in the command line.
|
||||
1. Qlib will check the code format with black. The PR will raise error if your code does not align to the standard of Qlib(e.g. a common error is the mixed use of space and tab).
|
||||
You can fix the bug by inputing the following code in the command line.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
pip install black
|
||||
python -m black . -l 120
|
||||
|
||||
|
||||
2. Qlib will check your code style pylint. The checking command is implemented in [github action workflow](https://github.com/microsoft/qlib/blob/0e8b94a552f1c457cfa6cd2c1bb3b87ebb3fb279/.github/workflows/test.yml#L66).
|
||||
Sometime pylint's restrictions are not that reasonable. You can ignore specific errors like this
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
return -ICLoss()(pred, target, index) # pylint: disable=E1130
|
||||
|
||||
|
||||
@@ -120,6 +120,32 @@ For more details about features, please refer `Feature API <../component/data.ht
|
||||
|
||||
.. note:: When calling `D.features()` at the client, use parameter `disk_cache=0` to skip dataset cache, use `disk_cache=1` to generate and use dataset cache. In addition, when calling at the server, users can use `disk_cache=2` to update the dataset cache.
|
||||
|
||||
|
||||
When you are building complicated expressions, implementing all the expressions in a single string may not be easy.
|
||||
For example, it looks quite long and complicated:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
>> from qlib.data import D
|
||||
>> data = D.features(["sh600519"], ["(($high / $close) + ($open / $close)) * (($high / $close) + ($open / $close)) / ($high / $close) + ($open / $close)"], start_time="20200101")
|
||||
|
||||
|
||||
But using string is not the only way to implement the expression. You can also implement expression by code.
|
||||
Here is an exmaple which does the same thing as above examples.
|
||||
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
>> from qlib.data.ops import *
|
||||
>> f1 = Feature("high") / Feature("close")
|
||||
>> f2 = Feature("open") / Feature("close")
|
||||
>> f3 = f1 + f2
|
||||
>> f4 = f3 * f3 / f3
|
||||
|
||||
>> data = D.features(["sh600519"], [f4], start_time="20200101")
|
||||
>> data.head()
|
||||
|
||||
|
||||
API
|
||||
====================
|
||||
To know more about how to use the Data, go to API Reference: `Data API <../reference/api.html#data>`_
|
||||
|
||||
@@ -37,7 +37,8 @@ Initialize Qlib before calling other APIs: run following code in python.
|
||||
Parameters
|
||||
-------------------
|
||||
|
||||
Besides `provider_uri` and `region`, `qlib.init` has other parameters. The following are several important parameters of `qlib.init`:
|
||||
Besides `provider_uri` and `region`, `qlib.init` has other parameters.
|
||||
The following are several important parameters of `qlib.init` (`Qlib` has a lot of config. Only part of parameters are limited here. More detailed setting can be found `here <https://github.com/microsoft/qlib/blob/main/qlib/config.py>`_):
|
||||
|
||||
- `provider_uri`
|
||||
Type: str. The URI of the Qlib data. For example, it could be the location where the data loaded by ``get_data.py`` are stored.
|
||||
@@ -48,7 +49,7 @@ Besides `provider_uri` and `region`, `qlib.init` has other parameters. The follo
|
||||
- ``qlib.constant.REG_CN``: China stock market.
|
||||
|
||||
Different modes will result in different trading limitations and costs.
|
||||
The region is just `shortcuts for defining a batch of configurations <https://github.com/microsoft/qlib/blob/main/qlib/config.py#L239>`_. Users can set the key configurations manually if the existing region setting can't meet their requirements.
|
||||
The region is just `shortcuts for defining a batch of configurations <https://github.com/microsoft/qlib/blob/528f74af099bf6156e9480bcd2bb28e453231212/qlib/config.py#L249>`_, which include minimal trading order unit (``trade_unit``), trading limitation (``limit_threshold``) , etc. It is not a necessary part and users can set the key configurations manually if the existing region setting can't meet their requirements.
|
||||
- `redis_host`
|
||||
Type: str, optional parameter(default: "127.0.0.1"), host of `redis`
|
||||
The lock and cache mechanism relies on redis.
|
||||
@@ -88,3 +89,9 @@ Besides `provider_uri` and `region`, `qlib.init` has other parameters. The follo
|
||||
"task_url": "mongodb://localhost:27017/", # your mongo url
|
||||
"task_db_name": "rolling_db", # the database name of Task Management
|
||||
})
|
||||
|
||||
- `logging_level`
|
||||
The logging level for the system.
|
||||
|
||||
- `kernels`
|
||||
The number of processes used when calculating features in Qlib's expression engine. It is very helpful to set it to 1 when you are debuggin an expression calculating exception
|
||||
|
||||
@@ -63,8 +63,6 @@ task:
|
||||
module_path: qlib.contrib.model.pytorch_nn
|
||||
kwargs:
|
||||
loss: mse
|
||||
input_dim: 157
|
||||
output_dim: 1
|
||||
lr: 0.002
|
||||
lr_decay: 0.96
|
||||
lr_decay_steps: 100
|
||||
@@ -73,6 +71,8 @@ task:
|
||||
batch_size: 8192
|
||||
GPU: 0
|
||||
weight_decay: 0.0002
|
||||
pt_model_kwargs:
|
||||
input_dim: 157
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
|
||||
@@ -51,8 +51,6 @@ task:
|
||||
module_path: qlib.contrib.model.pytorch_nn
|
||||
kwargs:
|
||||
loss: mse
|
||||
input_dim: 360
|
||||
output_dim: 1
|
||||
lr: 0.002
|
||||
lr_decay: 0.96
|
||||
lr_decay_steps: 100
|
||||
@@ -60,6 +58,8 @@ task:
|
||||
max_steps: 8000
|
||||
batch_size: 4096
|
||||
GPU: 0
|
||||
pt_model_kwargs:
|
||||
input_dim: 360
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
|
||||
@@ -4,6 +4,7 @@ This page lists a batch of methods designed for alpha seeking. Each method tries
|
||||
The alpha is evaluated in two ways.
|
||||
1. The correlation between the alpha and future return.
|
||||
1. Constructing portfolio based on the alpha and evaluating the final total return.
|
||||
- The explanation of metrics can be found [here](https://qlib.readthedocs.io/en/latest/component/report.html#id4)
|
||||
|
||||
Here are the results of each benchmark model running on Qlib's `Alpha360` and `Alpha158` dataset with China's A shared-stock & CSI300 data respectively. The values of each metric are the mean and std calculated based on 20 runs with different random seeds.
|
||||
|
||||
@@ -16,6 +17,8 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
> NOTE:
|
||||
> The backtest start from 0.8.0 is quite different from previous version. Please check out the changelog for the difference.
|
||||
|
||||
> NOTE:
|
||||
> We have very limited resources to implement and finetune the models. We tried our best effort to fairly compare these models. But some models may have greater potential than what it looks like in the table below. Your contribution is highly welcomed to explore their potential.
|
||||
|
||||
## Alpha158 dataset
|
||||
|
||||
@@ -66,3 +69,9 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
- The selected 20 features are based on the feature importance of a lightgbm-based model.
|
||||
- The base model of DoubleEnsemble is LGBM.
|
||||
- The base model of TCTS is GRU.
|
||||
- About the datasets
|
||||
- Alpha158 is a tabular dataset. There are less spatial relationships between different features. Each feature are carefully desgined by human (a.k.a feature engineering)
|
||||
- Alpha360 contains raw price and volue data without much feature engineering. There are strong strong spatial relationships between the features in the time dimension.
|
||||
- The metrics can be categorized into two
|
||||
- Signal-based evaluation: IC, ICIR, Rank IC, Rank ICIR
|
||||
- Portfolio-based metrics: Annualized Return, Information Ratio, Max Drawdown
|
||||
|
||||
@@ -130,7 +130,7 @@ class TRAModel(Model):
|
||||
|
||||
if prob is not None:
|
||||
P = sinkhorn(-L, epsilon=0.01) # sample assignment matrix
|
||||
lamb = self.lamb * (self.rho ** self.global_step)
|
||||
lamb = self.lamb * (self.rho**self.global_step)
|
||||
reg = prob.log().mul(P).sum(dim=-1).mean()
|
||||
loss = loss - lamb * reg
|
||||
|
||||
@@ -547,7 +547,7 @@ def evaluate(pred):
|
||||
score = pred.score
|
||||
label = pred.label
|
||||
diff = score - label
|
||||
MSE = (diff ** 2).mean()
|
||||
MSE = (diff**2).mean()
|
||||
MAE = (diff.abs()).mean()
|
||||
IC = score.corr(label)
|
||||
return {"MSE": MSE, "MAE": MAE, "IC": IC}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.17.4
|
||||
numpy==1.21.0
|
||||
scikit_learn==0.23.2
|
||||
torch==1.7.0
|
||||
@@ -1,3 +1,3 @@
|
||||
numpy==1.17.4
|
||||
numpy==1.21.0
|
||||
pandas==1.1.2
|
||||
torch==1.2.0
|
||||
@@ -1,3 +1,3 @@
|
||||
numpy==1.17.4
|
||||
numpy==1.21.0
|
||||
pandas==1.1.2
|
||||
xgboost==1.2.1
|
||||
@@ -4,16 +4,16 @@ This is the implementation of `DDG-DA` based on `Meta Controller` component prov
|
||||
Please refer to the paper for more details: *DDG-DA: Data Distribution Generation for Predictable Concept Drift Adaptation* [[arXiv](https://arxiv.org/abs/2201.04038)]
|
||||
|
||||
|
||||
## Background
|
||||
# Background
|
||||
In many real-world scenarios, we often deal with streaming data that is sequentially collected over time. Due to the non-stationary nature of the environment, the streaming data distribution may change in unpredictable ways, which is known as concept drift. To handle concept drift, previous methods first detect when/where the concept drift happens and then adapt models to fit the distribution of the latest data. However, there are still many cases that some underlying factors of environment evolution are predictable, making it possible to model the future concept drift trend of the streaming data, while such cases are not fully explored in previous work.
|
||||
|
||||
Therefore, we propose a novel method `DDG-DA`, that can effectively forecast the evolution of data distribution and improve the performance of models. Specifically, we first train a predictor to estimate the future data distribution, then leverage it to generate training samples, and finally train models on the generated data.
|
||||
|
||||
## Dataset
|
||||
# Dataset
|
||||
The data in the paper are private. So we conduct experiments on Qlib's public dataset.
|
||||
Though the dataset is different, the conclusion remains the same. By applying `DDG-DA`, users can see rising trends at the test phase both in the proxy models' ICs and the performances of the forecasting models.
|
||||
|
||||
## Run the Code
|
||||
# Run the Code
|
||||
Users can try `DDG-DA` by running the following command:
|
||||
```bash
|
||||
python workflow.py run_all
|
||||
@@ -24,7 +24,10 @@ The default forecasting models are `Linear`. Users can choose other forecasting
|
||||
python workflow.py --forecast_model="gbdt" run_all
|
||||
```
|
||||
|
||||
|
||||
## Results
|
||||
|
||||
# Results
|
||||
The results of related methods in Qlib's public dataset can be found [here](../)
|
||||
|
||||
# Requirements
|
||||
Here is the minimal hardware requirements to run the ``workflow.py`` of DDG-DA.
|
||||
* Memory: 45G
|
||||
* Disk: 4G
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.17.4
|
||||
numpy==1.21.0
|
||||
lightgbm==3.1.0
|
||||
optuna==2.7.0
|
||||
optuna-dashboard==0.4.1
|
||||
|
||||
@@ -21,7 +21,7 @@ class TestClass(unittest.TestCase):
|
||||
provider_uri = "~/.qlib/qlib_data/yahoo_cn_1min"
|
||||
qlib.init(
|
||||
provider_uri=provider_uri,
|
||||
mem_cache_size_limit=1024 ** 3 * 2,
|
||||
mem_cache_size_limit=1024**3 * 2,
|
||||
mem_cache_type="sizeof",
|
||||
kernels=1,
|
||||
expression_provider={"class": "LocalExpressionProvider", "kwargs": {"time2idx": False}},
|
||||
|
||||
@@ -24,6 +24,7 @@ We use China stock market data for our example.
|
||||
unzip -d ~/.qlib/qlib_data/cn_data csi300_weight.zip
|
||||
rm -f csi300_weight.zip
|
||||
```
|
||||
NOTE: We don't find any public free resource to get the weight in the benchmark. To run the example, we manually create this weight data.
|
||||
|
||||
2. Prepare risk model data:
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
from pathlib import Path
|
||||
|
||||
__version__ = "0.8.2"
|
||||
__version__ = "0.8.4"
|
||||
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
|
||||
import os
|
||||
from typing import Union
|
||||
@@ -30,8 +30,8 @@ def init(default_conf="client", **kwargs):
|
||||
When using the recorder, skip_if_reg can set to True to avoid loss of recorder.
|
||||
|
||||
"""
|
||||
from .config import C
|
||||
from .data.cache import H
|
||||
from .config import C # pylint: disable=C0415
|
||||
from .data.cache import H # pylint: disable=C0415
|
||||
|
||||
# FIXME: this logger ignored the level in config
|
||||
logger = get_module_logger("Initialization", level=logging.INFO)
|
||||
@@ -85,7 +85,7 @@ def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False):
|
||||
mount_command = "sudo mount.nfs %s %s" % (provider_uri, mount_path)
|
||||
# If the provider uri looks like this 172.23.233.89//data/csdesign'
|
||||
# It will be a nfs path. The client provider will be used
|
||||
if not auto_mount:
|
||||
if not auto_mount: # pylint: disable=R1702
|
||||
if not Path(mount_path).exists():
|
||||
raise FileNotFoundError(
|
||||
f"Invalid mount path: {mount_path}! Please mount manually: {mount_command} or Set init parameter `auto_mount=True`"
|
||||
@@ -139,8 +139,10 @@ def _mount_nfs_uri(provider_uri, mount_path, auto_mount: bool = False):
|
||||
if not _is_mount:
|
||||
try:
|
||||
Path(mount_path).mkdir(parents=True, exist_ok=True)
|
||||
except Exception:
|
||||
raise OSError(f"Failed to create directory {mount_path}, please create {mount_path} manually!")
|
||||
except Exception as e:
|
||||
raise OSError(
|
||||
f"Failed to create directory {mount_path}, please create {mount_path} manually!"
|
||||
) from e
|
||||
|
||||
# check nfs-common
|
||||
command_res = os.popen("dpkg -l | grep nfs-common")
|
||||
|
||||
@@ -171,8 +171,8 @@ def get_strategy_executor(
|
||||
# NOTE:
|
||||
# - for avoiding recursive import
|
||||
# - typing annotations is not reliable
|
||||
from ..strategy.base import BaseStrategy
|
||||
from .executor import BaseExecutor
|
||||
from ..strategy.base import BaseStrategy # pylint: disable=C0415
|
||||
from .executor import BaseExecutor # pylint: disable=C0415
|
||||
|
||||
trade_account = create_account_instance(
|
||||
start_time=start_time, end_time=end_time, benchmark=benchmark, account=account, pos_type=pos_type
|
||||
|
||||
@@ -2,11 +2,11 @@
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
import copy
|
||||
from typing import Dict, List, Tuple, TYPE_CHECKING
|
||||
from typing import Dict, List, Tuple
|
||||
from qlib.utils import init_instance_by_config
|
||||
import pandas as pd
|
||||
|
||||
from .position import BasePosition, InfPosition, Position
|
||||
from .position import BasePosition
|
||||
from .report import PortfolioMetrics, Indicator
|
||||
from .decision import BaseTradeDecision, Order
|
||||
from .exchange import Exchange
|
||||
|
||||
@@ -7,19 +7,18 @@ from qlib.data.data import Cal
|
||||
from qlib.utils.time import concat_date_time, epsilon_change
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
from typing import ClassVar, Optional, Union, List, Tuple
|
||||
|
||||
# try to fix circular imports when enabling type hints
|
||||
from typing import Callable, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qlib.strategy.base import BaseStrategy
|
||||
from qlib.backtest.exchange import Exchange
|
||||
from qlib.backtest.utils import TradeCalendarManager
|
||||
import warnings
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from dataclasses import dataclass, field
|
||||
from typing import ClassVar, Optional, Union, List, Set, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class OrderDir(IntEnum):
|
||||
@@ -418,7 +417,7 @@ class BaseTradeDecision:
|
||||
return kwargs["default_value"]
|
||||
else:
|
||||
# Default to get full index
|
||||
raise NotImplementedError(f"The decision didn't provide an index range")
|
||||
raise NotImplementedError(f"The decision didn't provide an index range") from NotImplementedError
|
||||
|
||||
# clip index
|
||||
if getattr(self, "total_step", None) is not None:
|
||||
|
||||
@@ -3,13 +3,13 @@
|
||||
from __future__ import annotations
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .account import Account
|
||||
|
||||
from qlib.backtest.position import BasePosition, Position
|
||||
import random
|
||||
from typing import List, Tuple, Union
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
@@ -18,7 +18,7 @@ from ..config import C
|
||||
from ..constant import REG_CN
|
||||
from ..log import get_module_logger
|
||||
from .decision import Order, OrderDir, OrderHelper
|
||||
from .high_performance_ds import BaseQuote, PandasQuote, NumpyQuote
|
||||
from .high_performance_ds import BaseQuote, NumpyQuote
|
||||
|
||||
|
||||
class Exchange:
|
||||
|
||||
@@ -1,22 +1,18 @@
|
||||
from abc import abstractclassmethod, abstractmethod
|
||||
from abc import abstractmethod
|
||||
import copy
|
||||
from qlib.backtest.position import BasePosition
|
||||
from qlib.log import get_module_logger
|
||||
from types import GeneratorType
|
||||
from qlib.backtest.account import Account
|
||||
import warnings
|
||||
import pandas as pd
|
||||
from typing import List, Tuple, Union
|
||||
from collections import defaultdict
|
||||
|
||||
from qlib.backtest.report import Indicator
|
||||
|
||||
from .decision import EmptyTradeDecision, Order, BaseTradeDecision
|
||||
from .decision import Order, BaseTradeDecision
|
||||
from .exchange import Exchange
|
||||
from .utils import TradeCalendarManager, CommonInfrastructure, LevelInfrastructure, get_start_end_idx
|
||||
|
||||
from ..utils import init_instance_by_config
|
||||
from ..utils.time import Freq
|
||||
from ..strategy.base import BaseStrategy
|
||||
|
||||
|
||||
@@ -193,7 +189,8 @@ class BaseExecutor:
|
||||
pass
|
||||
return return_value.get("execute_result")
|
||||
|
||||
@abstractclassmethod
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def _collect_data(cls, trade_decision: BaseTradeDecision, level: int = 0) -> Tuple[List[object], dict]:
|
||||
"""
|
||||
Please refer to the doc of collect_data
|
||||
@@ -453,7 +450,6 @@ class NestedExecutor(BaseExecutor):
|
||||
inner_exe_res :
|
||||
the execution result of inner task
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_all_executors(self):
|
||||
"""get all executors, including self and inner_executor.get_all_executors()"""
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
import copy
|
||||
import pathlib
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import pandas as pd
|
||||
@@ -362,7 +360,9 @@ class Position(BasePosition):
|
||||
# check if to delete
|
||||
if self.position[stock_id]["amount"] < -1e-5:
|
||||
raise ValueError(
|
||||
"only have {} {}, require {}".format(self.position[stock_id]["amount"], stock_id, trade_amount)
|
||||
"only have {} {}, require {}".format(
|
||||
self.position[stock_id]["amount"] + trade_amount, stock_id, trade_amount
|
||||
)
|
||||
)
|
||||
|
||||
new_cash = trade_val - cost
|
||||
@@ -538,7 +538,7 @@ class InfPosition(BasePosition):
|
||||
def get_stock_amount_dict(self) -> Dict:
|
||||
raise NotImplementedError(f"InfPosition doesn't support get_stock_amount_dict")
|
||||
|
||||
def get_stock_weight_dict(self, only_stock: bool) -> Dict:
|
||||
def get_stock_weight_dict(self, only_stock: bool = False) -> Dict:
|
||||
raise NotImplementedError(f"InfPosition doesn't support get_stock_weight_dict")
|
||||
|
||||
def add_count_all(self, bar):
|
||||
|
||||
@@ -10,11 +10,8 @@ import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest.exchange import Exchange
|
||||
from .decision import IdxTradeRange
|
||||
from qlib.backtest.decision import BaseTradeDecision, Order, OrderDir
|
||||
from qlib.backtest.utils import TradeCalendarManager
|
||||
from .high_performance_ds import BaseOrderIndicator, PandasOrderIndicator, NumpyOrderIndicator, SingleMetric
|
||||
from ..data import D
|
||||
from .high_performance_ds import BaseOrderIndicator, NumpyOrderIndicator, SingleMetric
|
||||
from ..tests.config import CSI300_BENCH
|
||||
from ..utils.resam import get_higher_eq_freq_feature, resam_ts_data
|
||||
import qlib.utils.index_data as idd
|
||||
|
||||
@@ -388,13 +388,11 @@ class QlibConfig(Config):
|
||||
default_conf : str
|
||||
the default config template chosen by user: "server", "client"
|
||||
"""
|
||||
from .utils import set_log_with_config, get_module_logger, can_use_cache
|
||||
from .utils import set_log_with_config, get_module_logger, can_use_cache # pylint: disable=C0415
|
||||
|
||||
self.reset()
|
||||
|
||||
_logging_config = self.logging_config
|
||||
if "logging_config" in kwargs:
|
||||
_logging_config = kwargs["logging_config"]
|
||||
_logging_config = kwargs.get("logging_config", self.logging_config)
|
||||
|
||||
# set global config
|
||||
if _logging_config:
|
||||
@@ -433,11 +431,11 @@ class QlibConfig(Config):
|
||||
)
|
||||
|
||||
def register(self):
|
||||
from .utils import init_instance_by_config
|
||||
from .data.ops import register_all_ops
|
||||
from .data.data import register_all_wrappers
|
||||
from .workflow import R, QlibRecorder
|
||||
from .workflow.utils import experiment_exit_handler
|
||||
from .utils import init_instance_by_config # pylint: disable=C0415
|
||||
from .data.ops import register_all_ops # pylint: disable=C0415
|
||||
from .data.data import register_all_wrappers # pylint: disable=C0415
|
||||
from .workflow import R, QlibRecorder # pylint: disable=C0415
|
||||
from .workflow.utils import experiment_exit_handler # pylint: disable=C0415
|
||||
|
||||
register_all_ops(self)
|
||||
register_all_wrappers(self)
|
||||
@@ -454,7 +452,7 @@ class QlibConfig(Config):
|
||||
self._registered = True
|
||||
|
||||
def reset_qlib_version(self):
|
||||
import qlib
|
||||
import qlib # pylint: disable=C0415
|
||||
|
||||
reset_version = self.get("qlib_reset_version", None)
|
||||
if reset_version is not None:
|
||||
|
||||
@@ -7,8 +7,7 @@ import warnings
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.utils import init_instance_by_config
|
||||
from qlib.data.dataset import DatasetH, DataHandler
|
||||
from qlib.data.dataset import DatasetH
|
||||
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
@@ -16,7 +15,7 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
def _to_tensor(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return torch.tensor(x, dtype=torch.float, device=device)
|
||||
return torch.tensor(x, dtype=torch.float, device=device) # pylint: disable=E1101
|
||||
return x
|
||||
|
||||
|
||||
|
||||
@@ -5,9 +5,7 @@ from ...data.dataset.handler import DataHandlerLP
|
||||
from ...data.dataset.processor import Processor
|
||||
from ...utils import get_callable_kwargs
|
||||
from ...data.dataset import processor as processor_module
|
||||
from ...log import TimeInspector
|
||||
from inspect import getfullargspec
|
||||
import copy
|
||||
|
||||
|
||||
def check_transform_proc(proc_l, fit_start_time, fit_end_time):
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
|
||||
from ...log import TimeInspector
|
||||
from ...utils.serial import Serializable
|
||||
from ...data.dataset.processor import Processor, get_group_columns
|
||||
|
||||
|
||||
@@ -62,10 +59,10 @@ class ConfigSectionProcessor(Processor):
|
||||
|
||||
# Features
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^KLEN|^KLOW|^KUP")]
|
||||
df_focus[cols] = df_focus[cols].apply(lambda x: x ** 0.25).groupby(level="datetime").apply(_feature_norm)
|
||||
df_focus[cols] = df_focus[cols].apply(lambda x: x**0.25).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
cols = df_focus.columns[df_focus.columns.str.contains("^KLOW2|^KUP2")]
|
||||
df_focus[cols] = df_focus[cols].apply(lambda x: x ** 0.5).groupby(level="datetime").apply(_feature_norm)
|
||||
df_focus[cols] = df_focus[cols].apply(lambda x: x**0.5).groupby(level="datetime").apply(_feature_norm)
|
||||
|
||||
_cols = [
|
||||
"KMID",
|
||||
|
||||
@@ -4,8 +4,10 @@ Here is a batch of evaluation functions.
|
||||
The interface should be redesigned carefully in the future.
|
||||
"""
|
||||
import pandas as pd
|
||||
|
||||
from typing import Tuple
|
||||
from qlib import get_module_logger
|
||||
from qlib.utils.paral import complex_parallel, DelayedDict
|
||||
from joblib import Parallel, delayed
|
||||
|
||||
|
||||
def calc_long_short_prec(
|
||||
@@ -61,32 +63,6 @@ def calc_long_short_prec(
|
||||
return (l_dom.groupby(date_col).sum() / l_c), (s_dom.groupby(date_col).sum() / s_c)
|
||||
|
||||
|
||||
def calc_ic(pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False) -> Tuple[pd.Series, pd.Series]:
|
||||
"""calc_ic.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pred :
|
||||
pred
|
||||
label :
|
||||
label
|
||||
date_col :
|
||||
date_col
|
||||
|
||||
Returns
|
||||
-------
|
||||
(pd.Series, pd.Series)
|
||||
ic and rank ic
|
||||
"""
|
||||
df = pd.DataFrame({"pred": pred, "label": label})
|
||||
ic = df.groupby(date_col).apply(lambda df: df["pred"].corr(df["label"]))
|
||||
ric = df.groupby(date_col).apply(lambda df: df["pred"].corr(df["label"], method="spearman"))
|
||||
if dropna:
|
||||
return ic.dropna(), ric.dropna()
|
||||
else:
|
||||
return ic, ric
|
||||
|
||||
|
||||
def calc_long_short_return(
|
||||
pred: pd.Series,
|
||||
label: pd.Series,
|
||||
@@ -127,3 +103,105 @@ def calc_long_short_return(
|
||||
r_short = group.apply(lambda x: x.nsmallest(N(x), columns="pred").label.mean())
|
||||
r_avg = group.label.mean()
|
||||
return (r_long - r_short) / 2, r_avg
|
||||
|
||||
|
||||
def pred_autocorr(pred: pd.Series, lag=1, inst_col="instrument", date_col="datetime"):
|
||||
"""pred_autocorr.
|
||||
|
||||
Limitation:
|
||||
- If the datetime is not sequential densely, the correlation will be calulated based on adjacent dates. (some users may expected NaN)
|
||||
|
||||
:param pred: pd.Series with following format
|
||||
instrument datetime
|
||||
SH600000 2016-01-04 -0.000403
|
||||
2016-01-05 -0.000753
|
||||
2016-01-06 -0.021801
|
||||
2016-01-07 -0.065230
|
||||
2016-01-08 -0.062465
|
||||
:type pred: pd.Series
|
||||
:param lag:
|
||||
"""
|
||||
if isinstance(pred, pd.DataFrame):
|
||||
pred = pred.iloc[:, 0]
|
||||
get_module_logger("pred_autocorr").warning("Only the first column in {pred.columns} of `pred` is kept")
|
||||
pred_ustk = pred.sort_index().unstack(inst_col)
|
||||
corr_s = {}
|
||||
for (idx, cur), (_, prev) in zip(pred_ustk.iterrows(), pred_ustk.shift(lag).iterrows()):
|
||||
corr_s[idx] = cur.corr(prev)
|
||||
corr_s = pd.Series(corr_s).sort_index()
|
||||
return corr_s
|
||||
|
||||
|
||||
def pred_autocorr_all(pred_dict, n_jobs=-1, **kwargs):
|
||||
"""
|
||||
calculate auto correlation for pred_dict
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pred_dict : dict
|
||||
A dict like {<method_name>: <prediction>}
|
||||
kwargs :
|
||||
all these arguments will be passed into pred_autocorr
|
||||
"""
|
||||
ac_dict = {}
|
||||
for k, pred in pred_dict.items():
|
||||
ac_dict[k] = delayed(pred_autocorr)(pred, **kwargs)
|
||||
return complex_parallel(Parallel(n_jobs=n_jobs, verbose=10), ac_dict)
|
||||
|
||||
|
||||
def calc_ic(pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False) -> (pd.Series, pd.Series):
|
||||
"""calc_ic.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pred :
|
||||
pred
|
||||
label :
|
||||
label
|
||||
date_col :
|
||||
date_col
|
||||
|
||||
Returns
|
||||
-------
|
||||
(pd.Series, pd.Series)
|
||||
ic and rank ic
|
||||
"""
|
||||
df = pd.DataFrame({"pred": pred, "label": label})
|
||||
ic = df.groupby(date_col).apply(lambda df: df["pred"].corr(df["label"]))
|
||||
ric = df.groupby(date_col).apply(lambda df: df["pred"].corr(df["label"], method="spearman"))
|
||||
if dropna:
|
||||
return ic.dropna(), ric.dropna()
|
||||
else:
|
||||
return ic, ric
|
||||
|
||||
|
||||
def calc_all_ic(pred_dict_all, label, date_col="datetime", dropna=False, n_jobs=-1):
|
||||
"""calc_all_ic.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pred_dict_all :
|
||||
A dict like {<method_name>: <prediction>}
|
||||
label:
|
||||
A pd.Series of label values
|
||||
|
||||
Returns
|
||||
-------
|
||||
{'Q2+IND_z': {'ic': <ic series like>
|
||||
2016-01-04 -0.057407
|
||||
...
|
||||
2020-05-28 0.183470
|
||||
2020-05-29 0.171393
|
||||
'ric': <rank ic series like>
|
||||
2016-01-04 -0.040888
|
||||
...
|
||||
2020-05-28 0.236665
|
||||
2020-05-29 0.183886
|
||||
}
|
||||
...}
|
||||
"""
|
||||
pred_all_ics = {}
|
||||
for k, pred in pred_dict_all.items():
|
||||
pred_all_ics[k] = DelayedDict(["ic", "ric"], delayed(calc_ic)(pred, label, date_col=date_col, dropna=dropna))
|
||||
pred_all_ics = complex_parallel(Parallel(n_jobs=n_jobs, verbose=10), pred_all_ics)
|
||||
return pred_all_ics
|
||||
|
||||
@@ -5,12 +5,10 @@
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from scipy.stats import spearmanr, pearsonr
|
||||
|
||||
|
||||
from ..data import D
|
||||
|
||||
from collections import OrderedDict
|
||||
@@ -243,4 +241,4 @@ def get_rank_ic(a, b):
|
||||
|
||||
|
||||
def get_normal_ic(a, b):
|
||||
return pearsonr(a, b).correlation
|
||||
return pearsonr(a, b)[0]
|
||||
|
||||
@@ -1,24 +1,23 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from copy import deepcopy
|
||||
from qlib.data.dataset.utils import init_task_handler
|
||||
from qlib.utils.data import deepcopy_basic_type
|
||||
from qlib.contrib.torch import data_to_tensor
|
||||
from qlib.workflow.task.utils import TimeAdjuster
|
||||
from qlib.model.meta.task import MetaTask
|
||||
from typing import Dict, List, Union, Text, Tuple
|
||||
from qlib.data.dataset.handler import DataHandler
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.utils import auto_filter_kwargs, get_date_by_shift, init_instance_by_config
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from joblib import Parallel, delayed
|
||||
from qlib.model.meta.dataset import MetaTaskDataset
|
||||
from qlib.model.trainer import task_train, TrainerR
|
||||
from qlib.data.dataset import DatasetH
|
||||
from tqdm.auto import tqdm
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from copy import deepcopy
|
||||
from joblib import Parallel, delayed # pylint: disable=E0401
|
||||
from typing import Dict, List, Union, Text, Tuple
|
||||
from qlib.data.dataset.utils import init_task_handler
|
||||
from qlib.data.dataset import DatasetH
|
||||
from qlib.contrib.torch import data_to_tensor
|
||||
from qlib.model.meta.task import MetaTask
|
||||
from qlib.model.meta.dataset import MetaTaskDataset
|
||||
from qlib.model.trainer import TrainerR
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.utils import auto_filter_kwargs, get_date_by_shift, init_instance_by_config
|
||||
from qlib.utils.data import deepcopy_basic_type
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.utils import TimeAdjuster
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
class InternalData:
|
||||
|
||||
@@ -1,28 +1,26 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from qlib.log import get_module_logger
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from qlib.model.meta.task import MetaTask
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch import optim
|
||||
from tqdm.auto import tqdm
|
||||
import collections
|
||||
import copy
|
||||
from typing import Union, List, Tuple, Dict
|
||||
from typing import Union, List
|
||||
|
||||
from ....data.dataset.weight import Reweighter
|
||||
from ....model.meta.dataset import MetaTaskDataset
|
||||
from ....model.meta.model import MetaModel, MetaTaskModel
|
||||
from ....model.meta.model import MetaTaskModel
|
||||
from ....workflow import R
|
||||
|
||||
from .utils import ICLoss
|
||||
from .dataset import MetaDatasetDS
|
||||
from qlib.contrib.meta.data_selection.net import PredNet
|
||||
from qlib.data.dataset.weight import Reweighter
|
||||
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.data.dataset.weight import Reweighter
|
||||
from qlib.model.meta.task import MetaTask
|
||||
from qlib.contrib.meta.data_selection.net import PredNet
|
||||
|
||||
logger = get_module_logger("data selection")
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from qlib.contrib.torch import data_to_tensor
|
||||
|
||||
|
||||
class ICLoss(nn.Module):
|
||||
def forward(self, pred, y, idx, skip_size=50):
|
||||
"""forward.
|
||||
FIXME:
|
||||
- Some times it will be a slightly different from the result from `pandas.corr()`
|
||||
- It may be caused by the precision problem of model;
|
||||
|
||||
:param pred:
|
||||
:param y:
|
||||
|
||||
@@ -160,7 +160,7 @@ class DEnsembleModel(Model, FeatureInt):
|
||||
h_avg = h.groupby("bins")["h_value"].mean()
|
||||
weights = pd.Series(np.zeros(N, dtype=float))
|
||||
for i_b, b in enumerate(h_avg.index):
|
||||
weights[h["bins"] == b] = 1.0 / (self.decay ** k_th * h_avg[i_b] + 0.1)
|
||||
weights[h["bins"] == b] = 1.0 / (self.decay**k_th * h_avg[i_b] + 0.1)
|
||||
return weights
|
||||
|
||||
def feature_selection(self, df_train, loss_values):
|
||||
|
||||
@@ -10,6 +10,7 @@ from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...model.interpret.base import LightGBMFInt
|
||||
from ...data.dataset.weight import Reweighter
|
||||
from qlib.workflow import R
|
||||
|
||||
|
||||
class LGBModel(ModelFT, LightGBMFInt):
|
||||
@@ -59,10 +60,12 @@ class LGBModel(ModelFT, LightGBMFInt):
|
||||
num_boost_round=None,
|
||||
early_stopping_rounds=None,
|
||||
verbose_eval=20,
|
||||
evals_result=dict(),
|
||||
evals_result=None,
|
||||
reweighter=None,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
if evals_result is None:
|
||||
evals_result = {} # in case of unsafety of Python default values
|
||||
ds_l = self._prepare_data(dataset, reweighter)
|
||||
ds, names = list(zip(*ds_l))
|
||||
self.model = lgb.train(
|
||||
@@ -76,10 +79,13 @@ class LGBModel(ModelFT, LightGBMFInt):
|
||||
),
|
||||
verbose_eval=verbose_eval,
|
||||
evals_result=evals_result,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
)
|
||||
for k in names:
|
||||
evals_result[k] = list(evals_result[k].values())[0]
|
||||
for key, val in evals_result[k].items():
|
||||
name = f"{key}.{k}"
|
||||
for epoch, m in enumerate(val):
|
||||
R.log_metrics(**{name.replace("@", "_"): m}, step=epoch)
|
||||
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if self.model is None:
|
||||
@@ -101,7 +107,7 @@ class LGBModel(ModelFT, LightGBMFInt):
|
||||
verbose level
|
||||
"""
|
||||
# Based on existing model and finetune by train more rounds
|
||||
dtrain, _ = self._prepare_data(dataset, reweighter)
|
||||
dtrain, _ = self._prepare_data(dataset, reweighter) # pylint: disable=W0632
|
||||
if dtrain.empty:
|
||||
raise ValueError("Empty data from dataset, please check your dataset config.")
|
||||
self.model = lgb.train(
|
||||
|
||||
@@ -58,7 +58,7 @@ class HFLGBModel(ModelFT, LightGBMFInt):
|
||||
"""
|
||||
Test the signal in high frequency test set
|
||||
"""
|
||||
if self.model == None:
|
||||
if self.model is None:
|
||||
raise ValueError("Model hasn't been trained yet")
|
||||
df_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
||||
df_test.dropna(inplace=True)
|
||||
@@ -111,7 +111,6 @@ class HFLGBModel(ModelFT, LightGBMFInt):
|
||||
early_stopping_rounds=50,
|
||||
verbose_eval=20,
|
||||
evals_result=dict(),
|
||||
**kwargs
|
||||
):
|
||||
dtrain, dvalid = self._prepare_data(dataset)
|
||||
self.model = lgb.train(
|
||||
@@ -123,7 +122,6 @@ class HFLGBModel(ModelFT, LightGBMFInt):
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
verbose_eval=verbose_eval,
|
||||
evals_result=evals_result,
|
||||
**kwargs
|
||||
)
|
||||
evals_result["train"] = list(evals_result["train"].values())[0]
|
||||
evals_result["valid"] = list(evals_result["valid"].values())[0]
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
import os
|
||||
from pdb import set_trace
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
|
||||
import copy
|
||||
from typing import Text, Union
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
@@ -182,11 +180,11 @@ class ADARNN(Model):
|
||||
continue
|
||||
|
||||
total_loss = torch.zeros(1).cuda()
|
||||
for i in range(len(index)):
|
||||
feature_s = list_feat[index[i][0]]
|
||||
feature_t = list_feat[index[i][1]]
|
||||
label_reg_s = list_label[index[i][0]]
|
||||
label_reg_t = list_label[index[i][1]]
|
||||
for i, n in enumerate(index):
|
||||
feature_s = list_feat[n[0]]
|
||||
feature_t = list_feat[n[1]]
|
||||
label_reg_s = list_label[n[0]]
|
||||
label_reg_t = list_label[n[1]]
|
||||
feature_all = torch.cat((feature_s, feature_t), 0)
|
||||
|
||||
if epoch < self.pre_epoch:
|
||||
@@ -410,7 +408,7 @@ class AdaRNN(nn.Module):
|
||||
in_size = hidden
|
||||
self.features = nn.Sequential(*features)
|
||||
|
||||
if use_bottleneck == True: # finance
|
||||
if use_bottleneck is True: # finance
|
||||
self.bottleneck = nn.Sequential(
|
||||
nn.Linear(n_hiddens[-1], bottleneck_width),
|
||||
nn.Linear(bottleneck_width, bottleneck_width),
|
||||
@@ -449,7 +447,7 @@ class AdaRNN(nn.Module):
|
||||
def forward_pre_train(self, x, len_win=0):
|
||||
out = self.gru_features(x)
|
||||
fea = out[0] # [2N,L,H]
|
||||
if self.use_bottleneck == True:
|
||||
if self.use_bottleneck is True:
|
||||
fea_bottleneck = self.bottleneck(fea[:, -1, :])
|
||||
fc_out = self.fc(fea_bottleneck).squeeze()
|
||||
else:
|
||||
@@ -458,8 +456,8 @@ class AdaRNN(nn.Module):
|
||||
out_list_all, out_weight_list = out[1], out[2]
|
||||
out_list_s, out_list_t = self.get_features(out_list_all)
|
||||
loss_transfer = torch.zeros((1,)).cuda()
|
||||
for i in range(len(out_list_s)):
|
||||
criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=out_list_s[i].shape[2])
|
||||
for i, n in enumerate(out_list_s):
|
||||
criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=n.shape[2])
|
||||
h_start = 0
|
||||
for j in range(h_start, self.len_seq, 1):
|
||||
i_start = j - len_win if j - len_win >= 0 else 0
|
||||
@@ -471,7 +469,7 @@ class AdaRNN(nn.Module):
|
||||
else 1 / (self.len_seq - h_start) * (2 * len_win + 1)
|
||||
)
|
||||
loss_transfer = loss_transfer + weight * criterion_transder.compute(
|
||||
out_list_s[i][:, j, :], out_list_t[i][:, k, :]
|
||||
n[:, j, :], out_list_t[i][:, k, :]
|
||||
)
|
||||
return fc_out, loss_transfer, out_weight_list
|
||||
|
||||
@@ -484,7 +482,7 @@ class AdaRNN(nn.Module):
|
||||
out, _ = self.features[i](x_input.float())
|
||||
x_input = out
|
||||
out_lis.append(out)
|
||||
if self.model_type == "AdaRNN" and predict == False:
|
||||
if self.model_type == "AdaRNN" and predict is False:
|
||||
out_gate = self.process_gate_weight(x_input, i)
|
||||
out_weight_list.append(out_gate)
|
||||
return out, out_lis, out_weight_list
|
||||
@@ -524,10 +522,10 @@ class AdaRNN(nn.Module):
|
||||
else:
|
||||
weight = weight_mat
|
||||
dist_mat = torch.zeros(self.num_layers, self.len_seq).cuda()
|
||||
for i in range(len(out_list_s)):
|
||||
criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=out_list_s[i].shape[2])
|
||||
for i, n in enumerate(out_list_s):
|
||||
criterion_transder = TransferLoss(loss_type=self.trans_loss, input_dim=n.shape[2])
|
||||
for j in range(self.len_seq):
|
||||
loss_trans = criterion_transder.compute(out_list_s[i][:, j, :], out_list_t[i][:, j, :])
|
||||
loss_trans = criterion_transder.compute(n[:, j, :], out_list_t[i][:, j, :])
|
||||
loss_transfer = loss_transfer + weight[i, j] * loss_trans
|
||||
dist_mat[i, j] = loss_trans
|
||||
return fc_out, loss_transfer, dist_mat, weight
|
||||
@@ -546,7 +544,7 @@ class AdaRNN(nn.Module):
|
||||
def predict(self, x):
|
||||
out = self.gru_features(x, predict=True)
|
||||
fea = out[0]
|
||||
if self.use_bottleneck == True:
|
||||
if self.use_bottleneck is True:
|
||||
fea_bottleneck = self.bottleneck(fea[:, -1, :])
|
||||
fc_out = self.fc(fea_bottleneck).squeeze()
|
||||
else:
|
||||
@@ -572,12 +570,12 @@ class TransferLoss:
|
||||
Returns:
|
||||
[tensor] -- transfer loss
|
||||
"""
|
||||
if self.loss_type == "mmd_lin" or self.loss_type == "mmd":
|
||||
if self.loss_type in ("mmd_lin", "mmd"):
|
||||
mmdloss = MMD_loss(kernel_type="linear")
|
||||
loss = mmdloss(X, Y)
|
||||
elif self.loss_type == "coral":
|
||||
loss = CORAL(X, Y)
|
||||
elif self.loss_type == "cosine" or self.loss_type == "cos":
|
||||
elif self.loss_type in ("cosine", "cos"):
|
||||
loss = 1 - cosine(X, Y)
|
||||
elif self.loss_type == "kl":
|
||||
loss = kl_div(X, Y)
|
||||
@@ -684,9 +682,9 @@ class MMD_loss(nn.Module):
|
||||
if fix_sigma:
|
||||
bandwidth = fix_sigma
|
||||
else:
|
||||
bandwidth = torch.sum(L2_distance.data) / (n_samples ** 2 - n_samples)
|
||||
bandwidth = torch.sum(L2_distance.data) / (n_samples**2 - n_samples)
|
||||
bandwidth /= kernel_mul ** (kernel_num // 2)
|
||||
bandwidth_list = [bandwidth * (kernel_mul ** i) for i in range(kernel_num)]
|
||||
bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
|
||||
kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
|
||||
return sum(kernel_val)
|
||||
|
||||
|
||||
@@ -20,7 +20,6 @@ from qlib.contrib.model.pytorch_lstm import LSTMModel
|
||||
from qlib.contrib.model.pytorch_utils import count_parameters
|
||||
from qlib.data.dataset import DatasetH
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
from qlib.data.dataset.processor import CSRankNorm
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.model.base import Model
|
||||
from qlib.utils import get_or_create_path
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
@@ -150,7 +149,7 @@ class ALSTM(Model):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
@@ -312,8 +311,8 @@ class ALSTMModel(nn.Module):
|
||||
def _build_model(self):
|
||||
try:
|
||||
klass = getattr(nn, self.rnn_type.upper())
|
||||
except:
|
||||
raise ValueError("unknown rnn_type `%s`" % self.rnn_type)
|
||||
except Exception as e:
|
||||
raise ValueError("unknown rnn_type `%s`" % self.rnn_type) from e
|
||||
self.net = nn.Sequential()
|
||||
self.net.add_module("fc_in", nn.Linear(in_features=self.input_size, out_features=self.hid_size))
|
||||
self.net.add_module("act", nn.Tanh())
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
@@ -20,7 +19,7 @@ from torch.utils.data import DataLoader
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH, TSDatasetH
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...model.utils import ConcatDataset
|
||||
from ...data.dataset.weight import Reweighter
|
||||
@@ -160,7 +159,7 @@ class ALSTM(Model):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
@@ -320,8 +319,8 @@ class ALSTMModel(nn.Module):
|
||||
def _build_model(self):
|
||||
try:
|
||||
klass = getattr(nn, self.rnn_type.upper())
|
||||
except:
|
||||
raise ValueError("unknown rnn_type `%s`" % self.rnn_type)
|
||||
except Exception as e:
|
||||
raise ValueError("unknown rnn_type `%s`" % self.rnn_type) from e
|
||||
self.net = nn.Sequential()
|
||||
self.net.add_module("fc_in", nn.Linear(in_features=self.input_size, out_features=self.hid_size))
|
||||
self.net.add_module("act", nn.Tanh())
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
@@ -158,7 +157,7 @@ class GATs(Model):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
@@ -263,7 +262,9 @@ class GATs(Model):
|
||||
pretrained_model.load_state_dict(torch.load(self.model_path, map_location=self.device))
|
||||
|
||||
model_dict = self.GAT_model.state_dict()
|
||||
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
|
||||
pretrained_dict = {
|
||||
k: v for k, v in pretrained_model.state_dict().items() if k in model_dict # pylint: disable=E1135
|
||||
}
|
||||
model_dict.update(pretrained_dict)
|
||||
self.GAT_model.load_state_dict(model_dict)
|
||||
self.logger.info("Loading pretrained model Done...")
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
@@ -19,7 +18,6 @@ from torch.utils.data import Sampler
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...contrib.model.pytorch_lstm import LSTMModel
|
||||
from ...contrib.model.pytorch_gru import GRUModel
|
||||
@@ -178,7 +176,7 @@ class GATs(Model):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
@@ -279,7 +277,9 @@ class GATs(Model):
|
||||
pretrained_model.load_state_dict(torch.load(self.model_path, map_location=self.device))
|
||||
|
||||
model_dict = self.GAT_model.state_dict()
|
||||
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
|
||||
pretrained_dict = {
|
||||
k: v for k, v in pretrained_model.state_dict().items() if k in model_dict # pylint: disable=E1135
|
||||
}
|
||||
model_dict.update(pretrained_dict)
|
||||
self.GAT_model.load_state_dict(model_dict)
|
||||
self.logger.info("Loading pretrained model Done...")
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
@@ -150,7 +149,7 @@ class GRU(Model):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
@@ -19,7 +18,6 @@ from torch.utils.data import DataLoader
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH, TSDatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...model.utils import ConcatDataset
|
||||
from ...data.dataset.weight import Reweighter
|
||||
@@ -159,7 +157,7 @@ class GRU(Model):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
@@ -17,11 +16,9 @@ from ...log import get_module_logger
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH, TSDatasetH
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from torch.nn.modules.container import ModuleList
|
||||
|
||||
@@ -102,7 +99,7 @@ class LocalformerModel(Model):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
@@ -18,9 +17,8 @@ import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH, TSDatasetH
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from torch.nn.modules.container import ModuleList
|
||||
|
||||
@@ -101,7 +99,7 @@ class LocalformerModel(Model):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
@@ -146,7 +145,7 @@ class LSTM(Model):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
@@ -18,7 +17,6 @@ import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH, TSDatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...model.utils import ConcatDataset
|
||||
from ...data.dataset.weight import Reweighter
|
||||
@@ -140,7 +138,7 @@ class LSTM(Model):
|
||||
loss = weight * (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
|
||||
def loss_fn(self, pred, label):
|
||||
def loss_fn(self, pred, label, weight):
|
||||
mask = ~torch.isnan(label)
|
||||
|
||||
if weight is None:
|
||||
@@ -155,8 +153,8 @@ class LSTM(Model):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask], weight=None)
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
|
||||
@@ -4,11 +4,13 @@
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from collections import defaultdict
|
||||
|
||||
import os
|
||||
import gc
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
from typing import Callable, Optional, Text, Union
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
|
||||
import torch
|
||||
@@ -20,9 +22,17 @@ from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...data.dataset.weight import Reweighter
|
||||
from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, get_or_create_path
|
||||
from ...utils import (
|
||||
auto_filter_kwargs,
|
||||
init_instance_by_config,
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
get_or_create_path,
|
||||
)
|
||||
from ...log import get_module_logger
|
||||
from ...workflow import R
|
||||
from qlib.contrib.meta.data_selection.utils import ICLoss
|
||||
from torch.nn import DataParallel
|
||||
|
||||
|
||||
class DNNModelPytorch(Model):
|
||||
@@ -49,9 +59,6 @@ class DNNModelPytorch(Model):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim=360,
|
||||
output_dim=1,
|
||||
layers=(256,),
|
||||
lr=0.001,
|
||||
max_steps=300,
|
||||
batch_size=2000,
|
||||
@@ -64,14 +71,23 @@ class DNNModelPytorch(Model):
|
||||
GPU=0,
|
||||
seed=None,
|
||||
weight_decay=0.0,
|
||||
**kwargs
|
||||
data_parall=False,
|
||||
scheduler: Optional[Union[Callable]] = "default", # when it is Callable, it accept one argument named optimizer
|
||||
init_model=None,
|
||||
eval_train_metric=False,
|
||||
pt_model_uri="qlib.contrib.model.pytorch_nn.Net",
|
||||
pt_model_kwargs={
|
||||
"input_dim": 360,
|
||||
"layers": (256,),
|
||||
},
|
||||
valid_key=DataHandlerLP.DK_L,
|
||||
# TODO: Infer Key is a more reasonable key. But it requires more detailed processing on label processing
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("DNNModelPytorch")
|
||||
self.logger.info("DNN pytorch version...")
|
||||
|
||||
# set hyper-parameters.
|
||||
self.layers = layers
|
||||
self.lr = lr
|
||||
self.max_steps = max_steps
|
||||
self.batch_size = batch_size
|
||||
@@ -81,41 +97,36 @@ class DNNModelPytorch(Model):
|
||||
self.lr_decay_steps = lr_decay_steps
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss_type = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
if isinstance(GPU, str):
|
||||
self.device = torch.device(GPU)
|
||||
else:
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
self.weight_decay = weight_decay
|
||||
self.data_parall = data_parall
|
||||
self.eval_train_metric = eval_train_metric
|
||||
self.valid_key = valid_key
|
||||
|
||||
self.best_step = None
|
||||
|
||||
self.logger.info(
|
||||
"DNN parameters setting:"
|
||||
"\nlayers : {}"
|
||||
"\nlr : {}"
|
||||
"\nmax_steps : {}"
|
||||
"\nbatch_size : {}"
|
||||
"\nearly_stop_rounds : {}"
|
||||
"\neval_steps : {}"
|
||||
"\nlr_decay : {}"
|
||||
"\nlr_decay_steps : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nseed : {}"
|
||||
"\ndevice : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nweight_decay : {}".format(
|
||||
layers,
|
||||
lr,
|
||||
max_steps,
|
||||
batch_size,
|
||||
early_stop_rounds,
|
||||
eval_steps,
|
||||
lr_decay,
|
||||
lr_decay_steps,
|
||||
optimizer,
|
||||
loss,
|
||||
seed,
|
||||
self.device,
|
||||
self.use_gpu,
|
||||
weight_decay,
|
||||
)
|
||||
f"\nlr : {lr}"
|
||||
f"\nmax_steps : {max_steps}"
|
||||
f"\nbatch_size : {batch_size}"
|
||||
f"\nearly_stop_rounds : {early_stop_rounds}"
|
||||
f"\neval_steps : {eval_steps}"
|
||||
f"\nlr_decay : {lr_decay}"
|
||||
f"\nlr_decay_steps : {lr_decay_steps}"
|
||||
f"\noptimizer : {optimizer}"
|
||||
f"\nloss_type : {loss}"
|
||||
f"\nseed : {seed}"
|
||||
f"\ndevice : {self.device}"
|
||||
f"\nuse_GPU : {self.use_gpu}"
|
||||
f"\nweight_decay : {weight_decay}"
|
||||
f"\nenable data parall : {self.data_parall}"
|
||||
f"\npt_model_uri: {pt_model_uri}"
|
||||
f"\npt_model_kwargs: {pt_model_kwargs}"
|
||||
)
|
||||
|
||||
if self.seed is not None:
|
||||
@@ -126,7 +137,14 @@ class DNNModelPytorch(Model):
|
||||
raise NotImplementedError("loss {} is not supported!".format(loss))
|
||||
self._scorer = mean_squared_error if loss == "mse" else roc_auc_score
|
||||
|
||||
self.dnn_model = Net(input_dim, output_dim, layers, loss=self.loss_type)
|
||||
if init_model is None:
|
||||
self.dnn_model = init_instance_by_config({"class": pt_model_uri, "kwargs": pt_model_kwargs})
|
||||
|
||||
if self.data_parall:
|
||||
self.dnn_model = DataParallel(self.dnn_model).to(self.device)
|
||||
else:
|
||||
self.dnn_model = init_model
|
||||
|
||||
self.logger.info("model:\n{:}".format(self.dnn_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.dnn_model)))
|
||||
|
||||
@@ -137,19 +155,24 @@ class DNNModelPytorch(Model):
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
# Reduce learning rate when loss has stopped decrease
|
||||
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||
self.train_optimizer,
|
||||
mode="min",
|
||||
factor=0.5,
|
||||
patience=10,
|
||||
verbose=True,
|
||||
threshold=0.0001,
|
||||
threshold_mode="rel",
|
||||
cooldown=0,
|
||||
min_lr=0.00001,
|
||||
eps=1e-08,
|
||||
)
|
||||
if scheduler == "default":
|
||||
# Reduce learning rate when loss has stopped decrease
|
||||
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
||||
self.train_optimizer,
|
||||
mode="min",
|
||||
factor=0.5,
|
||||
patience=10,
|
||||
verbose=True,
|
||||
threshold=0.0001,
|
||||
threshold_mode="rel",
|
||||
cooldown=0,
|
||||
min_lr=0.00001,
|
||||
eps=1e-08,
|
||||
)
|
||||
elif scheduler is None:
|
||||
self.scheduler = None
|
||||
else:
|
||||
self.scheduler = scheduler(optimizer=self.train_optimizer)
|
||||
|
||||
self.fitted = False
|
||||
self.dnn_model.to(self.device)
|
||||
@@ -166,40 +189,48 @@ class DNNModelPytorch(Model):
|
||||
save_path=None,
|
||||
reweighter=None,
|
||||
):
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
)
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
has_valid = "valid" in dataset.segments
|
||||
segments = ["train", "valid"]
|
||||
vars = ["x", "y", "w"]
|
||||
all_df = defaultdict(dict) # x_train, x_valid y_train, y_valid w_train, w_valid
|
||||
all_t = defaultdict(dict) # tensors
|
||||
for seg in segments:
|
||||
if seg in dataset.segments:
|
||||
# df_train df_valid
|
||||
df = dataset.prepare(
|
||||
seg, col_set=["feature", "label"], data_key=self.valid_key if seg == "valid" else DataHandlerLP.DK_L
|
||||
)
|
||||
all_df["x"][seg] = df["feature"]
|
||||
all_df["y"][seg] = df["label"].copy() # We have to use copy to remove the reference to release mem
|
||||
if reweighter is None:
|
||||
all_df["w"][seg] = pd.DataFrame(np.ones_like(all_df["y"][seg].values), index=df.index)
|
||||
elif isinstance(reweighter, Reweighter):
|
||||
all_df["w"][seg] = pd.DataFrame(reweighter.reweight(df))
|
||||
else:
|
||||
raise ValueError("Unsupported reweighter type.")
|
||||
|
||||
if reweighter is None:
|
||||
w_train = pd.DataFrame(np.ones_like(y_train.values), index=y_train.index)
|
||||
w_valid = pd.DataFrame(np.ones_like(y_valid.values), index=y_valid.index)
|
||||
elif isinstance(reweighter, Reweighter):
|
||||
w_train = pd.DataFrame(reweighter.reweight(df_train))
|
||||
w_valid = pd.DataFrame(reweighter.reweight(df_valid))
|
||||
else:
|
||||
raise ValueError("Unsupported reweighter type.")
|
||||
# get tensors
|
||||
for v in vars:
|
||||
all_t[v][seg] = torch.from_numpy(all_df[v][seg].values).float()
|
||||
# if seg == "valid": # accelerate the eval of validation
|
||||
all_t[v][seg] = all_t[v][seg].to(self.device) # This will consume a lot of memory !!!!
|
||||
|
||||
evals_result[seg] = []
|
||||
# free memory
|
||||
del df
|
||||
del all_df["x"]
|
||||
gc.collect()
|
||||
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_loss = np.inf
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
# train
|
||||
self.logger.info("training...")
|
||||
self.fitted = True
|
||||
# return
|
||||
# prepare training data
|
||||
x_train_values = torch.from_numpy(x_train.values).float()
|
||||
y_train_values = torch.from_numpy(y_train.values).float()
|
||||
w_train_values = torch.from_numpy(w_train.values).float()
|
||||
train_num = y_train_values.shape[0]
|
||||
# prepare validation data
|
||||
x_val_auto = torch.from_numpy(x_valid.values).float().to(self.device)
|
||||
y_val_auto = torch.from_numpy(y_valid.values).float().to(self.device)
|
||||
w_val_auto = torch.from_numpy(w_valid.values).float().to(self.device)
|
||||
train_num = all_t["y"]["train"].shape[0]
|
||||
|
||||
for step in range(1, self.max_steps + 1):
|
||||
if stop_steps >= self.early_stop_rounds:
|
||||
@@ -210,9 +241,9 @@ class DNNModelPytorch(Model):
|
||||
self.dnn_model.train()
|
||||
self.train_optimizer.zero_grad()
|
||||
choice = np.random.choice(train_num, self.batch_size)
|
||||
x_batch_auto = x_train_values[choice].to(self.device)
|
||||
y_batch_auto = y_train_values[choice].to(self.device)
|
||||
w_batch_auto = w_train_values[choice].to(self.device)
|
||||
x_batch_auto = all_t["x"]["train"][choice].to(self.device)
|
||||
y_batch_auto = all_t["y"]["train"][choice].to(self.device)
|
||||
w_batch_auto = all_t["w"]["train"][choice].to(self.device)
|
||||
|
||||
# forward
|
||||
preds = self.dnn_model(x_batch_auto)
|
||||
@@ -226,44 +257,84 @@ class DNNModelPytorch(Model):
|
||||
train_loss += loss.val
|
||||
# for evert `eval_steps` steps or at the last steps, we will evaluate the model.
|
||||
if step % self.eval_steps == 0 or step == self.max_steps:
|
||||
stop_steps += 1
|
||||
train_loss /= self.eval_steps
|
||||
if has_valid:
|
||||
stop_steps += 1
|
||||
train_loss /= self.eval_steps
|
||||
|
||||
with torch.no_grad():
|
||||
self.dnn_model.eval()
|
||||
loss_val = AverageMeter()
|
||||
with torch.no_grad():
|
||||
self.dnn_model.eval()
|
||||
|
||||
# forward
|
||||
preds = self.dnn_model(x_val_auto)
|
||||
cur_loss_val = self.get_loss(preds, w_val_auto, y_val_auto, self.loss_type)
|
||||
loss_val.update(cur_loss_val.item())
|
||||
R.log_metrics(val_loss=loss_val.val, step=step)
|
||||
if verbose:
|
||||
self.logger.info(
|
||||
"[Epoch {}]: train_loss {:.6f}, valid_loss {:.6f}".format(step, train_loss, loss_val.val)
|
||||
)
|
||||
evals_result["train"].append(train_loss)
|
||||
evals_result["valid"].append(loss_val.val)
|
||||
if loss_val.val < best_loss:
|
||||
# forward
|
||||
preds = self._nn_predict(all_t["x"]["valid"], return_cpu=False)
|
||||
cur_loss_val = self.get_loss(preds, all_t["w"]["valid"], all_t["y"]["valid"], self.loss_type)
|
||||
loss_val = cur_loss_val.item()
|
||||
metric_val = (
|
||||
self.get_metric(
|
||||
preds.reshape(-1), all_t["y"]["valid"].reshape(-1), all_df["y"]["valid"].index
|
||||
)
|
||||
.detach()
|
||||
.cpu()
|
||||
.numpy()
|
||||
.item()
|
||||
)
|
||||
R.log_metrics(val_loss=loss_val, step=step)
|
||||
R.log_metrics(val_metric=metric_val, step=step)
|
||||
|
||||
if self.eval_train_metric:
|
||||
metric_train = (
|
||||
self.get_metric(
|
||||
self._nn_predict(all_t["x"]["train"], return_cpu=False),
|
||||
all_t["y"]["train"].reshape(-1),
|
||||
all_df["y"]["train"].index,
|
||||
)
|
||||
.detach()
|
||||
.cpu()
|
||||
.numpy()
|
||||
.item()
|
||||
)
|
||||
R.log_metrics(train_metric=metric_train, step=step)
|
||||
else:
|
||||
metric_train = np.nan
|
||||
if verbose:
|
||||
self.logger.info(
|
||||
"\tvalid loss update from {:.6f} to {:.6f}, save checkpoint.".format(
|
||||
best_loss, loss_val.val
|
||||
)
|
||||
f"[Step {step}]: train_loss {train_loss:.6f}, valid_loss {loss_val:.6f}, train_metric {metric_train:.6f}, valid_metric {metric_val:.6f}"
|
||||
)
|
||||
best_loss = loss_val.val
|
||||
stop_steps = 0
|
||||
torch.save(self.dnn_model.state_dict(), save_path)
|
||||
train_loss = 0
|
||||
# update learning rate
|
||||
self.scheduler.step(cur_loss_val)
|
||||
evals_result["train"].append(train_loss)
|
||||
evals_result["valid"].append(loss_val)
|
||||
if loss_val < best_loss:
|
||||
if verbose:
|
||||
self.logger.info(
|
||||
"\tvalid loss update from {:.6f} to {:.6f}, save checkpoint.".format(
|
||||
best_loss, loss_val
|
||||
)
|
||||
)
|
||||
best_loss = loss_val
|
||||
self.best_step = step
|
||||
R.log_metrics(best_step=self.best_step, step=step)
|
||||
stop_steps = 0
|
||||
torch.save(self.dnn_model.state_dict(), save_path)
|
||||
train_loss = 0
|
||||
# update learning rate
|
||||
if self.scheduler is not None:
|
||||
auto_filter_kwargs(self.scheduler.step, warning=False)(metrics=cur_loss_val, epoch=step)
|
||||
R.log_metrics(lr=self.get_lr(), step=step)
|
||||
else:
|
||||
# retraining mode
|
||||
if self.scheduler is not None:
|
||||
self.scheduler.step(epoch=step)
|
||||
|
||||
# restore the optimal parameters after training
|
||||
self.dnn_model.load_state_dict(torch.load(save_path, map_location=self.device))
|
||||
if has_valid:
|
||||
# restore the optimal parameters after training
|
||||
self.dnn_model.load_state_dict(torch.load(save_path, map_location=self.device))
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_lr(self):
|
||||
assert len(self.train_optimizer.param_groups) == 1
|
||||
return self.train_optimizer.param_groups[0]["lr"]
|
||||
|
||||
def get_loss(self, pred, w, target, loss_type):
|
||||
pred, w, target = pred.reshape(-1), w.reshape(-1), target.reshape(-1)
|
||||
if loss_type == "mse":
|
||||
sqr_loss = torch.mul(pred - target, pred - target)
|
||||
loss = torch.mul(sqr_loss, w).mean()
|
||||
@@ -274,15 +345,40 @@ class DNNModelPytorch(Model):
|
||||
else:
|
||||
raise NotImplementedError("loss {} is not supported!".format(loss_type))
|
||||
|
||||
def get_metric(self, pred, target, index):
|
||||
# NOTE: the order of the index must follow <datetime, instrument> sorted order
|
||||
return -ICLoss()(pred, target, index) # pylint: disable=E1130
|
||||
|
||||
def _nn_predict(self, data, return_cpu=True):
|
||||
"""Reusing predicting NN.
|
||||
Scenarios
|
||||
1) test inference (data may come from CPU and expect the output data is on CPU)
|
||||
2) evaluation on training (data may come from GPU)
|
||||
"""
|
||||
if not isinstance(data, torch.Tensor):
|
||||
if isinstance(data, pd.DataFrame):
|
||||
data = data.values
|
||||
data = torch.Tensor(data)
|
||||
data = data.to(self.device)
|
||||
preds = []
|
||||
self.dnn_model.eval()
|
||||
with torch.no_grad():
|
||||
batch_size = 8096
|
||||
for i in range(0, len(data), batch_size):
|
||||
x = data[i : i + batch_size]
|
||||
preds.append(self.dnn_model(x.to(self.device)).detach().reshape(-1))
|
||||
if return_cpu:
|
||||
preds = np.concatenate([pr.cpu().numpy() for pr in preds])
|
||||
else:
|
||||
preds = torch.cat(preds, axis=0)
|
||||
return preds
|
||||
|
||||
def predict(self, dataset: DatasetH, segment: Union[Text, slice] = "test"):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test_pd = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
x_test = torch.from_numpy(x_test_pd.values).float().to(self.device)
|
||||
self.dnn_model.eval()
|
||||
with torch.no_grad():
|
||||
preds = self.dnn_model(x_test).detach().cpu().numpy()
|
||||
return pd.Series(np.squeeze(preds), index=x_test_pd.index)
|
||||
preds = self._nn_predict(x_test_pd)
|
||||
return pd.Series(preds.reshape(-1), index=x_test_pd.index)
|
||||
|
||||
def save(self, filename, **kwargs):
|
||||
with save_multiple_parts_file(filename) as model_dir:
|
||||
@@ -322,15 +418,22 @@ class AverageMeter:
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self, input_dim, output_dim, layers=(256, 512, 768, 512, 256, 128, 64), loss="mse"):
|
||||
def __init__(self, input_dim, output_dim=1, layers=(256,), act="LeakyReLU"):
|
||||
super(Net, self).__init__()
|
||||
|
||||
layers = [input_dim] + list(layers)
|
||||
dnn_layers = []
|
||||
drop_input = nn.Dropout(0.05)
|
||||
dnn_layers.append(drop_input)
|
||||
hidden_units = input_dim
|
||||
for i, (_input_dim, hidden_units) in enumerate(zip(layers[:-1], layers[1:])):
|
||||
fc = nn.Linear(_input_dim, hidden_units)
|
||||
activation = nn.LeakyReLU(negative_slope=0.1, inplace=False)
|
||||
if act == "LeakyReLU":
|
||||
activation = nn.LeakyReLU(negative_slope=0.1, inplace=False)
|
||||
elif act == "SiLU":
|
||||
activation = nn.SiLU()
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
bn = nn.BatchNorm1d(hidden_units)
|
||||
seq = nn.Sequential(fc, bn, activation)
|
||||
dnn_layers.append(seq)
|
||||
@@ -338,7 +441,7 @@ class Net(nn.Module):
|
||||
dnn_layers.append(drop_input)
|
||||
fc = nn.Linear(hidden_units, output_dim)
|
||||
dnn_layers.append(fc)
|
||||
# optimizer
|
||||
# optimizer # pylint: disable=W0631
|
||||
self.dnn_layers = nn.ModuleList(dnn_layers)
|
||||
self._weight_init()
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
@@ -435,7 +434,7 @@ class SFM(Model):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
@@ -378,7 +377,7 @@ class TabnetModel(Model):
|
||||
|
||||
def metric_fn(self, pred, label):
|
||||
mask = torch.isfinite(label)
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@ from ...log import get_module_logger
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
@@ -158,7 +157,7 @@ class TCN(Model):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
@@ -158,7 +158,7 @@ class TCN(Model):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
@@ -5,20 +5,12 @@
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
import random
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
from ...utils import get_or_create_path
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -263,7 +255,7 @@ class TCTS(Model):
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
x_test, y_test = df_test["feature"], df_test["label"]
|
||||
|
||||
if save_path == None:
|
||||
if save_path is None:
|
||||
save_path = get_or_create_path(save_path)
|
||||
best_loss = np.inf
|
||||
while best_loss > self.lowest_valid_performance:
|
||||
|
||||
@@ -6,10 +6,8 @@ import os
|
||||
import copy
|
||||
import math
|
||||
import json
|
||||
import collections
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import seaborn as sns
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import torch
|
||||
@@ -24,7 +22,6 @@ except ImportError:
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from qlib.utils import get_or_create_path
|
||||
from qlib.constant import EPS
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.model.base import Model
|
||||
@@ -745,7 +742,7 @@ def evaluate(pred):
|
||||
score = pred.score
|
||||
label = pred.label
|
||||
diff = score - label
|
||||
MSE = (diff ** 2).mean()
|
||||
MSE = (diff**2).mean()
|
||||
MAE = (diff.abs()).mean()
|
||||
IC = score.corr(label, method="spearman")
|
||||
return {"MSE": MSE, "MAE": MAE, "IC": IC}
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Text, Union
|
||||
@@ -17,11 +16,9 @@ from ...log import get_module_logger
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH, TSDatasetH
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
|
||||
# qrun examples/benchmarks/Transformer/workflow_config_transformer_Alpha360.yaml ”
|
||||
@@ -101,7 +98,7 @@ class TransformerModel(Model):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
@@ -5,7 +5,6 @@
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
@@ -18,9 +17,8 @@ import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH, TSDatasetH
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
|
||||
|
||||
@@ -98,7 +96,7 @@ class TransformerModel(Model):
|
||||
|
||||
mask = torch.isfinite(label)
|
||||
|
||||
if self.metric == "" or self.metric == "loss":
|
||||
if self.metric in ("", "loss"):
|
||||
return -self.loss_fn(pred[mask], label[mask])
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
@@ -26,12 +26,12 @@ def count_parameters(models_or_parameters, unit="m"):
|
||||
else:
|
||||
counts = sum(v.numel() for v in models_or_parameters)
|
||||
unit = unit.lower()
|
||||
if unit == "kb" or unit == "k":
|
||||
counts /= 2 ** 10
|
||||
elif unit == "mb" or unit == "m":
|
||||
counts /= 2 ** 20
|
||||
elif unit == "gb" or unit == "g":
|
||||
counts /= 2 ** 30
|
||||
if unit in ("kb", "k"):
|
||||
counts /= 2**10
|
||||
elif unit in ("mb", "m"):
|
||||
counts /= 2**20
|
||||
elif unit in ("gb", "g"):
|
||||
counts /= 2**30
|
||||
elif unit is not None:
|
||||
raise ValueError("Unknown unit: {:}".format(unit))
|
||||
return counts
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# MIT License
|
||||
# Copyright (c) 2018 CMU Locus Lab
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
@@ -56,7 +55,7 @@ class TemporalConvNet(nn.Module):
|
||||
layers = []
|
||||
num_levels = len(num_channels)
|
||||
for i in range(num_levels):
|
||||
dilation_size = 2 ** i
|
||||
dilation_size = 2**i
|
||||
in_channels = num_inputs if i == 0 else num_channels[i - 1]
|
||||
out_channels = num_channels[i]
|
||||
layers += [
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
# pylint: skip-file
|
||||
|
||||
'''
|
||||
TODO:
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
import yaml
|
||||
import pathlib
|
||||
import pandas as pd
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
import random
|
||||
import pandas as pd
|
||||
from ...data import D
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
import fire
|
||||
import pandas as pd
|
||||
import pathlib
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
import logging
|
||||
|
||||
from ...log import get_module_logger
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
import pathlib
|
||||
import pickle
|
||||
import yaml
|
||||
|
||||
@@ -1,12 +1,8 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from pathlib import Path
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from datetime import datetime
|
||||
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
from qlib.data.cache import H
|
||||
from qlib.data.data import Cal
|
||||
from qlib.data.ops import ElemOperator
|
||||
|
||||
@@ -34,7 +34,7 @@ def _group_return(pred_label: pd.DataFrame = None, reverse: bool = False, N: int
|
||||
{
|
||||
"Group%d"
|
||||
% (i + 1): pred_label_drop.groupby(level="datetime")["label"].apply(
|
||||
lambda x: x[len(x) // N * i : len(x) // N * (i + 1)].mean()
|
||||
lambda x: x[len(x) // N * i : len(x) // N * (i + 1)].mean() # pylint: disable=W0640
|
||||
)
|
||||
for i in range(N)
|
||||
}
|
||||
|
||||
7
qlib/contrib/report/data/__init__.py
Normal file
7
qlib/contrib/report/data/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
This module is designed to analysis data
|
||||
|
||||
"""
|
||||
202
qlib/contrib/report/data/ana.py
Normal file
202
qlib/contrib/report/data/ana.py
Normal file
@@ -0,0 +1,202 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from qlib.contrib.report.data.base import FeaAnalyser
|
||||
from qlib.contrib.report.utils import sub_fig_generator
|
||||
from qlib.utils.paral import datetime_groupby_apply
|
||||
from qlib.contrib.eva.alpha import pred_autocorr_all
|
||||
from loguru import logger
|
||||
import seaborn as sns
|
||||
|
||||
DT_COL_NAME = "datetime"
|
||||
|
||||
|
||||
class CombFeaAna(FeaAnalyser):
|
||||
"""
|
||||
Combine the sub feature analysers and plot then in a single graph
|
||||
"""
|
||||
|
||||
def __init__(self, dataset: pd.DataFrame, *fea_ana_cls):
|
||||
if len(fea_ana_cls) <= 1:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
self._fea_ana_l = [fcls(dataset) for fcls in fea_ana_cls]
|
||||
super().__init__(dataset=dataset)
|
||||
|
||||
def skip(self, col):
|
||||
return np.all(list(map(lambda fa: fa.skip(col), self._fea_ana_l)))
|
||||
|
||||
def calc_stat_values(self):
|
||||
"""The statistics of features are finished in the underlying analysers"""
|
||||
|
||||
def plot_all(self, *args, **kwargs):
|
||||
|
||||
ax_gen = iter(sub_fig_generator(row_n=len(self._fea_ana_l), *args, **kwargs))
|
||||
|
||||
for col in self._dataset:
|
||||
if not self.skip(col):
|
||||
axes = next(ax_gen)
|
||||
for fa, ax in zip(self._fea_ana_l, axes):
|
||||
if not fa.skip(col):
|
||||
fa.plot_single(col, ax)
|
||||
ax.set_xlabel("")
|
||||
ax.set_title("")
|
||||
axes[0].set_title(col)
|
||||
|
||||
|
||||
class NumFeaAnalyser(FeaAnalyser):
|
||||
def skip(self, col):
|
||||
is_obj = np.issubdtype(self._dataset[col], np.dtype("O"))
|
||||
if is_obj:
|
||||
logger.info(f"{col} is not numeric and is skipped")
|
||||
return is_obj
|
||||
|
||||
|
||||
class ValueCNT(FeaAnalyser):
|
||||
def __init__(self, dataset: pd.DataFrame, ratio=False):
|
||||
self.ratio = ratio
|
||||
super().__init__(dataset)
|
||||
|
||||
def calc_stat_values(self):
|
||||
self._val_cnt = {}
|
||||
for col, item in self._dataset.items():
|
||||
if not super().skip(col):
|
||||
self._val_cnt[col] = item.groupby(DT_COL_NAME).apply(lambda s: len(s.unique()))
|
||||
self._val_cnt = pd.DataFrame(self._val_cnt)
|
||||
if self.ratio:
|
||||
self._val_cnt = self._val_cnt.div(self._dataset.groupby(DT_COL_NAME).size(), axis=0)
|
||||
|
||||
# TODO: transfer this feature to other analysers
|
||||
ymin, ymax = self._val_cnt.min().min(), self._val_cnt.max().max()
|
||||
self.ylim = (ymin - 0.05 * (ymax - ymin), ymax + 0.05 * (ymax - ymin))
|
||||
|
||||
def plot_single(self, col, ax):
|
||||
self._val_cnt[col].plot(ax=ax, title=col, ylim=self.ylim)
|
||||
ax.set_xlabel("")
|
||||
|
||||
|
||||
class FeaDistAna(NumFeaAnalyser):
|
||||
def plot_single(self, col, ax):
|
||||
sns.histplot(self._dataset[col], ax=ax, kde=False, bins=100)
|
||||
ax.set_xlabel("")
|
||||
ax.set_title(col)
|
||||
|
||||
|
||||
class FeaInfAna(NumFeaAnalyser):
|
||||
def calc_stat_values(self):
|
||||
self._inf_cnt = {}
|
||||
for col, item in self._dataset.items():
|
||||
if not super().skip(col):
|
||||
self._inf_cnt[col] = item.apply(np.isinf).astype(np.int).groupby(DT_COL_NAME).sum()
|
||||
self._inf_cnt = pd.DataFrame(self._inf_cnt)
|
||||
|
||||
def skip(self, col):
|
||||
return (col not in self._inf_cnt) or (self._inf_cnt[col].sum() == 0)
|
||||
|
||||
def plot_single(self, col, ax):
|
||||
self._inf_cnt[col].plot(ax=ax, title=col)
|
||||
ax.set_xlabel("")
|
||||
|
||||
|
||||
class FeaNanAna(FeaAnalyser):
|
||||
def calc_stat_values(self):
|
||||
self._nan_cnt = self._dataset.isna().groupby(DT_COL_NAME).sum()
|
||||
|
||||
def skip(self, col):
|
||||
return (col not in self._nan_cnt) or (self._nan_cnt[col].sum() == 0)
|
||||
|
||||
def plot_single(self, col, ax):
|
||||
self._nan_cnt[col].plot(ax=ax, title=col)
|
||||
ax.set_xlabel("")
|
||||
|
||||
|
||||
class FeaNanAnaRatio(FeaAnalyser):
|
||||
def calc_stat_values(self):
|
||||
self._nan_cnt = self._dataset.isna().groupby(DT_COL_NAME).sum()
|
||||
self._total_cnt = self._dataset.groupby(DT_COL_NAME).size()
|
||||
|
||||
def skip(self, col):
|
||||
return (col not in self._nan_cnt) or (self._nan_cnt[col].sum() == 0)
|
||||
|
||||
def plot_single(self, col, ax):
|
||||
(self._nan_cnt[col] / self._total_cnt).plot(ax=ax, title=col)
|
||||
ax.set_xlabel("")
|
||||
|
||||
|
||||
class FeaACAna(FeaAnalyser):
|
||||
"""Analysis the auto-correlation of features"""
|
||||
|
||||
def calc_stat_values(self):
|
||||
self._fea_corr = pred_autocorr_all(self._dataset.to_dict("series"))
|
||||
df = pd.DataFrame(self._fea_corr)
|
||||
ymin, ymax = df.min().min(), df.max().max()
|
||||
self.ylim = (ymin - 0.05 * (ymax - ymin), ymax + 0.05 * (ymax - ymin))
|
||||
|
||||
def plot_single(self, col, ax):
|
||||
self._fea_corr[col].plot(ax=ax, title=col, ylim=self.ylim)
|
||||
ax.set_xlabel("")
|
||||
|
||||
|
||||
class FeaSkewTurt(NumFeaAnalyser):
|
||||
def calc_stat_values(self):
|
||||
self._skew = datetime_groupby_apply(self._dataset, "skew", skip_group=True)
|
||||
self._kurt = datetime_groupby_apply(self._dataset, pd.DataFrame.kurt, skip_group=True)
|
||||
|
||||
def plot_single(self, col, ax):
|
||||
self._skew[col].plot(ax=ax, label="skew")
|
||||
ax.set_xlabel("")
|
||||
ax.set_ylabel("skew")
|
||||
ax.legend()
|
||||
|
||||
right_ax = ax.twinx()
|
||||
|
||||
self._kurt[col].plot(ax=right_ax, label="kurt", color="green")
|
||||
right_ax.set_xlabel("")
|
||||
right_ax.set_ylabel("kurt")
|
||||
|
||||
h1, l1 = ax.get_legend_handles_labels()
|
||||
h2, l2 = right_ax.get_legend_handles_labels()
|
||||
|
||||
ax.legend().set_visible(False)
|
||||
right_ax.legend(h1 + h2, l1 + l2)
|
||||
ax.set_title(col)
|
||||
|
||||
|
||||
class FeaMeanStd(NumFeaAnalyser):
|
||||
def calc_stat_values(self):
|
||||
self._std = self._dataset.groupby(DT_COL_NAME).std()
|
||||
self._mean = self._dataset.groupby(DT_COL_NAME).mean()
|
||||
|
||||
def plot_single(self, col, ax):
|
||||
self._mean[col].plot(ax=ax, label="mean")
|
||||
ax.set_xlabel("")
|
||||
ax.set_ylabel("mean")
|
||||
ax.legend()
|
||||
|
||||
right_ax = ax.twinx()
|
||||
|
||||
self._std[col].plot(ax=right_ax, label="std", color="green")
|
||||
right_ax.set_xlabel("")
|
||||
right_ax.set_ylabel("std")
|
||||
|
||||
h1, l1 = ax.get_legend_handles_labels()
|
||||
h2, l2 = right_ax.get_legend_handles_labels()
|
||||
|
||||
ax.legend().set_visible(False)
|
||||
right_ax.legend(h1 + h2, l1 + l2)
|
||||
ax.set_title(col)
|
||||
|
||||
|
||||
class RawFeaAna(FeaAnalyser):
|
||||
"""
|
||||
Motivation:
|
||||
- display the values without further analysis
|
||||
"""
|
||||
|
||||
def calc_stat_values(self):
|
||||
ymin, ymax = self._dataset.min().min(), self._dataset.max().max()
|
||||
self.ylim = (ymin - 0.05 * (ymax - ymin), ymax + 0.05 * (ymax - ymin))
|
||||
|
||||
def plot_single(self, col, ax):
|
||||
self._dataset[col].plot(ax=ax, title=col, ylim=self.ylim)
|
||||
ax.set_xlabel("")
|
||||
36
qlib/contrib/report/data/base.py
Normal file
36
qlib/contrib/report/data/base.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
This module is responsible for analysing data
|
||||
|
||||
Assumptions
|
||||
- The analyse each feature individually
|
||||
|
||||
"""
|
||||
import pandas as pd
|
||||
from qlib.log import TimeInspector
|
||||
from qlib.contrib.report.utils import sub_fig_generator
|
||||
|
||||
|
||||
class FeaAnalyser:
|
||||
def __init__(self, dataset: pd.DataFrame):
|
||||
self._dataset = dataset
|
||||
with TimeInspector.logt("calc_stat_values"):
|
||||
self.calc_stat_values()
|
||||
|
||||
def calc_stat_values(self):
|
||||
pass
|
||||
|
||||
def plot_single(self, col, ax):
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
def skip(self, col):
|
||||
return False
|
||||
|
||||
def plot_all(self, *args, **kwargs):
|
||||
|
||||
ax_gen = iter(sub_fig_generator(*args, **kwargs))
|
||||
for col in self._dataset:
|
||||
if not self.skip(col):
|
||||
ax = next(ax_gen)
|
||||
self.plot_single(col, ax)
|
||||
@@ -282,8 +282,10 @@ class SubplotsGraph:
|
||||
if self._subplots_kwargs is None:
|
||||
self._init_subplots_kwargs()
|
||||
|
||||
self.__cols = self._subplots_kwargs.get("cols", 2)
|
||||
self.__rows = self._subplots_kwargs.get("rows", math.ceil(len(self._df.columns) / self.__cols))
|
||||
self.__cols = self._subplots_kwargs.get("cols", 2) # pylint: disable=W0238
|
||||
self.__rows = self._subplots_kwargs.get( # pylint: disable=W0238
|
||||
"rows", math.ceil(len(self._df.columns) / self.__cols)
|
||||
)
|
||||
|
||||
self._sub_graph_data = sub_graph_data
|
||||
if self._sub_graph_data is None:
|
||||
|
||||
45
qlib/contrib/report/utils.py
Normal file
45
qlib/contrib/report/utils.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
def sub_fig_generator(sub_fs=(3, 3), col_n=10, row_n=1, wspace=None, hspace=None, sharex=False, sharey=False):
|
||||
"""sub_fig_generator.
|
||||
it will return a generator, each row contains <col_n> sub graph
|
||||
|
||||
FIXME: Known limitation:
|
||||
- The last row will not be plotted automatically, please plot it outside the function
|
||||
|
||||
Parameters
|
||||
----------
|
||||
sub_fs :
|
||||
the figure size of each subgraph in <col_n> * <row_n> subgraphs
|
||||
col_n :
|
||||
the number of subgraph in each row; It will generating a new graph after generating <col_n> of subgraphs.
|
||||
row_n :
|
||||
the number of subgraph in each column
|
||||
wspace :
|
||||
the width of the space for subgraphs in each row
|
||||
hspace :
|
||||
the height of blank space for subgraphs in each column
|
||||
You can try 0.3 if you feel it is too crowded
|
||||
|
||||
Returns
|
||||
-------
|
||||
It will return graphs with the shape of <col_n> each iter (it is squeezed).
|
||||
"""
|
||||
assert col_n > 1
|
||||
|
||||
while True:
|
||||
fig, axes = plt.subplots(
|
||||
row_n, col_n, figsize=(sub_fs[0] * col_n, sub_fs[1] * row_n), sharex=sharex, sharey=sharey
|
||||
)
|
||||
plt.subplots_adjust(wspace=wspace, hspace=hspace)
|
||||
axes = axes.reshape(row_n, col_n)
|
||||
|
||||
for col in range(col_n):
|
||||
res = axes[:, col].squeeze()
|
||||
if res.size == 1:
|
||||
res = res.item()
|
||||
yield res
|
||||
plt.show()
|
||||
@@ -10,4 +10,3 @@ class BaseOptimizer(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def __call__(self, *args, **kwargs) -> object:
|
||||
"""Generate a optimized portfolio allocation"""
|
||||
pass
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
|
||||
import numpy as np
|
||||
import cvxpy as cp
|
||||
import pandas as pd
|
||||
|
||||
from typing import Union, Optional, Dict, Any, List
|
||||
|
||||
@@ -126,7 +125,7 @@ class EnhancedIndexingOptimizer(BaseOptimizer):
|
||||
|
||||
# objective
|
||||
ret = d @ r # excess return
|
||||
risk = cp.quad_form(v, cov_b) + var_u @ (d ** 2) # tracking error
|
||||
risk = cp.quad_form(v, cov_b) + var_u @ (d**2) # tracking error
|
||||
obj = cp.Maximize(ret - self.lamb * risk)
|
||||
|
||||
# weight bounds
|
||||
@@ -156,7 +155,7 @@ class EnhancedIndexingOptimizer(BaseOptimizer):
|
||||
|
||||
# factor deviation
|
||||
if self.f_dev is not None:
|
||||
cons.extend([v >= -self.f_dev, v <= self.f_dev])
|
||||
cons.extend([v >= -self.f_dev, v <= self.f_dev]) # pylint: disable=E1130
|
||||
|
||||
# total turnover constraint
|
||||
t_cons = []
|
||||
|
||||
@@ -6,7 +6,6 @@ This order generator is for strategies based on WeightStrategyBase
|
||||
"""
|
||||
from ...backtest.position import Position
|
||||
from ...backtest.exchange import Exchange
|
||||
from ...backtest.decision import BaseTradeDecision, TradeDecisionWO
|
||||
|
||||
import pandas as pd
|
||||
import copy
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
import os
|
||||
import copy
|
||||
import warnings
|
||||
import cvxpy as cp
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
@@ -15,11 +14,10 @@ from qlib.model.base import BaseModel
|
||||
from qlib.strategy.base import BaseStrategy
|
||||
from qlib.backtest.position import Position
|
||||
from qlib.backtest.signal import Signal, create_signal_from
|
||||
from qlib.backtest.decision import Order, BaseTradeDecision, OrderDir, TradeDecisionWO
|
||||
from qlib.backtest.decision import Order, OrderDir, TradeDecisionWO
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.utils import get_pre_trading_date, load_dataset
|
||||
from qlib.utils.resam import resam_ts_data
|
||||
from qlib.contrib.strategy.order_generator import OrderGenWInteract, OrderGenWOInteract
|
||||
from qlib.contrib.strategy.order_generator import OrderGenWOInteract
|
||||
from qlib.contrib.strategy.optimizer import EnhancedIndexingOptimizer
|
||||
|
||||
|
||||
@@ -484,7 +482,7 @@ class EnhancedIndexingStrategy(WeightStrategyBase):
|
||||
r=score,
|
||||
F=factor_exp,
|
||||
cov_b=factor_cov,
|
||||
var_u=specific_risk ** 2,
|
||||
var_u=specific_risk**2,
|
||||
w0=cur_weight,
|
||||
wb=bench_weight,
|
||||
mfh=mask_force_hold,
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
# pylint: skip-file
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
import yaml
|
||||
import copy
|
||||
import os
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
# coding=utf-8
|
||||
|
||||
import argparse
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
from hyperopt import hp
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# pylint: skip-file
|
||||
|
||||
import os
|
||||
import yaml
|
||||
import json
|
||||
|
||||
@@ -6,7 +6,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import abc
|
||||
import pandas as pd
|
||||
|
||||
from ..log import get_module_logger
|
||||
|
||||
@@ -21,107 +20,107 @@ class Expression(abc.ABC):
|
||||
return str(self)
|
||||
|
||||
def __gt__(self, other):
|
||||
from .ops import Gt
|
||||
from .ops import Gt # pylint: disable=C0415
|
||||
|
||||
return Gt(self, other)
|
||||
|
||||
def __ge__(self, other):
|
||||
from .ops import Ge
|
||||
from .ops import Ge # pylint: disable=C0415
|
||||
|
||||
return Ge(self, other)
|
||||
|
||||
def __lt__(self, other):
|
||||
from .ops import Lt
|
||||
from .ops import Lt # pylint: disable=C0415
|
||||
|
||||
return Lt(self, other)
|
||||
|
||||
def __le__(self, other):
|
||||
from .ops import Le
|
||||
from .ops import Le # pylint: disable=C0415
|
||||
|
||||
return Le(self, other)
|
||||
|
||||
def __eq__(self, other):
|
||||
from .ops import Eq
|
||||
from .ops import Eq # pylint: disable=C0415
|
||||
|
||||
return Eq(self, other)
|
||||
|
||||
def __ne__(self, other):
|
||||
from .ops import Ne
|
||||
from .ops import Ne # pylint: disable=C0415
|
||||
|
||||
return Ne(self, other)
|
||||
|
||||
def __add__(self, other):
|
||||
from .ops import Add
|
||||
from .ops import Add # pylint: disable=C0415
|
||||
|
||||
return Add(self, other)
|
||||
|
||||
def __radd__(self, other):
|
||||
from .ops import Add
|
||||
from .ops import Add # pylint: disable=C0415
|
||||
|
||||
return Add(other, self)
|
||||
|
||||
def __sub__(self, other):
|
||||
from .ops import Sub
|
||||
from .ops import Sub # pylint: disable=C0415
|
||||
|
||||
return Sub(self, other)
|
||||
|
||||
def __rsub__(self, other):
|
||||
from .ops import Sub
|
||||
from .ops import Sub # pylint: disable=C0415
|
||||
|
||||
return Sub(other, self)
|
||||
|
||||
def __mul__(self, other):
|
||||
from .ops import Mul
|
||||
from .ops import Mul # pylint: disable=C0415
|
||||
|
||||
return Mul(self, other)
|
||||
|
||||
def __rmul__(self, other):
|
||||
from .ops import Mul
|
||||
from .ops import Mul # pylint: disable=C0415
|
||||
|
||||
return Mul(self, other)
|
||||
|
||||
def __div__(self, other):
|
||||
from .ops import Div
|
||||
from .ops import Div # pylint: disable=C0415
|
||||
|
||||
return Div(self, other)
|
||||
|
||||
def __rdiv__(self, other):
|
||||
from .ops import Div
|
||||
from .ops import Div # pylint: disable=C0415
|
||||
|
||||
return Div(other, self)
|
||||
|
||||
def __truediv__(self, other):
|
||||
from .ops import Div
|
||||
from .ops import Div # pylint: disable=C0415
|
||||
|
||||
return Div(self, other)
|
||||
|
||||
def __rtruediv__(self, other):
|
||||
from .ops import Div
|
||||
from .ops import Div # pylint: disable=C0415
|
||||
|
||||
return Div(other, self)
|
||||
|
||||
def __pow__(self, other):
|
||||
from .ops import Power
|
||||
from .ops import Power # pylint: disable=C0415
|
||||
|
||||
return Power(self, other)
|
||||
|
||||
def __and__(self, other):
|
||||
from .ops import And
|
||||
from .ops import And # pylint: disable=C0415
|
||||
|
||||
return And(self, other)
|
||||
|
||||
def __rand__(self, other):
|
||||
from .ops import And
|
||||
from .ops import And # pylint: disable=C0415
|
||||
|
||||
return And(other, self)
|
||||
|
||||
def __or__(self, other):
|
||||
from .ops import Or
|
||||
from .ops import Or # pylint: disable=C0415
|
||||
|
||||
return Or(self, other)
|
||||
|
||||
def __ror__(self, other):
|
||||
from .ops import Or
|
||||
from .ops import Or # pylint: disable=C0415
|
||||
|
||||
return Or(other, self)
|
||||
|
||||
@@ -144,7 +143,7 @@ class Expression(abc.ABC):
|
||||
pd.Series
|
||||
feature series: The index of the series is the calendar index
|
||||
"""
|
||||
from .cache import H
|
||||
from .cache import H # pylint: disable=C0415
|
||||
|
||||
# cache
|
||||
args = str(self), instrument, start_index, end_index, freq
|
||||
@@ -215,7 +214,7 @@ class Feature(Expression):
|
||||
|
||||
def _load_internal(self, instrument, start_index, end_index, freq):
|
||||
# load
|
||||
from .data import FeatureD
|
||||
from .data import FeatureD # pylint: disable=C0415
|
||||
|
||||
return FeatureD.feature(instrument, str(self), start_index, end_index, freq)
|
||||
|
||||
@@ -232,5 +231,3 @@ class ExpressionOps(Expression):
|
||||
This kind of feature will use operator for feature
|
||||
construction on the fly.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
@@ -33,8 +33,7 @@ from ..utils import (
|
||||
|
||||
from ..log import get_module_logger
|
||||
from .base import Feature
|
||||
|
||||
from .ops import Operators
|
||||
from .ops import Operators # pylint: disable=W0611
|
||||
|
||||
|
||||
class QlibCacheException(RuntimeError):
|
||||
@@ -229,8 +228,8 @@ class CacheUtils:
|
||||
try:
|
||||
d["meta"]["last_visit"] = str(time.time())
|
||||
d["meta"]["visits"] = d["meta"]["visits"] + 1
|
||||
except KeyError:
|
||||
raise KeyError("Unknown meta keyword")
|
||||
except KeyError as key_e:
|
||||
raise KeyError("Unknown meta keyword") from key_e
|
||||
pickle.dump(d, f, protocol=C.dump_protocol_version)
|
||||
except Exception as e:
|
||||
get_module_logger("CacheUtils").warning(f"visit {cache_path} cache error: {e}")
|
||||
@@ -239,7 +238,7 @@ class CacheUtils:
|
||||
def acquire(lock, lock_name):
|
||||
try:
|
||||
lock.acquire()
|
||||
except redis_lock.AlreadyAcquired:
|
||||
except redis_lock.AlreadyAcquired as lock_acquired:
|
||||
raise QlibCacheException(
|
||||
f"""It sees the key(lock:{repr(lock_name)[1:-1]}-wlock) of the redis lock has existed in your redis db now.
|
||||
You can use the following command to clear your redis keys and rerun your commands:
|
||||
@@ -249,7 +248,7 @@ class CacheUtils:
|
||||
> quit
|
||||
If the issue is not resolved, use "keys *" to find if multiple keys exist. If so, try using "flushall" to clear all the keys.
|
||||
"""
|
||||
)
|
||||
) from lock_acquired
|
||||
|
||||
@staticmethod
|
||||
@contextlib.contextmanager
|
||||
@@ -507,7 +506,7 @@ class DiskExpressionCache(ExpressionCache):
|
||||
_instrument_dir = self.get_cache_dir(freq).joinpath(instrument.lower())
|
||||
cache_path = _instrument_dir.joinpath(_cache_uri)
|
||||
# get calendar
|
||||
from .data import Cal
|
||||
from .data import Cal # pylint: disable=C0415
|
||||
|
||||
_calendar = Cal.calendar(freq=freq)
|
||||
|
||||
@@ -599,7 +598,7 @@ class DiskExpressionCache(ExpressionCache):
|
||||
last_update_time = d["info"]["last_update"]
|
||||
|
||||
# get newest calendar
|
||||
from .data import Cal, ExpressionD
|
||||
from .data import Cal, ExpressionD # pylint: disable=C0415
|
||||
|
||||
whole_calendar = Cal.calendar(start_time=None, end_time=None, freq=freq)
|
||||
# calendar since last updated.
|
||||
@@ -753,7 +752,7 @@ class DiskDatasetCache(DatasetCache):
|
||||
if disk_cache == 0:
|
||||
# In this case, server only checks the expression cache.
|
||||
# The client will load the cache data by itself.
|
||||
from .data import LocalDatasetProvider
|
||||
from .data import LocalDatasetProvider # pylint: disable=C0415
|
||||
|
||||
LocalDatasetProvider.multi_cache_walker(instruments, fields, start_time, end_time, freq)
|
||||
return ""
|
||||
@@ -895,7 +894,7 @@ class DiskDatasetCache(DatasetCache):
|
||||
:return type pd.DataFrame; The fields of the returned DataFrame are consistent with the parameters of the function.
|
||||
"""
|
||||
# get calendar
|
||||
from .data import Cal
|
||||
from .data import Cal # pylint: disable=C0415
|
||||
|
||||
cache_path = Path(cache_path)
|
||||
_calendar = Cal.calendar(freq=freq)
|
||||
@@ -970,14 +969,14 @@ class DiskDatasetCache(DatasetCache):
|
||||
index_data = im.get_index()
|
||||
|
||||
self.logger.debug("Updating dataset: {}".format(d))
|
||||
from .data import Inst
|
||||
from .data import Inst # pylint: disable=C0415
|
||||
|
||||
if Inst.get_inst_type(instruments) == Inst.DICT:
|
||||
self.logger.info(f"The file {cache_uri} has dict cache. Skip updating")
|
||||
return 1
|
||||
|
||||
# get newest calendar
|
||||
from .data import Cal
|
||||
from .data import Cal # pylint: disable=C0415
|
||||
|
||||
whole_calendar = Cal.calendar(start_time=None, end_time=None, freq=freq)
|
||||
# The calendar since last updated
|
||||
@@ -994,7 +993,7 @@ class DiskDatasetCache(DatasetCache):
|
||||
current_index = len(whole_calendar) - len(new_calendar) + 1
|
||||
|
||||
# To avoid recursive import
|
||||
from .data import ExpressionD
|
||||
from .data import ExpressionD # pylint: disable=C0415
|
||||
|
||||
# The existing data length
|
||||
lft_etd = rght_etd = 0
|
||||
|
||||
@@ -5,17 +5,13 @@
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import re
|
||||
import abc
|
||||
import time
|
||||
import copy
|
||||
import queue
|
||||
import bisect
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from multiprocessing import Pool
|
||||
from typing import Iterable, Union
|
||||
from typing import List, Union
|
||||
|
||||
# For supporting multiprocessing in outer code, joblib is used
|
||||
@@ -23,13 +19,10 @@ from joblib import delayed
|
||||
|
||||
from .cache import H
|
||||
from ..config import C
|
||||
from .base import Feature
|
||||
from .ops import Operators
|
||||
from .inst_processor import InstProcessor
|
||||
|
||||
from ..log import get_module_logger
|
||||
from ..utils.time import Freq
|
||||
from .cache import DiskDatasetCache, DiskExpressionCache
|
||||
from .cache import DiskDatasetCache
|
||||
from ..utils import (
|
||||
Wrapper,
|
||||
init_instance_by_config,
|
||||
@@ -43,6 +36,7 @@ from ..utils import (
|
||||
time_to_slc_point,
|
||||
)
|
||||
from ..utils.paral import ParallelExt
|
||||
from .ops import Operators # pylint: disable=W0611
|
||||
|
||||
|
||||
class ProviderBackendMixin:
|
||||
@@ -144,10 +138,10 @@ class CalendarProvider(abc.ABC):
|
||||
if start_time not in calendar_index:
|
||||
try:
|
||||
start_time = calendar[bisect.bisect_left(calendar, start_time)]
|
||||
except IndexError:
|
||||
except IndexError as index_e:
|
||||
raise IndexError(
|
||||
"`start_time` uses a future date, if you want to get future trading days, you can use: `future=True`"
|
||||
)
|
||||
) from index_e
|
||||
start_index = calendar_index[start_time]
|
||||
if end_time not in calendar_index:
|
||||
end_time = calendar[bisect.bisect_right(calendar, end_time) - 1]
|
||||
@@ -246,7 +240,7 @@ class InstrumentProvider(abc.ABC):
|
||||
"""
|
||||
if isinstance(market, list):
|
||||
return market
|
||||
from .filter import SeriesDFilter
|
||||
from .filter import SeriesDFilter # pylint: disable=C0415
|
||||
|
||||
if filter_pipe is None:
|
||||
filter_pipe = []
|
||||
@@ -672,7 +666,7 @@ class LocalInstrumentProvider(InstrumentProvider, ProviderBackendMixin):
|
||||
# filter
|
||||
filter_pipe = instruments["filter_pipe"]
|
||||
for filter_config in filter_pipe:
|
||||
from . import filter as F
|
||||
from . import filter as F # pylint: disable=C0415
|
||||
|
||||
filter_t = getattr(F, filter_config["filter_type"]).from_config(filter_config)
|
||||
_instruments_filtered = filter_t(_instruments_filtered, start_time, end_time, freq)
|
||||
@@ -1003,8 +997,8 @@ class ClientDatasetProvider(DatasetProvider):
|
||||
if return_uri:
|
||||
return df, feature_uri
|
||||
return df
|
||||
except AttributeError:
|
||||
raise IOError("Unable to fetch instruments from remote server!")
|
||||
except AttributeError as attribute_e:
|
||||
raise IOError("Unable to fetch instruments from remote server!") from attribute_e
|
||||
|
||||
|
||||
class BaseProvider:
|
||||
@@ -1110,7 +1104,7 @@ class ClientProvider(BaseProvider):
|
||||
|
||||
return isinstance(instance, cls)
|
||||
|
||||
from .client import Client
|
||||
from .client import Client # pylint: disable=C0415
|
||||
|
||||
self.client = Client(C.flask_server, C.flask_port)
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Callable, Union, List, Tuple, Dict, Text, Optional
|
||||
from ...utils import init_instance_by_config, np_ffill, time_to_slc_point
|
||||
from ...log import get_module_logger
|
||||
from .handler import DataHandler, DataHandlerLP
|
||||
from copy import deepcopy
|
||||
from copy import copy, deepcopy
|
||||
from inspect import getfullargspec
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
@@ -52,7 +52,6 @@ class Dataset(Serializable):
|
||||
|
||||
- User prepare data for model based on previous status.
|
||||
"""
|
||||
pass
|
||||
|
||||
def prepare(self, **kwargs) -> object:
|
||||
"""
|
||||
@@ -68,7 +67,6 @@ class Dataset(Serializable):
|
||||
object:
|
||||
return the object
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class DatasetH(Dataset):
|
||||
@@ -83,7 +81,9 @@ class DatasetH(Dataset):
|
||||
- The processing is related to data split.
|
||||
"""
|
||||
|
||||
def __init__(self, handler: Union[Dict, DataHandler], segments: Dict[Text, Tuple], **kwargs):
|
||||
def __init__(
|
||||
self, handler: Union[Dict, DataHandler], segments: Dict[Text, Tuple], fetch_kwargs: Dict = {}, **kwargs
|
||||
):
|
||||
"""
|
||||
Setup the underlying data.
|
||||
|
||||
@@ -114,7 +114,7 @@ class DatasetH(Dataset):
|
||||
"""
|
||||
self.handler: DataHandler = init_instance_by_config(handler, accept_types=DataHandler)
|
||||
self.segments = segments.copy()
|
||||
self.fetch_kwargs = {}
|
||||
self.fetch_kwargs = copy(fetch_kwargs)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def config(self, handler_kwargs: dict = None, **kwargs):
|
||||
@@ -164,13 +164,13 @@ class DatasetH(Dataset):
|
||||
name=self.__class__.__name__, handler=self.handler, segments=self.segments
|
||||
)
|
||||
|
||||
def _prepare_seg(self, slc: slice, **kwargs):
|
||||
def _prepare_seg(self, slc, **kwargs):
|
||||
"""
|
||||
Give a slice, retrieve the according data
|
||||
Give a query, retrieve the according data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
slc : slice
|
||||
slc : please refer to the docs of `prepare`
|
||||
"""
|
||||
if hasattr(self, "fetch_kwargs"):
|
||||
return self.handler.fetch(slc, **kwargs, **self.fetch_kwargs)
|
||||
@@ -179,7 +179,7 @@ class DatasetH(Dataset):
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
segments: Union[List[Text], Tuple[Text], Text, slice],
|
||||
segments: Union[List[Text], Tuple[Text], Text, slice, pd.Index],
|
||||
col_set=DataHandler.CS_ALL,
|
||||
data_key=DataHandlerLP.DK_I,
|
||||
**kwargs,
|
||||
@@ -218,22 +218,27 @@ class DatasetH(Dataset):
|
||||
NotImplementedError:
|
||||
"""
|
||||
logger = get_module_logger("DatasetH")
|
||||
fetch_kwargs = {"col_set": col_set}
|
||||
fetch_kwargs.update(kwargs)
|
||||
seg_kwargs = {"col_set": col_set}
|
||||
seg_kwargs.update(kwargs)
|
||||
if "data_key" in getfullargspec(self.handler.fetch).args:
|
||||
fetch_kwargs["data_key"] = data_key
|
||||
seg_kwargs["data_key"] = data_key
|
||||
else:
|
||||
logger.info(f"data_key[{data_key}] is ignored.")
|
||||
|
||||
# Handle all kinds of segments format
|
||||
if isinstance(segments, (list, tuple)):
|
||||
return [self._prepare_seg(slice(*self.segments[seg]), **fetch_kwargs) for seg in segments]
|
||||
elif isinstance(segments, str):
|
||||
return self._prepare_seg(slice(*self.segments[segments]), **fetch_kwargs)
|
||||
elif isinstance(segments, slice):
|
||||
return self._prepare_seg(segments, **fetch_kwargs)
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
# Conflictions may happen here
|
||||
# - The fetched data and the segment key may both be string
|
||||
# To resolve the confliction
|
||||
# - The segment name will have higher priorities
|
||||
|
||||
# 1) Use it as segment name first
|
||||
if isinstance(segments, str) and segments in self.segments:
|
||||
return self._prepare_seg(self.segments[segments], **seg_kwargs)
|
||||
|
||||
if isinstance(segments, (list, tuple)) and all(seg in self.segments for seg in segments):
|
||||
return [self._prepare_seg(self.segments[seg], **seg_kwargs) for seg in segments]
|
||||
|
||||
# 2) Use pass it directly to prepare a single seg
|
||||
return self._prepare_seg(segments, **seg_kwargs)
|
||||
|
||||
# helper functions
|
||||
@staticmethod
|
||||
@@ -341,7 +346,7 @@ class TSDataSampler:
|
||||
flt_data = flt_data.reindex(self.data_index).fillna(False).astype(np.bool)
|
||||
self.flt_data = flt_data.values
|
||||
self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map)
|
||||
self.data_index = self.data_index[np.where(self.flt_data == True)[0]]
|
||||
self.data_index = self.data_index[np.where(self.flt_data is True)[0]]
|
||||
self.idx_map = self.idx_map2arr(self.idx_map)
|
||||
|
||||
self.start_idx, self.end_idx = self.data_index.slice_locs(
|
||||
@@ -582,8 +587,11 @@ class TSDatasetH(DatasetH):
|
||||
def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
|
||||
"""
|
||||
split the _prepare_raw_seg is to leave a hook for data preprocessing before creating processing data
|
||||
NOTE: TSDatasetH only support slc segment on datetime !!!
|
||||
"""
|
||||
dtype = kwargs.pop("dtype", None)
|
||||
if not isinstance(slc, slice):
|
||||
slc = slice(*slc)
|
||||
start, end = slc.start, slc.stop
|
||||
flt_col = kwargs.pop("flt_col", None)
|
||||
# TSDatasetH will retrieve more data for complete time-series
|
||||
|
||||
@@ -2,24 +2,16 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# coding=utf-8
|
||||
import abc
|
||||
import bisect
|
||||
import logging
|
||||
import warnings
|
||||
from inspect import getfullargspec
|
||||
from typing import Callable, Union, Tuple, List, Iterator, Optional
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
from ...data import D
|
||||
from ...config import C
|
||||
from ...utils import parse_config, transform_end_date, init_instance_by_config
|
||||
from ...utils import init_instance_by_config
|
||||
from ...utils.serial import Serializable
|
||||
from .utils import fetch_df_by_index, fetch_df_by_col
|
||||
from ...utils import lazy_sort_index
|
||||
from pathlib import Path
|
||||
from .loader import DataLoader
|
||||
|
||||
from . import processor as processor_module
|
||||
@@ -154,7 +146,7 @@ class DataHandler(Serializable):
|
||||
|
||||
def fetch(
|
||||
self,
|
||||
selector: Union[pd.Timestamp, slice, str] = slice(None, None),
|
||||
selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None),
|
||||
level: Union[str, int] = "datetime",
|
||||
col_set: Union[str, List[str]] = CS_ALL,
|
||||
squeeze: bool = False,
|
||||
@@ -167,13 +159,24 @@ class DataHandler(Serializable):
|
||||
----------
|
||||
selector : Union[pd.Timestamp, slice, str]
|
||||
describe how to select data by index
|
||||
It can be categories as following
|
||||
- fetch single index
|
||||
- fetch a range of index
|
||||
- a slice range
|
||||
- pd.Index for specific indexes
|
||||
|
||||
Following conflictions may occurs
|
||||
- Does [20200101", "20210101"] mean selecting this slice or these two days?
|
||||
- slice have higher priorities
|
||||
|
||||
level : Union[str, int]
|
||||
which index level to select the data
|
||||
|
||||
col_set : Union[str, List[str]]
|
||||
|
||||
- if isinstance(col_set, str):
|
||||
|
||||
select a set of meaningful columns.(e.g. features, columns)
|
||||
select a set of meaningful, pd.Index columns.(e.g. features, columns)
|
||||
|
||||
if col_set == CS_RAW:
|
||||
the raw dataset will be returned.
|
||||
@@ -181,6 +184,7 @@ class DataHandler(Serializable):
|
||||
- if isinstance(col_set, List[str]):
|
||||
|
||||
select several sets of meaningful columns, the returned data has multiple levels
|
||||
|
||||
proc_func: Callable
|
||||
- Give a hook for processing data before fetching
|
||||
- An example to explain the necessity of the hook:
|
||||
@@ -197,9 +201,39 @@ class DataHandler(Serializable):
|
||||
-------
|
||||
pd.DataFrame.
|
||||
"""
|
||||
from .storage import BaseHandlerStorage
|
||||
return self._fetch_data(
|
||||
data_storage=self._data,
|
||||
selector=selector,
|
||||
level=level,
|
||||
col_set=col_set,
|
||||
squeeze=squeeze,
|
||||
proc_func=proc_func,
|
||||
)
|
||||
|
||||
def _fetch_data(
|
||||
self,
|
||||
data_storage,
|
||||
selector: Union[pd.Timestamp, slice, str, pd.Index] = slice(None, None),
|
||||
level: Union[str, int] = "datetime",
|
||||
col_set: Union[str, List[str]] = CS_ALL,
|
||||
squeeze: bool = False,
|
||||
proc_func: Callable = None,
|
||||
):
|
||||
# This method is extracted for sharing in subclasses
|
||||
from .storage import BaseHandlerStorage # pylint: disable=C0415
|
||||
|
||||
# Following conflictions may occurs
|
||||
# - Does [20200101", "20210101"] mean selecting this slice or these two days?
|
||||
# To solve this issue
|
||||
# - slice have higher priorities (except when level is none)
|
||||
if isinstance(selector, (tuple, list)) and level is not None:
|
||||
# when level is None, the argument will be passed in directly
|
||||
# we don't have to convert it into slice
|
||||
try:
|
||||
selector = slice(*selector)
|
||||
except ValueError:
|
||||
get_module_logger("DataHandlerLP").info(f"Fail to converting to query to slice. It will used directly")
|
||||
|
||||
data_storage = self._data
|
||||
if isinstance(data_storage, pd.DataFrame):
|
||||
data_df = data_storage
|
||||
if proc_func is not None:
|
||||
@@ -223,7 +257,7 @@ class DataHandler(Serializable):
|
||||
selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig, proc_func=proc_func
|
||||
)
|
||||
else:
|
||||
raise TypeError(f"data_storage should be pd.DataFrame|HasingStockStorage, not {type(data_storage)}")
|
||||
raise TypeError(f"data_storage should be pd.DataFrame|HashingStockStorage, not {type(data_storage)}")
|
||||
|
||||
if squeeze:
|
||||
# squeeze columns
|
||||
@@ -291,7 +325,18 @@ class DataHandlerLP(DataHandler):
|
||||
"""
|
||||
DataHandler with **(L)earnable (P)rocessor**
|
||||
|
||||
Tips to improving the performance of data handler
|
||||
This handler will produce three pieces of data in pd.DataFrame format.
|
||||
- DK_R / self._data: the raw data loaded from the loader
|
||||
- DK_I / self._infer: the data processed for inference
|
||||
- DK_L / self._learn: the data processed for learning model.
|
||||
|
||||
The motivation of using different processor workflows for learning and inference
|
||||
Here are some examples.
|
||||
- The instrument universe for learning and inference may be different.
|
||||
- The processing of some samples may rely on label (for example, some samples hit the limit may need extra processing or be dropped).
|
||||
These processors only apply to the learning phase.
|
||||
|
||||
Tips to improve the performance of data handler
|
||||
- To reduce the memory cost
|
||||
- `drop_raw=True`: this will modify the data inplace on raw data;
|
||||
"""
|
||||
@@ -551,6 +596,7 @@ class DataHandlerLP(DataHandler):
|
||||
level: Union[str, int] = "datetime",
|
||||
col_set=DataHandler.CS_ALL,
|
||||
data_key: str = DK_I,
|
||||
squeeze: bool = False,
|
||||
proc_func: Callable = None,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
@@ -573,36 +619,15 @@ class DataHandlerLP(DataHandler):
|
||||
-------
|
||||
pd.DataFrame:
|
||||
"""
|
||||
from .storage import BaseHandlerStorage
|
||||
|
||||
data_storage = self._get_df_by_key(data_key)
|
||||
if isinstance(data_storage, pd.DataFrame):
|
||||
data_df = data_storage
|
||||
if proc_func is not None:
|
||||
# FIXME: fetch by time first will be more friendly to proc_func
|
||||
# Copy incase of `proc_func` changing the data inplace....
|
||||
data_df = proc_func(fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig).copy())
|
||||
data_df = fetch_df_by_col(data_df, col_set)
|
||||
else:
|
||||
# Fetch column first will be more friendly to SepDataFrame
|
||||
data_df = fetch_df_by_col(data_df, col_set)
|
||||
data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig)
|
||||
|
||||
elif isinstance(data_storage, BaseHandlerStorage):
|
||||
if not data_storage.is_proc_func_supported():
|
||||
if proc_func is not None:
|
||||
raise ValueError(f"proc_func is not supported by the storage {type(data_storage)}")
|
||||
data_df = data_storage.fetch(
|
||||
selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig
|
||||
)
|
||||
else:
|
||||
data_df = data_storage.fetch(
|
||||
selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig, proc_func=proc_func
|
||||
)
|
||||
else:
|
||||
raise TypeError(f"data_storage should be pd.DataFrame|HasingStockStorage, not {type(data_storage)}")
|
||||
|
||||
return data_df
|
||||
return self._fetch_data(
|
||||
data_storage=self._get_df_by_key(data_key),
|
||||
selector=selector,
|
||||
level=level,
|
||||
col_set=col_set,
|
||||
squeeze=squeeze,
|
||||
proc_func=proc_func,
|
||||
)
|
||||
|
||||
def get_cols(self, col_set=DataHandler.CS_ALL, data_key: str = DK_I) -> list:
|
||||
"""
|
||||
|
||||
@@ -51,7 +51,6 @@ class DataLoader(abc.ABC):
|
||||
pd.DataFrame:
|
||||
data load from the under layer source
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class DLWParser(DataLoader):
|
||||
@@ -129,7 +128,6 @@ class DLWParser(DataLoader):
|
||||
pd.DataFrame:
|
||||
the queried dataframe.
|
||||
"""
|
||||
pass
|
||||
|
||||
def load(self, instruments=None, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
if self.is_group:
|
||||
@@ -308,7 +306,7 @@ class DataLoaderDH(DataLoader):
|
||||
is_group will be used to describe whether the key of handler_config is group
|
||||
|
||||
"""
|
||||
from qlib.data.dataset.handler import DataHandler
|
||||
from qlib.data.dataset.handler import DataHandler # pylint: disable=C0415
|
||||
|
||||
if is_group:
|
||||
self.handlers = {
|
||||
|
||||
@@ -42,7 +42,6 @@ class Processor(Serializable):
|
||||
processor, i.e. `df`.
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def __call__(self, df: pd.DataFrame):
|
||||
@@ -57,7 +56,6 @@ class Processor(Serializable):
|
||||
df : pd.DataFrame
|
||||
The raw_df of handler or result from previous processor.
|
||||
"""
|
||||
pass
|
||||
|
||||
def is_for_infer(self) -> bool:
|
||||
"""
|
||||
@@ -201,7 +199,7 @@ class MinMaxNorm(Processor):
|
||||
self.fit_end_time = fit_end_time
|
||||
self.fields_group = fields_group
|
||||
|
||||
def fit(self, df):
|
||||
def fit(self, df: pd.DataFrame = None):
|
||||
df = fetch_df_by_index(df, slice(self.fit_start_time, self.fit_end_time), level="datetime")
|
||||
cols = get_group_columns(df, self.fields_group)
|
||||
self.min_val = np.nanmin(df[cols].values, axis=0)
|
||||
@@ -232,7 +230,7 @@ class ZScoreNorm(Processor):
|
||||
self.fit_end_time = fit_end_time
|
||||
self.fields_group = fields_group
|
||||
|
||||
def fit(self, df):
|
||||
def fit(self, df: pd.DataFrame = None):
|
||||
df = fetch_df_by_index(df, slice(self.fit_start_time, self.fit_end_time), level="datetime")
|
||||
cols = get_group_columns(df, self.fields_group)
|
||||
self.mean_train = np.nanmean(df[cols].values, axis=0)
|
||||
@@ -272,7 +270,7 @@ class RobustZScoreNorm(Processor):
|
||||
self.fields_group = fields_group
|
||||
self.clip_outlier = clip_outlier
|
||||
|
||||
def fit(self, df):
|
||||
def fit(self, df: pd.DataFrame = None):
|
||||
df = fetch_df_by_index(df, slice(self.fit_start_time, self.fit_end_time), level="datetime")
|
||||
self.cols = get_group_columns(df, self.fields_group)
|
||||
X = df[self.cols].values
|
||||
@@ -351,6 +349,6 @@ class HashStockFormat(Processor):
|
||||
"""Process the storage of from df into hasing stock format"""
|
||||
|
||||
def __call__(self, df: pd.DataFrame):
|
||||
from .storage import HasingStockStorage
|
||||
from .storage import HashingStockStorage # pylint: disable=C0415
|
||||
|
||||
return HasingStockStorage.from_df(df)
|
||||
return HashingStockStorage.from_df(df)
|
||||
|
||||
@@ -2,7 +2,7 @@ import pandas as pd
|
||||
import numpy as np
|
||||
|
||||
from .handler import DataHandler
|
||||
from typing import Tuple, Union, List, Callable
|
||||
from typing import Union, List, Callable
|
||||
|
||||
from .utils import get_level_index, fetch_df_by_index, fetch_df_by_col
|
||||
|
||||
@@ -59,11 +59,11 @@ class BaseHandlerStorage:
|
||||
raise NotImplementedError("is_proc_func_supported method is not implemented!")
|
||||
|
||||
|
||||
class HasingStockStorage(BaseHandlerStorage):
|
||||
"""Hasing data storage for datahanlder
|
||||
class HashingStockStorage(BaseHandlerStorage):
|
||||
"""Hashing data storage for datahanlder
|
||||
- The default data storage pandas.DataFrame is too slow when randomly accessing one stock's data
|
||||
- HasingStockStorage hashes the multiple stocks' data(pandas.DataFrame) by the key `stock_id`.
|
||||
- HasingStockStorage hashes the pandas.DataFrame into a dict, whose key is the stock_id(str) and value this stock data(panda.DataFrame), it has the following format:
|
||||
- HashingStockStorage hashes the multiple stocks' data(pandas.DataFrame) by the key `stock_id`.
|
||||
- HashingStockStorage hashes the pandas.DataFrame into a dict, whose key is the stock_id(str) and value this stock data(panda.DataFrame), it has the following format:
|
||||
{
|
||||
stock1_id: stock1_data,
|
||||
stock2_id: stock2_data,
|
||||
@@ -82,7 +82,7 @@ class HasingStockStorage(BaseHandlerStorage):
|
||||
|
||||
@staticmethod
|
||||
def from_df(df):
|
||||
return HasingStockStorage(df)
|
||||
return HashingStockStorage(df)
|
||||
|
||||
def _fetch_hash_df_by_stock(self, selector, level):
|
||||
"""fetch the data with stock selector
|
||||
@@ -109,7 +109,7 @@ class HasingStockStorage(BaseHandlerStorage):
|
||||
stock_selector = selector[self.stock_level]
|
||||
elif isinstance(selector, (list, str)) and self.stock_level == 0:
|
||||
stock_selector = selector
|
||||
elif level == "instrument" or level == self.stock_level:
|
||||
elif level in ("instrument", self.stock_level):
|
||||
if isinstance(selector, tuple):
|
||||
stock_selector = selector[0]
|
||||
elif isinstance(selector, (list, str)):
|
||||
@@ -153,5 +153,5 @@ class HasingStockStorage(BaseHandlerStorage):
|
||||
return pd.concat(fetch_stock_df_list, sort=False, copy=~fetch_orig)
|
||||
|
||||
def is_proc_func_supported(self):
|
||||
"""the arg `proc_func` in `fetch` method is not supported in HasingStockStorage"""
|
||||
"""the arg `proc_func` in `fetch` method is not supported in HashingStockStorage"""
|
||||
return False
|
||||
|
||||
@@ -41,13 +41,16 @@ def get_level_index(df: pd.DataFrame, level=Union[str, int]) -> int:
|
||||
|
||||
def fetch_df_by_index(
|
||||
df: pd.DataFrame,
|
||||
selector: Union[pd.Timestamp, slice, str, list],
|
||||
selector: Union[pd.Timestamp, slice, str, list, pd.Index],
|
||||
level: Union[str, int],
|
||||
fetch_orig=True,
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
fetch data from `data` with `selector` and `level`
|
||||
|
||||
selector are assumed to be well processed.
|
||||
`fetch_df_by_index` is only responsible for get the right level
|
||||
|
||||
Parameters
|
||||
----------
|
||||
selector : Union[pd.Timestamp, slice, str, list]
|
||||
@@ -60,7 +63,7 @@ def fetch_df_by_index(
|
||||
Data of the given index.
|
||||
"""
|
||||
# level = None -> use selector directly
|
||||
if level == None:
|
||||
if level is None or isinstance(selector, pd.MultiIndex):
|
||||
return df.loc(axis=0)[selector]
|
||||
# Try to get the right index
|
||||
idx_slc = (selector, slice(None, None))
|
||||
@@ -72,7 +75,7 @@ def fetch_df_by_index(
|
||||
return df.loc[
|
||||
pd.IndexSlice[idx_slc],
|
||||
]
|
||||
else:
|
||||
else: # pylint: disable=W0120
|
||||
return df
|
||||
else:
|
||||
return df.loc[
|
||||
@@ -81,7 +84,7 @@ def fetch_df_by_index(
|
||||
|
||||
|
||||
def fetch_df_by_col(df: pd.DataFrame, col_set: Union[str, List[str]]) -> pd.DataFrame:
|
||||
from .handler import DataHandler
|
||||
from .handler import DataHandler # pylint: disable=C0415
|
||||
|
||||
if not isinstance(df.columns, pd.MultiIndex) or col_set == DataHandler.CS_RAW:
|
||||
return df
|
||||
@@ -133,7 +136,7 @@ def init_task_handler(task: dict) -> Union[DataHandler, None]:
|
||||
returns
|
||||
"""
|
||||
# avoid recursive import
|
||||
from .handler import DataHandler
|
||||
from .handler import DataHandler # pylint: disable=C0415
|
||||
|
||||
h_conf = task["dataset"]["kwargs"].get("handler")
|
||||
if h_conf is not None:
|
||||
|
||||
@@ -1,13 +1,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from typing import Union, List, Tuple
|
||||
from ...data.dataset import TSDataSampler
|
||||
from ...data.dataset.utils import get_level_index
|
||||
from ...utils import lazy_sort_index
|
||||
|
||||
|
||||
class Reweighter:
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
||||
@@ -62,7 +62,7 @@ class SeriesDFilter(BaseDFilter):
|
||||
Override _getFilterSeries to use the rule to filter the series and get a dict of {inst => series}, or override filter_main for more advanced series filter rule
|
||||
"""
|
||||
|
||||
def __init__(self, fstart_time=None, fend_time=None):
|
||||
def __init__(self, fstart_time=None, fend_time=None, keep=False):
|
||||
"""Init function for filter base class.
|
||||
Filter a set of instruments based on a certain rule within a certain period assigned by fstart_time and fend_time.
|
||||
|
||||
@@ -72,10 +72,13 @@ class SeriesDFilter(BaseDFilter):
|
||||
the time for the filter rule to start filter the instruments.
|
||||
fend_time: str
|
||||
the time for the filter rule to stop filter the instruments.
|
||||
keep: bool
|
||||
whether to keep the instruments of which features don't exist in the filter time span.
|
||||
"""
|
||||
super(SeriesDFilter, self).__init__()
|
||||
self.filter_start_time = pd.Timestamp(fstart_time) if fstart_time else None
|
||||
self.filter_end_time = pd.Timestamp(fend_time) if fend_time else None
|
||||
self.keep = keep
|
||||
|
||||
def _getTimeBound(self, instruments):
|
||||
"""Get time bound for all instruments.
|
||||
@@ -330,12 +333,9 @@ class ExpressionDFilter(SeriesDFilter):
|
||||
filter the feature ending by this time.
|
||||
rule_expression: str
|
||||
an input expression for the rule.
|
||||
keep: bool
|
||||
whether to keep the instruments of which features don't exist in the filter time span.
|
||||
"""
|
||||
super(ExpressionDFilter, self).__init__(fstart_time, fend_time)
|
||||
super(ExpressionDFilter, self).__init__(fstart_time, fend_time, keep=keep)
|
||||
self.rule_expression = rule_expression
|
||||
self.keep = keep
|
||||
|
||||
def _getFilterSeries(self, instruments, fstart, fend):
|
||||
# do not use dataset cache
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user