mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Merge branch 'main' into online_srv
This commit is contained in:
12
.deepsource.toml
Normal file
12
.deepsource.toml
Normal file
@@ -0,0 +1,12 @@
|
||||
version = 1
|
||||
|
||||
test_patterns = ["tests/test_*.py"]
|
||||
|
||||
exclude_patterns = ["examples/**"]
|
||||
|
||||
[[analyzers]]
|
||||
name = "python"
|
||||
enabled = true
|
||||
|
||||
[analyzers.meta]
|
||||
runtime_version = "3.x.x"
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -34,3 +34,7 @@ tags
|
||||
|
||||
.pytest_cache/
|
||||
.vscode/
|
||||
|
||||
*.swp
|
||||
|
||||
./pretrain
|
||||
|
||||
56
README.md
56
README.md
@@ -17,10 +17,11 @@ Qlib is an AI-oriented quantitative investment platform, which aims to realize t
|
||||
|
||||
It contains the full ML pipeline of data processing, model training, back-testing; and covers the entire chain of quantitative investment: alpha seeking, risk modeling, portfolio optimization, and order execution.
|
||||
|
||||
With Qlib, user can easily try ideas to create better Quant investment strategies.
|
||||
With Qlib, users can easily try ideas to create better Quant investment strategies.
|
||||
|
||||
For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative Investment Platform"](https://arxiv.org/abs/2009.11189).
|
||||
|
||||
- [**News and Plans**](#news-and-plans)
|
||||
- [Framework of Qlib](#framework-of-qlib)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Installation](#installation)
|
||||
@@ -35,9 +36,32 @@ For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative
|
||||
- [Offline Mode and Online Mode](#offline-mode-and-online-mode)
|
||||
- [Performance of Qlib Data Server](#performance-of-qlib-data-server)
|
||||
- [Related Reports](#related-reports)
|
||||
- [Contact Us](#contact-us)
|
||||
- [Contributing](#contributing)
|
||||
|
||||
|
||||
# News And Plans
|
||||
New features under development(order by estimated release time).
|
||||
Your feedbacks about the features are very important.
|
||||
| Feature | Status |
|
||||
| -- | ------ |
|
||||
| Online serving and automatic model rolling | Under review: https://github.com/microsoft/qlib/pull/290 |
|
||||
| Planning-based portfolio optimization | Under review: https://github.com/microsoft/qlib/pull/280 |
|
||||
| Fund data supporting and analysis | Under review: https://github.com/microsoft/qlib/pull/292 |
|
||||
| Point-in-Time database | Under review: https://github.com/microsoft/qlib/pull/343 |
|
||||
| High-frequency trading | Initial opensource version under development |
|
||||
| Meta-Learning-based data selection | Initial opensource version under development |
|
||||
|
||||
Recent released features
|
||||
| Feature | Status |
|
||||
| -- | ------ |
|
||||
| DoubleEnsemble Model | Released https://github.com/microsoft/qlib/pull/286 |
|
||||
| High-frequency data processing example | Released https://github.com/microsoft/qlib/pull/257 |
|
||||
| High-frequency trading example | Part of code released https://github.com/microsoft/qlib/pull/227 |
|
||||
| High-frequency data(1min) | Released https://github.com/microsoft/qlib/pull/221 |
|
||||
| Tabnet Model | Released https://github.com/microsoft/qlib/pull/205 |
|
||||
|
||||
Features released before 2021 are not listed here.
|
||||
|
||||
# Framework of Qlib
|
||||
|
||||
@@ -46,11 +70,11 @@ For more details, please refer to our paper ["Qlib: An AI-oriented Quantitative
|
||||
</div>
|
||||
|
||||
|
||||
At the module level, Qlib is a platform that consists of the above components. The components are designed as loose-coupled modules and each component could be used stand-alone.
|
||||
At the module level, Qlib is a platform that consists of the above components. The components are designed as loose-coupled modules, and each component could be used stand-alone.
|
||||
|
||||
| Name | Description |
|
||||
| ------ | ----- |
|
||||
| `Infrastructure` layer | `Infrastructure` layer provides underlying support for Quant research. `DataServer` provides high-performance infrastructure for users to manage and retrieve raw data. `Trainer` provides flexible interface to control the training process of models which enable algorithms controlling the training process. |
|
||||
| `Infrastructure` layer | `Infrastructure` layer provides underlying support for Quant research. `DataServer` provides a high-performance infrastructure for users to manage and retrieve raw data. `Trainer` provides a flexible interface to control the training process of models, which enable algorithms to control the training process. |
|
||||
| `Workflow` layer | `Workflow` layer covers the whole workflow of quantitative investment. `Information Extractor` extracts data for models. `Forecast Model` focuses on producing all kinds of forecast signals (e.g. _alpha_, risk) for other modules. With these signals `Portfolio Generator` will generate the target portfolio and produce orders to be executed by `Order Executor`. |
|
||||
| `Interface` layer | `Interface` layer tries to present a user-friendly interface for the underlying system. `Analyser` module will provide users detailed analysis reports of forecasting signals, portfolios and execution results |
|
||||
|
||||
@@ -130,7 +154,8 @@ This dataset is created by public data collected by [crawler scripts](scripts/da
|
||||
the same repository.
|
||||
Users could create the same dataset with it.
|
||||
|
||||
*Please pay **ATTENTION** that the data is collected from [Yahoo Finance](https://finance.yahoo.com/lookup) and the data might not be perfect. We recommend users to prepare their own data if they have high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*.
|
||||
*Please pay **ATTENTION** that the data is collected from [Yahoo Finance](https://finance.yahoo.com/lookup), and the data might not be perfect.
|
||||
We recommend users to prepare their own data if they have a high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*.
|
||||
|
||||
<!--
|
||||
- Run the initialization code and get stock data:
|
||||
@@ -220,7 +245,7 @@ Qlib provides a tool named `qrun` to run the whole workflow automatically (inclu
|
||||
-->
|
||||
|
||||
## Building Customized Quant Research Workflow by Code
|
||||
The automatic workflow may not suite the research workflow of all Quant researchers. To support a flexible Quant research workflow, Qlib also provides a modularized interface to allow researchers to build their own workflow by code. [Here](examples/workflow_by_code.ipynb) is a demo for customized Quant research workflow by code.
|
||||
The automatic workflow may not suit the research workflow of all Quant researchers. To support a flexible Quant research workflow, Qlib also provides a modularized interface to allow researchers to build their own workflow by code. [Here](examples/workflow_by_code.ipynb) is a demo for customized Quant research workflow by code.
|
||||
|
||||
|
||||
# [Quant Model Zoo](examples/benchmarks)
|
||||
@@ -237,6 +262,7 @@ Here is a list of models built on `Qlib`.
|
||||
- [SFM based on pytorch (Liheng Zhang, et al. 2017)](qlib/contrib/model/pytorch_sfm.py)
|
||||
- [TFT based on tensorflow (Bryan Lim, et al. 2019)](examples/benchmarks/TFT/tft.py)
|
||||
- [TabNet based on pytorch (Sercan O. Arik, et al. 2019)](qlib/contrib/model/pytorch_tabnet.py)
|
||||
- [DoubleEnsemble based on LightGBM (Chuheng Zhang, et al. 2020)](qlib/contrib/model/double_ensemble.py)
|
||||
|
||||
Your PR of new Quant models is highly welcomed.
|
||||
|
||||
@@ -246,10 +272,10 @@ The performance of each model on the `Alpha158` and `Alpha360` dataset can be fo
|
||||
All the models listed above are runnable with ``Qlib``. Users can find the config files we provide and some details about the model through the [benchmarks](examples/benchmarks) folder. More information can be retrieved at the model files listed above.
|
||||
|
||||
`Qlib` provides three different ways to run a single model, users can pick the one that fits their cases best:
|
||||
- User can use the tool `qrun` mentioned above to run a model's workflow based from a config file.
|
||||
- User can create a `workflow_by_code` python script based on the [one](examples/workflow_by_code.py) listed in the `examples` folder.
|
||||
- Users can use the tool `qrun` mentioned above to run a model's workflow based from a config file.
|
||||
- Users can create a `workflow_by_code` python script based on the [one](examples/workflow_by_code.py) listed in the `examples` folder.
|
||||
|
||||
- User can use the script [`run_all_model.py`](examples/run_all_model.py) listed in the `examples` folder to run a model. Here is an example of the specific shell command to be used: `python run_all_model.py --models=lightgbm`, where the `--models` arguments can take any number of models listed above(the available models can be found in [benchmarks](examples/benchmarks/)). For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
|
||||
- Users can use the script [`run_all_model.py`](examples/run_all_model.py) listed in the `examples` folder to run a model. Here is an example of the specific shell command to be used: `python run_all_model.py --models=lightgbm`, where the `--models` arguments can take any number of models listed above(the available models can be found in [benchmarks](examples/benchmarks/)). For more use cases, please refer to the file's [docstrings](examples/run_all_model.py).
|
||||
|
||||
## Run multiple models
|
||||
`Qlib` also provides a script [`run_all_model.py`](examples/run_all_model.py) which can run multiple models for several iterations. (**Note**: the script only support *Linux* for now. Other OS will be supported in the future. Besides, it doesn't support parrallel running the same model for multiple times as well, and this will be fixed in the future development too.)
|
||||
@@ -312,17 +338,27 @@ which creates a dataset (14 features/factors) from the basic OHLCV daily data of
|
||||
* `+(-)E` indicates with (out) `ExpressionCache`
|
||||
* `+(-)D` indicates with (out) `DatasetCache`
|
||||
|
||||
Most general-purpose databases take too much time on loading data. After looking into the underlying implementation, we find that data go through too many layers of interfaces and unnecessary format transformations in general-purpose database solutions.
|
||||
Most general-purpose databases take too much time to load data. After looking into the underlying implementation, we find that data go through too many layers of interfaces and unnecessary format transformations in general-purpose database solutions.
|
||||
Such overheads greatly slow down the data loading process.
|
||||
Qlib data are stored in a compact format, which is efficient to be combined into arrays for scientific computation.
|
||||
|
||||
|
||||
# Related Reports
|
||||
- [【华泰金工林晓明团队】图神经网络选股与Qlib实践——华泰人工智能系列之四十二](https://mp.weixin.qq.com/s/w5fDB6oAv9dO6vlhf1kmhA)
|
||||
- [Guide To Qlib: Microsoft’s AI Investment Platform](https://analyticsindiamag.com/qlib/)
|
||||
- [【华泰金工林晓明团队】微软AI量化投资平台Qlib体验——华泰人工智能系列之四十](https://mp.weixin.qq.com/s/Brcd7im4NibJOJzZfMn6tQ)
|
||||
- [微软也搞AI量化平台?还是开源的!](https://mp.weixin.qq.com/s/47bP5YwxfTp2uTHjUBzJQQ)
|
||||
- [微矿Qlib:业内首个AI量化投资开源平台](https://mp.weixin.qq.com/s/vsJv7lsgjEi-ALYUz4CvtQ)
|
||||
|
||||
# Contact Us
|
||||
- If you have any issues, please create issue [here](https://github.com/microsoft/qlib/issues/new/choose) or send messages in [gitter](https://gitter.im/Microsoft/qlib).
|
||||
- If you want to make contributions to `Qlib`, please [create pull requests](https://github.com/microsoft/qlib/compare).
|
||||
- For other reasons, you are welcome to contact us by email([qlib@microsoft.com](mailto:qlib@microsoft.com)).
|
||||
- We are recruiting new members(both FTEs and interns), your resumes are welcome!
|
||||
|
||||
Join IM discussion groups:
|
||||
|[Gitter](https://gitter.im/Microsoft/qlib)|
|
||||
|----|
|
||||
||
|
||||
|
||||
# Contributing
|
||||
|
||||
|
||||
@@ -70,3 +70,31 @@ If the issue is not resolved, use ``keys *`` to find if multiple keys exist. If
|
||||
|
||||
|
||||
Also, feel free to post a new issue in our GitHub repository. We always check each issue carefully and try our best to solve them.
|
||||
|
||||
3. ModuleNotFoundError: No module named 'qlib.data._libs.rolling'
|
||||
------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
#### Do not import qlib package in the repository directory in case of importing qlib from . without compiling #####
|
||||
Traceback (most recent call last):
|
||||
File "<stdin>", line 1, in <module>
|
||||
File "qlib/qlib/__init__.py", line 19, in init
|
||||
from .data.cache import H
|
||||
File "qlib/qlib/data/__init__.py", line 8, in <module>
|
||||
from .data import (
|
||||
File "qlib/qlib/data/data.py", line 20, in <module>
|
||||
from .cache import H
|
||||
File "qlib/qlib/data/cache.py", line 36, in <module>
|
||||
from .ops import Operators
|
||||
File "qlib/qlib/data/ops.py", line 19, in <module>
|
||||
from ._libs.rolling import rolling_slope, rolling_rsquare, rolling_resi
|
||||
ModuleNotFoundError: No module named 'qlib.data._libs.rolling'
|
||||
|
||||
- If the error occurs when importing ``qlib`` package with ``PyCharm`` IDE, users can execute the following command in the project root folder to compile Cython files and generate executable files:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python setup.py build_ext --inplace
|
||||
|
||||
- If the error occurs when importing ``qlib`` package with command ``python`` , users need to change the running directory to ensure that the script does not run in the project directory.
|
||||
BIN
docs/_static/img/qrcode/gitter_qr.png
vendored
Normal file
BIN
docs/_static/img/qrcode/gitter_qr.png
vendored
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 7.2 KiB |
@@ -61,7 +61,7 @@ In addition to China-Stock data, ``Qlib`` also includes a US-Stock dataset, whic
|
||||
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/us_data --region us
|
||||
|
||||
After running the above command, users can find china-stock and us-stock data in ``Qlib`` format in the ``~/.qlib/csv_data/cn_data`` directory and ``~/.qlib/csv_data/us_data`` directory respectively.
|
||||
After running the above command, users can find china-stock and us-stock data in ``Qlib`` format in the ``~/.qlib/qlib_data/cn_data`` directory and ``~/.qlib/qlib_data/us_data`` directory respectively.
|
||||
|
||||
``Qlib`` also provides the scripts in ``scripts/data_collector`` to help users crawl the latest data on the Internet and convert it to qlib format.
|
||||
|
||||
@@ -163,7 +163,7 @@ The `trade unit` defines the unit number of stocks can be used in a trade, and t
|
||||
- If users use ``Qlib`` in china-stock mode, china-stock data is required. Users can use ``Qlib`` in china-stock mode according to the following steps:
|
||||
- Download china-stock in qlib format, please refer to section `Qlib Format Dataset <#qlib-format-dataset>`_.
|
||||
- Initialize ``Qlib`` in china-stock mode
|
||||
Supposed that users download their Qlib format data in the directory ``~/.qlib/csv_data/cn_data``. Users only need to initialize ``Qlib`` as follows.
|
||||
Supposed that users download their Qlib format data in the directory ``~/.qlib/qlib_data/cn_data``. Users only need to initialize ``Qlib`` as follows.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -174,7 +174,7 @@ The `trade unit` defines the unit number of stocks can be used in a trade, and t
|
||||
- If users use ``Qlib`` in US-stock mode, US-stock data is required. ``Qlib`` also provides a script to download US-stock data. Users can use ``Qlib`` in US-stock mode according to the following steps:
|
||||
- Download us-stock in qlib format, please refer to section `Qlib Format Dataset <#qlib-format-dataset>`_.
|
||||
- Initialize ``Qlib`` in US-stock mode
|
||||
Supposed that users prepare their Qlib format data in the directory ``~/.qlib/csv_data/us_data``. Users only need to initialize ``Qlib`` as follows.
|
||||
Supposed that users prepare their Qlib format data in the directory ``~/.qlib/qlib_data/us_data``. Users only need to initialize ``Qlib`` as follows.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -218,6 +218,25 @@ Filter
|
||||
- `cross-sectional features filter` \: rule_expression = '$rank($close)<10'
|
||||
- `time-sequence features filter`: rule_expression = '$Ref($close, 3)>100'
|
||||
|
||||
Here is a simple example showing how to use filter in a basic ``Qlib`` workflow configuration file:
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
filter: &filter
|
||||
filter_type: ExpressionDFilter
|
||||
rule_expression: "Ref($close, -2) / Ref($close, -1) > 1"
|
||||
filter_start_time: 2010-01-01
|
||||
filter_end_time: 2010-01-07
|
||||
keep: False
|
||||
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2010-01-01
|
||||
end_time: 2021-01-22
|
||||
fit_start_time: 2010-01-01
|
||||
fit_end_time: 2015-12-31
|
||||
instruments: *market
|
||||
filter_pipe: [*filter]
|
||||
|
||||
To know more about ``Filter``, please refer to `Filter API <../reference/api.html#module-qlib.data.filter>`_.
|
||||
|
||||
Reference
|
||||
|
||||
4
examples/benchmarks/DoubleEnsemble/README.md
Normal file
4
examples/benchmarks/DoubleEnsemble/README.md
Normal file
@@ -0,0 +1,4 @@
|
||||
# DoubleEnsemble
|
||||
* DoubleEnsemble is an ensemble framework leveraging learning trajectory based sample reweighting and shuffling based feature selection, to solve both the low signal-to-noise ratio and increasing number of features problems. They identify the key samples based on the training dynamics on each sample and elicit key features based on the ablation impact of each feature via shuffling. The model is applicable to a wide range of base models, capable of extracting complex patterns, while mitigating the overfitting and instability issues for financial market prediction.
|
||||
* This code used in Qlib is implemented by ourselves.
|
||||
* Paper: DoubleEnsemble: A New Ensemble Method Based on Sample Reweighting and Feature Selection for Financial Data Analysis [https://arxiv.org/pdf/2010.01265.pdf](https://arxiv.org/pdf/2010.01265.pdf).
|
||||
3
examples/benchmarks/DoubleEnsemble/requirements.txt
Normal file
3
examples/benchmarks/DoubleEnsemble/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.17.4
|
||||
lightgbm==3.1.0
|
||||
@@ -0,0 +1,90 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: DEnsembleModel
|
||||
module_path: qlib.contrib.model.double_ensemble
|
||||
kwargs:
|
||||
base_model: "gbm"
|
||||
loss: mse
|
||||
num_models: 6
|
||||
enable_sr: True
|
||||
enable_fs: True
|
||||
alpha1: 1
|
||||
alpha2: 1
|
||||
bins_sr: 10
|
||||
bins_fs: 5
|
||||
decay: 0.5
|
||||
sample_ratios:
|
||||
- 0.8
|
||||
- 0.7
|
||||
- 0.6
|
||||
- 0.5
|
||||
- 0.4
|
||||
sub_weights:
|
||||
- 1
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
epochs: 28
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.2
|
||||
subsample: 0.8789
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
verbosity: -1
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -0,0 +1,97 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors: []
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: DEnsembleModel
|
||||
module_path: qlib.contrib.model.double_ensemble
|
||||
kwargs:
|
||||
base_model: "gbm"
|
||||
loss: mse
|
||||
num_models: 6
|
||||
enable_sr: True
|
||||
enable_fs: True
|
||||
alpha1: 1
|
||||
alpha2: 1
|
||||
bins_sr: 10
|
||||
bins_fs: 5
|
||||
decay: 0.5
|
||||
sample_ratios:
|
||||
- 0.8
|
||||
- 0.7
|
||||
- 0.6
|
||||
- 0.5
|
||||
- 0.4
|
||||
sub_weights:
|
||||
- 1
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
epochs: 136
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.0421
|
||||
subsample: 0.8789
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
verbosity: -1
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -29,7 +29,7 @@ data_handler_config: &data_handler_config
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
|
||||
@@ -16,6 +16,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
| LSTM (Sepp Hochreiter, et al.) | Alpha360 | 0.0443±0.01 | 0.3401±0.05| 0.0536±0.01 | 0.4248±0.05 | 0.0627±0.03 | 0.8441±0.48| -0.0882±0.03 |
|
||||
| ALSTM (Yao Qin, et al.) | Alpha360 | 0.0493±0.01 | 0.3778±0.06| 0.0585±0.00 | 0.4606±0.04 | 0.0513±0.03 | 0.6727±0.38| -0.1085±0.02 |
|
||||
| GATs (Petar Velickovic, et al.) | Alpha360 | 0.0475±0.00 | 0.3515±0.02| 0.0592±0.00 | 0.4585±0.01 | 0.0876±0.02 | 1.1513±0.27| -0.0795±0.02 |
|
||||
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha360 | 0.0407±0.00| 0.3053±0.00 | 0.0490±0.00 | 0.3840±0.00 | 0.0380±0.02 | 0.5000±0.21 | -0.0984±0.02 |
|
||||
|
||||
## Alpha158 dataset
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
@@ -25,11 +26,12 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
| XGBoost (Tianqi Chen, et al.) | Alpha158 | 0.0481±0.00 | 0.3659±0.00| 0.0495±0.00 | 0.4033±0.00 | 0.1111±0.00 | 1.2915±0.00| -0.0893±0.00 |
|
||||
| LightGBM (Guolin Ke, et al.) | Alpha158 | 0.0475±0.00 | 0.3979±0.00| 0.0485±0.00 | 0.4123±0.00 | 0.1143±0.00 | 1.2744±0.00| -0.0800±0.00 |
|
||||
| MLP | Alpha158 | 0.0358±0.00 | 0.2738±0.03| 0.0425±0.00 | 0.3221±0.01 | 0.0836±0.02 | 1.0323±0.25| -0.1127±0.02 |
|
||||
| TabNet with pretrain (Sercan O. Arikm et al) | Alpha158 | 0.0344±0.00|0.205±0.11|0.0398±0.00 |0.3479±0.01|0.0827±0.02|1.1141±0.32 |-0.0925±0.02 |
|
||||
| TFT (Bryan Lim, et al.) | Alpha158 (with selected 20 features) | 0.0343±0.00 | 0.2071±0.02| 0.0107±0.00 | 0.0660±0.02 | 0.0623±0.02 | 0.5818±0.20| -0.1762±0.01 |
|
||||
| GRU (Kyunghyun Cho, et al.) | Alpha158 (with selected 20 features) | 0.0311±0.00 | 0.2418±0.04| 0.0425±0.00 | 0.3434±0.02 | 0.0330±0.02 | 0.4805±0.30| -0.1021±0.02 |
|
||||
| LSTM (Sepp Hochreiter, et al.) | Alpha158 (with selected 20 features) | 0.0312±0.00 | 0.2394±0.04| 0.0418±0.00 | 0.3324±0.03 | 0.0298±0.02 | 0.4198±0.33| -0.1348±0.03 |
|
||||
| ALSTM (Yao Qin, et al.) | Alpha158 (with selected 20 features) | 0.0385±0.01 | 0.3022±0.06| 0.0478±0.00 | 0.3874±0.04 | 0.0486±0.03 | 0.7141±0.45| -0.1088±0.03 |
|
||||
| GATs (Petar Velickovic, et al.) | Alpha158 (with selected 20 features) | 0.0349±0.00 | 0.2511±0.01| 0.0457±0.00 | 0.3537±0.01 | 0.0578±0.02 | 0.8221±0.25| -0.0824±0.02 |
|
||||
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha158 | 0.0544±0.00 | 0.4338±0.01 | 0.0523±0.00 | 0.4257±0.01 | 0.1253±0.01 | 1.4105±0.14 | -0.0902±0.01 |
|
||||
|
||||
- The selected 20 features are based on the feature importance of a lightgbm-based model.
|
||||
- The base model of DoubleEnsemble is LGBM.
|
||||
|
||||
Binary file not shown.
@@ -44,6 +44,7 @@ task:
|
||||
class: TabnetModel
|
||||
module_path: qlib.contrib.model.pytorch_tabnet
|
||||
kwargs:
|
||||
d_feat: 158
|
||||
pretrain: True
|
||||
dataset:
|
||||
class: DatasetH
|
||||
@@ -55,7 +56,7 @@ task:
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
pretrain: [2008-01-01, 2014-12-31]
|
||||
pretrain_validation: [2015-01-01, 2020-08-01]
|
||||
pretrain_validation: [2015-01-01, 2016-12-31]
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
|
||||
@@ -0,0 +1,75 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: TabnetModel
|
||||
module_path: qlib.contrib.model.pytorch_tabnet
|
||||
kwargs:
|
||||
d_feat: 360
|
||||
pretrain: True
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
pretrain: [2008-01-01, 2014-12-31]
|
||||
pretrain_validation: [2015-01-01, 2016-12-31]
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -108,7 +108,7 @@ _default_config = {
|
||||
"redis_port": 6379,
|
||||
"redis_task_db": 1,
|
||||
# This value can be reset via qlib.init
|
||||
"logging_level": "INFO",
|
||||
"logging_level": logging.INFO,
|
||||
# Global configuration of qlib log
|
||||
# logging_level can control the logging level more finely
|
||||
"logging_config": {
|
||||
@@ -127,12 +127,12 @@ _default_config = {
|
||||
"handlers": {
|
||||
"console": {
|
||||
"class": "logging.StreamHandler",
|
||||
"level": "DEBUG",
|
||||
"level": logging.DEBUG,
|
||||
"formatter": "logger_format",
|
||||
"filters": ["field_not_found"],
|
||||
}
|
||||
},
|
||||
"loggers": {"qlib": {"level": "DEBUG", "handlers": ["console"]}},
|
||||
"loggers": {"qlib": {"level": logging.DEBUG, "handlers": ["console"]}},
|
||||
},
|
||||
# Defatult config for experiment manager
|
||||
"exp_manager": {
|
||||
@@ -188,7 +188,7 @@ MODE_CONF = {
|
||||
# The nfs should be auto-mounted by qlib on other
|
||||
# serversS(such as PAI) [auto_mount:True]
|
||||
"timeout": 100,
|
||||
"logging_level": "INFO",
|
||||
"logging_level": logging.INFO,
|
||||
"region": REG_CN,
|
||||
## Custom Operator
|
||||
"custom_ops": [],
|
||||
|
||||
@@ -104,10 +104,9 @@ class Account:
|
||||
# if suspend, no new price to be updated, profit is 0
|
||||
if trader.check_stock_suspended(code, today):
|
||||
continue
|
||||
else:
|
||||
today_close = trader.get_close(code, today)
|
||||
profit += (today_close - self.current.position[code]["price"]) * self.current.position[code]["amount"]
|
||||
self.current.update_stock_price(stock_id=code, price=today_close)
|
||||
today_close = trader.get_close(code, today)
|
||||
profit += (today_close - self.current.position[code]["price"]) * self.current.position[code]["amount"]
|
||||
self.current.update_stock_price(stock_id=code, price=today_close)
|
||||
self.rtn += profit
|
||||
# update holding day count
|
||||
self.current.add_count_all()
|
||||
|
||||
@@ -61,7 +61,7 @@ def get_position_value(evaluate_date, position):
|
||||
# load close price for position
|
||||
# position should also consider cash
|
||||
instruments = list(position.keys())
|
||||
instruments = list(set(instruments) - set(["cash"])) # filter 'cash'
|
||||
instruments = list(set(instruments) - {"cash"}) # filter 'cash'
|
||||
fields = ["$close"]
|
||||
close_data_df = D.features(
|
||||
instruments,
|
||||
@@ -80,7 +80,7 @@ def get_position_list_value(positions):
|
||||
instruments = set()
|
||||
for day, position in positions.items():
|
||||
instruments.update(position.keys())
|
||||
instruments = list(set(instruments) - set(["cash"])) # filter 'cash'
|
||||
instruments = list(set(instruments) - {"cash"}) # filter 'cash'
|
||||
instruments.sort()
|
||||
day_list = list(positions.keys())
|
||||
day_list.sort()
|
||||
|
||||
247
qlib/contrib/model/double_ensemble.py
Normal file
247
qlib/contrib/model/double_ensemble.py
Normal file
@@ -0,0 +1,247 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import lightgbm as lgb
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...log import get_module_logger
|
||||
|
||||
|
||||
class DEnsembleModel(Model):
|
||||
"""Double Ensemble Model"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_model="gbm",
|
||||
loss="mse",
|
||||
num_models=6,
|
||||
enable_sr=True,
|
||||
enable_fs=True,
|
||||
alpha1=1.0,
|
||||
alpha2=1.0,
|
||||
bins_sr=10,
|
||||
bins_fs=5,
|
||||
decay=None,
|
||||
sample_ratios=None,
|
||||
sub_weights=None,
|
||||
epochs=100,
|
||||
**kwargs
|
||||
):
|
||||
self.base_model = base_model # "gbm" or "mlp", specifically, we use lgbm for "gbm"
|
||||
self.num_models = num_models # the number of sub-models
|
||||
self.enable_sr = enable_sr
|
||||
self.enable_fs = enable_fs
|
||||
self.alpha1 = alpha1
|
||||
self.alpha2 = alpha2
|
||||
self.bins_sr = bins_sr
|
||||
self.bins_fs = bins_fs
|
||||
self.decay = decay
|
||||
if not len(sample_ratios) == bins_fs:
|
||||
raise ValueError("The length of sample_ratios should be equal to bins_fs.")
|
||||
self.sample_ratios = sample_ratios
|
||||
if not len(sub_weights) == num_models:
|
||||
raise ValueError("The length of sub_weights should be equal to num_models.")
|
||||
self.sub_weights = sub_weights
|
||||
self.epochs = epochs
|
||||
self.logger = get_module_logger("DEnsembleModel")
|
||||
self.logger.info("Double Ensemble Model...")
|
||||
self.ensemble = [] # the current ensemble model, a list contains all the sub-models
|
||||
self.sub_features = [] # the features for each sub model in the form of pandas.Index
|
||||
self.params = {"objective": loss}
|
||||
self.params.update(kwargs)
|
||||
self.loss = loss
|
||||
|
||||
def fit(self, dataset: DatasetH):
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
)
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
# initialize the sample weights
|
||||
N, F = x_train.shape
|
||||
weights = pd.Series(np.ones(N, dtype=float))
|
||||
# initialize the features
|
||||
features = x_train.columns
|
||||
pred_sub = pd.DataFrame(np.zeros((N, self.num_models), dtype=float), index=x_train.index)
|
||||
# train sub-models
|
||||
for k in range(self.num_models):
|
||||
self.sub_features.append(features)
|
||||
self.logger.info("Training sub-model: ({}/{})".format(k + 1, self.num_models))
|
||||
model_k = self.train_submodel(df_train, df_valid, weights, features)
|
||||
self.ensemble.append(model_k)
|
||||
# no further sample re-weight and feature selection needed for the last sub-model
|
||||
if k + 1 == self.num_models:
|
||||
break
|
||||
|
||||
self.logger.info("Retrieving loss curve and loss values...")
|
||||
loss_curve = self.retrieve_loss_curve(model_k, df_train, features)
|
||||
pred_k = self.predict_sub(model_k, df_train, features)
|
||||
pred_sub.iloc[:, k] = pred_k
|
||||
pred_ensemble = pred_sub.iloc[:, : k + 1].mean(axis=1)
|
||||
loss_values = pd.Series(self.get_loss(y_train.values.squeeze(), pred_ensemble.values))
|
||||
|
||||
if self.enable_sr:
|
||||
self.logger.info("Sample re-weighting...")
|
||||
weights = self.sample_reweight(loss_curve, loss_values, k + 1)
|
||||
|
||||
if self.enable_fs:
|
||||
self.logger.info("Feature selection...")
|
||||
features = self.feature_selection(df_train, loss_values)
|
||||
|
||||
def train_submodel(self, df_train, df_valid, weights, features):
|
||||
dtrain, dvalid = self._prepare_data_gbm(df_train, df_valid, weights, features)
|
||||
evals_result = dict()
|
||||
model = lgb.train(
|
||||
self.params,
|
||||
dtrain,
|
||||
num_boost_round=self.epochs,
|
||||
valid_sets=[dtrain, dvalid],
|
||||
valid_names=["train", "valid"],
|
||||
verbose_eval=20,
|
||||
evals_result=evals_result,
|
||||
)
|
||||
evals_result["train"] = list(evals_result["train"].values())[0]
|
||||
evals_result["valid"] = list(evals_result["valid"].values())[0]
|
||||
return model
|
||||
|
||||
def _prepare_data_gbm(self, df_train, df_valid, weights, features):
|
||||
x_train, y_train = df_train["feature"].loc[:, features], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"].loc[:, features], df_valid["label"]
|
||||
|
||||
# Lightgbm need 1D array as its label
|
||||
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
|
||||
y_train, y_valid = np.squeeze(y_train.values), np.squeeze(y_valid.values)
|
||||
else:
|
||||
raise ValueError("LightGBM doesn't support multi-label training")
|
||||
|
||||
dtrain = lgb.Dataset(x_train.values, label=y_train, weight=weights)
|
||||
dvalid = lgb.Dataset(x_valid.values, label=y_valid)
|
||||
return dtrain, dvalid
|
||||
|
||||
def sample_reweight(self, loss_curve, loss_values, k_th):
|
||||
"""
|
||||
the SR module of Double Ensemble
|
||||
:param loss_curve: the shape is NxT
|
||||
the loss curve for the previous sub-model, where the element (i, t) if the error on the i-th sample
|
||||
after the t-th iteration in the training of the previous sub-model.
|
||||
:param loss_values: the shape is N
|
||||
the loss of the current ensemble on the i-th sample.
|
||||
:param k_th: the index of the current sub-model, starting from 1
|
||||
:return: weights
|
||||
the weights for all the samples.
|
||||
"""
|
||||
# normalize loss_curve and loss_values with ranking
|
||||
loss_curve_norm = loss_curve.rank(axis=0, pct=True)
|
||||
loss_values_norm = (-loss_values).rank(pct=True)
|
||||
|
||||
# calculate l_start and l_end from loss_curve
|
||||
N, T = loss_curve.shape
|
||||
part = np.maximum(int(T * 0.1), 1)
|
||||
l_start = loss_curve_norm.iloc[:, :part].mean(axis=1)
|
||||
l_end = loss_curve_norm.iloc[:, -part:].mean(axis=1)
|
||||
|
||||
# calculate h-value for each sample
|
||||
h1 = loss_values_norm
|
||||
h2 = (l_end / l_start).rank(pct=True)
|
||||
h = pd.DataFrame({"h_value": self.alpha1 * h1 + self.alpha2 * h2})
|
||||
|
||||
# calculate weights
|
||||
h["bins"] = pd.cut(h["h_value"], self.bins_sr)
|
||||
h_avg = h.groupby("bins")["h_value"].mean()
|
||||
weights = pd.Series(np.zeros(N, dtype=float))
|
||||
for i_b, b in enumerate(h_avg.index):
|
||||
weights[h["bins"] == b] = 1.0 / (self.decay ** k_th * h_avg[i_b] + 0.1)
|
||||
return weights
|
||||
|
||||
def feature_selection(self, df_train, loss_values):
|
||||
"""
|
||||
the FS module of Double Ensemble
|
||||
:param df_train: the shape is NxF
|
||||
:param loss_values: the shape is N
|
||||
the loss of the current ensemble on the i-th sample.
|
||||
:return: res_feat: in the form of pandas.Index
|
||||
|
||||
"""
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
features = x_train.columns
|
||||
N, F = x_train.shape
|
||||
g = pd.DataFrame({"g_value": np.zeros(F, dtype=float)})
|
||||
M = len(self.ensemble)
|
||||
|
||||
# shuffle specific columns and calculate g-value for each feature
|
||||
x_train_tmp = x_train.copy()
|
||||
for i_f, feat in enumerate(features):
|
||||
x_train_tmp.loc[:, feat] = np.random.permutation(x_train_tmp.loc[:, feat].values)
|
||||
pred = pd.Series(np.zeros(N), index=x_train_tmp.index)
|
||||
for i_s, submodel in enumerate(self.ensemble):
|
||||
pred += (
|
||||
pd.Series(
|
||||
submodel.predict(x_train_tmp.loc[:, self.sub_features[i_s]].values), index=x_train_tmp.index
|
||||
)
|
||||
/ M
|
||||
)
|
||||
loss_feat = self.get_loss(y_train.values.squeeze(), pred.values)
|
||||
g.loc[i_f, "g_value"] = np.mean(loss_feat - loss_values) / (np.std(loss_feat - loss_values) + 1e-7)
|
||||
x_train_tmp.loc[:, feat] = x_train.loc[:, feat].copy()
|
||||
|
||||
# one column in train features is all-nan # if g['g_value'].isna().any()
|
||||
g["g_value"].replace(np.nan, 0, inplace=True)
|
||||
|
||||
# divide features into bins_fs bins
|
||||
g["bins"] = pd.cut(g["g_value"], self.bins_fs)
|
||||
|
||||
# randomly sample features from bins to construct the new features
|
||||
res_feat = []
|
||||
sorted_bins = sorted(g["bins"].unique(), reverse=True)
|
||||
for i_b, b in enumerate(sorted_bins):
|
||||
b_feat = features[g["bins"] == b]
|
||||
num_feat = int(np.ceil(self.sample_ratios[i_b] * len(b_feat)))
|
||||
res_feat = res_feat + np.random.choice(b_feat, size=num_feat).tolist()
|
||||
return pd.Index(res_feat)
|
||||
|
||||
def get_loss(self, label, pred):
|
||||
if self.loss == "mse":
|
||||
return (label - pred) ** 2
|
||||
else:
|
||||
raise ValueError("not implemented yet")
|
||||
|
||||
def retrieve_loss_curve(self, model, df_train, features):
|
||||
if self.base_model == "gbm":
|
||||
num_trees = model.num_trees()
|
||||
x_train, y_train = df_train["feature"].loc[:, features], df_train["label"]
|
||||
# Lightgbm need 1D array as its label
|
||||
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
|
||||
y_train = np.squeeze(y_train.values)
|
||||
else:
|
||||
raise ValueError("LightGBM doesn't support multi-label training")
|
||||
|
||||
N = x_train.shape[0]
|
||||
loss_curve = pd.DataFrame(np.zeros((N, num_trees)))
|
||||
pred_tree = np.zeros(N, dtype=float)
|
||||
for i_tree in range(num_trees):
|
||||
pred_tree += model.predict(x_train.values, start_iteration=i_tree, num_iteration=1)
|
||||
loss_curve.iloc[:, i_tree] = self.get_loss(y_train, pred_tree)
|
||||
else:
|
||||
raise ValueError("not implemented yet")
|
||||
return loss_curve
|
||||
|
||||
def predict(self, dataset):
|
||||
if self.ensemble is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
pred = pd.Series(np.zeros(x_test.shape[0]), index=x_test.index)
|
||||
for i_sub, submodel in enumerate(self.ensemble):
|
||||
feat_sub = self.sub_features[i_sub]
|
||||
pred += (
|
||||
pd.Series(submodel.predict(x_test.loc[:, feat_sub].values), index=x_test.index)
|
||||
* self.sub_weights[i_sub]
|
||||
)
|
||||
return pred
|
||||
|
||||
def predict_sub(self, submodel, df_data, features):
|
||||
x_data, y_data = df_data["feature"].loc[:, features], df_data["label"]
|
||||
pred_sub = pd.Series(submodel.predict(x_data.values), index=x_data.index)
|
||||
return pred_sub
|
||||
@@ -9,20 +9,19 @@ import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
@@ -39,8 +38,8 @@ class ALSTM(Model):
|
||||
the evaluate metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
GPU : int
|
||||
the GPU ID used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -76,8 +75,7 @@ class ALSTM(Model):
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
@@ -93,7 +91,7 @@ class ALSTM(Model):
|
||||
"\nearly_stop : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\ndevice : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
d_feat,
|
||||
@@ -107,7 +105,7 @@ class ALSTM(Model):
|
||||
early_stop,
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
GPU,
|
||||
self.device,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
)
|
||||
@@ -123,6 +121,9 @@ class ALSTM(Model):
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.ALSTM_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.ALSTM_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.ALSTM_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
@@ -133,6 +134,10 @@ class ALSTM(Model):
|
||||
self.fitted = False
|
||||
self.ALSTM_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
@@ -201,12 +206,13 @@ class ALSTM(Model):
|
||||
feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
|
||||
pred = self.ALSTM_model(feature)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
with torch.no_grad():
|
||||
pred = self.ALSTM_model(feature)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
@@ -214,7 +220,6 @@ class ALSTM(Model):
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
@@ -227,8 +232,7 @@ class ALSTM(Model):
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = -np.inf
|
||||
@@ -290,10 +294,7 @@ class ALSTM(Model):
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_gpu:
|
||||
pred = self.ALSTM_model(x_batch).detach().cpu().numpy()
|
||||
else:
|
||||
pred = self.ALSTM_model(x_batch).detach().numpy()
|
||||
pred = self.ALSTM_model(x_batch).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
|
||||
@@ -9,21 +9,20 @@ import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH, TSDatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
@@ -40,8 +39,8 @@ class ALSTM(Model):
|
||||
the evaluate metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
GPU : int
|
||||
the GPU ID used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -78,9 +77,8 @@ class ALSTM(Model):
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.n_jobs = n_jobs
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
@@ -96,7 +94,7 @@ class ALSTM(Model):
|
||||
"\nearly_stop : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\ndevice : {}"
|
||||
"\nn_jobs : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
@@ -111,7 +109,7 @@ class ALSTM(Model):
|
||||
early_stop,
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
GPU,
|
||||
self.device,
|
||||
n_jobs,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
@@ -127,7 +125,10 @@ class ALSTM(Model):
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
).to(self.device)
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.ALSTM_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.ALSTM_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.ALSTM_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
@@ -138,6 +139,10 @@ class ALSTM(Model):
|
||||
self.fitted = False
|
||||
self.ALSTM_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
@@ -188,12 +193,13 @@ class ALSTM(Model):
|
||||
# feature[torch.isnan(feature)] = 0
|
||||
label = data[:, -1, -1].to(self.device)
|
||||
|
||||
pred = self.ALSTM_model(feature.float())
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
with torch.no_grad():
|
||||
pred = self.ALSTM_model(feature.float())
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
@@ -201,7 +207,6 @@ class ALSTM(Model):
|
||||
self,
|
||||
dataset,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
@@ -210,11 +215,14 @@ class ALSTM(Model):
|
||||
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
|
||||
train_loader = DataLoader(dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs)
|
||||
valid_loader = DataLoader(dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs)
|
||||
train_loader = DataLoader(
|
||||
dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True
|
||||
)
|
||||
valid_loader = DataLoader(
|
||||
dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True
|
||||
)
|
||||
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
save_path = get_or_create_path(save_path)
|
||||
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
@@ -271,10 +279,7 @@ class ALSTM(Model):
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_gpu:
|
||||
pred = self.ALSTM_model(feature.float()).detach().cpu().numpy()
|
||||
else:
|
||||
pred = self.ALSTM_model(feature.float()).detach().numpy()
|
||||
pred = self.ALSTM_model(feature.float()).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
|
||||
@@ -9,19 +9,18 @@ import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
from ...log import get_module_logger
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
@@ -42,8 +41,8 @@ class GATs(Model):
|
||||
the evaluate metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
GPU : int
|
||||
the GPU ID used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -83,7 +82,7 @@ class GATs(Model):
|
||||
self.base_model = base_model
|
||||
self.with_pretrain = with_pretrain
|
||||
self.model_path = model_path
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.seed = seed
|
||||
|
||||
@@ -102,7 +101,7 @@ class GATs(Model):
|
||||
"\nbase_model : {}"
|
||||
"\nwith_pretrain : {}"
|
||||
"\nmodel_path : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\ndevice : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
d_feat,
|
||||
@@ -118,7 +117,7 @@ class GATs(Model):
|
||||
base_model,
|
||||
with_pretrain,
|
||||
model_path,
|
||||
GPU,
|
||||
self.device,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
)
|
||||
@@ -135,6 +134,9 @@ class GATs(Model):
|
||||
dropout=self.dropout,
|
||||
base_model=self.base_model,
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.GAT_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.GAT_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.GAT_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
@@ -145,6 +147,10 @@ class GATs(Model):
|
||||
self.fitted = False
|
||||
self.GAT_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
@@ -232,7 +238,6 @@ class GATs(Model):
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
@@ -245,8 +250,7 @@ class GATs(Model):
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
@@ -324,10 +328,7 @@ class GATs(Model):
|
||||
x_batch = torch.from_numpy(x_values[batch]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_gpu:
|
||||
pred = self.GAT_model(x_batch).detach().cpu().numpy()
|
||||
else:
|
||||
pred = self.GAT_model(x_batch).detach().numpy()
|
||||
pred = self.GAT_model(x_batch).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
|
||||
@@ -9,21 +9,20 @@ import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
from ...log import get_module_logger
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import Sampler
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
@@ -62,8 +61,8 @@ class GATs(Model):
|
||||
the evaluate metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
GPU : int
|
||||
the GPU ID used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -104,9 +103,8 @@ class GATs(Model):
|
||||
self.base_model = base_model
|
||||
self.with_pretrain = with_pretrain
|
||||
self.model_path = model_path
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.n_jobs = n_jobs
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
@@ -157,6 +155,9 @@ class GATs(Model):
|
||||
dropout=self.dropout,
|
||||
base_model=self.base_model,
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.GAT_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.GAT_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.GAT_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
@@ -167,6 +168,10 @@ class GATs(Model):
|
||||
self.fitted = False
|
||||
self.GAT_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
@@ -245,7 +250,6 @@ class GATs(Model):
|
||||
self,
|
||||
dataset,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
@@ -258,11 +262,10 @@ class GATs(Model):
|
||||
sampler_train = DailyBatchSampler(dl_train)
|
||||
sampler_valid = DailyBatchSampler(dl_valid)
|
||||
|
||||
train_loader = DataLoader(dl_train, sampler=sampler_train, num_workers=self.n_jobs)
|
||||
valid_loader = DataLoader(dl_valid, sampler=sampler_valid, num_workers=self.n_jobs)
|
||||
train_loader = DataLoader(dl_train, sampler=sampler_train, num_workers=self.n_jobs, drop_last=True)
|
||||
valid_loader = DataLoader(dl_valid, sampler=sampler_valid, num_workers=self.n_jobs, drop_last=True)
|
||||
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
save_path = get_or_create_path(save_path)
|
||||
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
@@ -345,10 +348,7 @@ class GATs(Model):
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_gpu:
|
||||
pred = self.GAT_model(feature.float()).detach().cpu().numpy()
|
||||
else:
|
||||
pred = self.GAT_model(feature.float()).detach().numpy()
|
||||
pred = self.GAT_model(feature.float()).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
|
||||
@@ -9,20 +9,19 @@ import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
@@ -76,8 +75,7 @@ class GRU(Model):
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
@@ -123,6 +121,9 @@ class GRU(Model):
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.gru_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.gru_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.gru_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
@@ -133,6 +134,10 @@ class GRU(Model):
|
||||
self.fitted = False
|
||||
self.gru_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
@@ -201,12 +206,13 @@ class GRU(Model):
|
||||
feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
|
||||
pred = self.gru_model(feature)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
with torch.no_grad():
|
||||
pred = self.gru_model(feature)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
@@ -214,7 +220,6 @@ class GRU(Model):
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
@@ -227,8 +232,7 @@ class GRU(Model):
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = -np.inf
|
||||
@@ -290,10 +294,7 @@ class GRU(Model):
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_gpu:
|
||||
pred = self.gru_model(x_batch).detach().cpu().numpy()
|
||||
else:
|
||||
pred = self.gru_model(x_batch).detach().numpy()
|
||||
pred = self.gru_model(x_batch).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
|
||||
@@ -9,21 +9,20 @@ import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH, TSDatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
@@ -78,9 +77,8 @@ class GRU(Model):
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.n_jobs = n_jobs
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
@@ -96,7 +94,7 @@ class GRU(Model):
|
||||
"\nearly_stop : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\ndevice : {}"
|
||||
"\nn_jobs : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
@@ -111,7 +109,7 @@ class GRU(Model):
|
||||
early_stop,
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
GPU,
|
||||
self.device,
|
||||
n_jobs,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
@@ -127,7 +125,10 @@ class GRU(Model):
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
).to(self.device)
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.gru_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.gru_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.GRU_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
@@ -138,6 +139,10 @@ class GRU(Model):
|
||||
self.fitted = False
|
||||
self.GRU_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
@@ -188,12 +193,13 @@ class GRU(Model):
|
||||
# feature[torch.isnan(feature)] = 0
|
||||
label = data[:, -1, -1].to(self.device)
|
||||
|
||||
pred = self.GRU_model(feature.float())
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
with torch.no_grad():
|
||||
pred = self.GRU_model(feature.float())
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
@@ -201,7 +207,6 @@ class GRU(Model):
|
||||
self,
|
||||
dataset,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
@@ -210,11 +215,14 @@ class GRU(Model):
|
||||
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
|
||||
train_loader = DataLoader(dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs)
|
||||
valid_loader = DataLoader(dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs)
|
||||
train_loader = DataLoader(
|
||||
dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True
|
||||
)
|
||||
valid_loader = DataLoader(
|
||||
dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True
|
||||
)
|
||||
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
save_path = get_or_create_path(save_path)
|
||||
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
@@ -271,10 +279,7 @@ class GRU(Model):
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_gpu:
|
||||
pred = self.GRU_model(feature.float()).detach().cpu().numpy()
|
||||
else:
|
||||
pred = self.GRU_model(feature.float()).detach().numpy()
|
||||
pred = self.GRU_model(feature.float()).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
|
||||
@@ -9,15 +9,13 @@ import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -76,8 +74,7 @@ class LSTM(Model):
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
@@ -133,6 +130,10 @@ class LSTM(Model):
|
||||
self.fitted = False
|
||||
self.lstm_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
@@ -214,7 +215,6 @@ class LSTM(Model):
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
@@ -227,8 +227,7 @@ class LSTM(Model):
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = -np.inf
|
||||
@@ -290,10 +289,7 @@ class LSTM(Model):
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_gpu:
|
||||
pred = self.lstm_model(x_batch).detach().cpu().numpy()
|
||||
else:
|
||||
pred = self.lstm_model(x_batch).detach().numpy()
|
||||
pred = self.lstm_model(x_batch).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
|
||||
@@ -9,15 +9,13 @@ import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -78,9 +76,8 @@ class LSTM(Model):
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.n_jobs = n_jobs
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
@@ -96,7 +93,7 @@ class LSTM(Model):
|
||||
"\nearly_stop : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\ndevice : {}"
|
||||
"\nn_jobs : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
@@ -111,7 +108,7 @@ class LSTM(Model):
|
||||
early_stop,
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
GPU,
|
||||
self.device,
|
||||
n_jobs,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
@@ -138,6 +135,10 @@ class LSTM(Model):
|
||||
self.fitted = False
|
||||
self.LSTM_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
@@ -201,7 +202,6 @@ class LSTM(Model):
|
||||
self,
|
||||
dataset,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
@@ -210,11 +210,14 @@ class LSTM(Model):
|
||||
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
|
||||
train_loader = DataLoader(dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs)
|
||||
valid_loader = DataLoader(dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs)
|
||||
train_loader = DataLoader(
|
||||
dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True
|
||||
)
|
||||
valid_loader = DataLoader(
|
||||
dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True
|
||||
)
|
||||
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
save_path = get_or_create_path(save_path)
|
||||
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
@@ -271,10 +274,7 @@ class LSTM(Model):
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_gpu:
|
||||
pred = self.LSTM_model(feature.float()).detach().cpu().numpy()
|
||||
else:
|
||||
pred = self.LSTM_model(feature.float()).detach().numpy()
|
||||
pred = self.LSTM_model(feature.float()).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import logging
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
@@ -15,11 +14,12 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, create_save_path, drop_nan_by_y_index
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, get_or_create_path, drop_nan_by_y_index
|
||||
from ...log import get_module_logger
|
||||
from ...workflow import R
|
||||
|
||||
|
||||
@@ -42,8 +42,8 @@ class DNNModelPytorch(Model):
|
||||
learning rate decay steps
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
GPU : int
|
||||
the GPU ID used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -80,8 +80,7 @@ class DNNModelPytorch(Model):
|
||||
self.lr_decay_steps = lr_decay_steps
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss_type = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.use_GPU = torch.cuda.is_available()
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
self.weight_decay = weight_decay
|
||||
|
||||
@@ -99,7 +98,7 @@ class DNNModelPytorch(Model):
|
||||
"\nloss_type : {}"
|
||||
"\neval_steps : {}"
|
||||
"\nseed : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\ndevice : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nweight_decay : {}".format(
|
||||
layers,
|
||||
@@ -114,8 +113,8 @@ class DNNModelPytorch(Model):
|
||||
loss,
|
||||
eval_steps,
|
||||
seed,
|
||||
GPU,
|
||||
self.use_GPU,
|
||||
self.device,
|
||||
self.use_gpu,
|
||||
weight_decay,
|
||||
)
|
||||
)
|
||||
@@ -129,6 +128,9 @@ class DNNModelPytorch(Model):
|
||||
self._scorer = mean_squared_error if loss == "mse" else roc_auc_score
|
||||
|
||||
self.dnn_model = Net(input_dim, output_dim, layers, loss=self.loss_type)
|
||||
self.logger.info("model:\n{:}".format(self.dnn_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.dnn_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.dnn_model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
||||
elif optimizer.lower() == "gd":
|
||||
@@ -153,6 +155,10 @@ class DNNModelPytorch(Model):
|
||||
self.fitted = False
|
||||
self.dnn_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
@@ -172,7 +178,7 @@ class DNNModelPytorch(Model):
|
||||
w_train = pd.DataFrame(np.ones_like(y_train.values), index=y_train.index)
|
||||
w_valid = pd.DataFrame(np.ones_like(y_valid.values), index=y_valid.index)
|
||||
|
||||
save_path = create_save_path(save_path)
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_loss = np.inf
|
||||
@@ -215,7 +221,8 @@ class DNNModelPytorch(Model):
|
||||
|
||||
# validation
|
||||
train_loss += loss.val
|
||||
if step and step % self.eval_steps == 0:
|
||||
# for evert `eval_steps` steps or at the last steps, we will evaluate the model.
|
||||
if step % self.eval_steps == 0 or step + 1 == self.max_steps:
|
||||
stop_steps += 1
|
||||
train_loss /= self.eval_steps
|
||||
|
||||
@@ -248,9 +255,9 @@ class DNNModelPytorch(Model):
|
||||
# update learning rate
|
||||
self.scheduler.step(cur_loss_val)
|
||||
|
||||
# restore the optimal parameters after training ??
|
||||
# restore the optimal parameters after training
|
||||
self.dnn_model.load_state_dict(torch.load(save_path))
|
||||
if self.use_GPU:
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_loss(self, pred, w, target, loss_type):
|
||||
@@ -272,10 +279,7 @@ class DNNModelPytorch(Model):
|
||||
self.dnn_model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_GPU:
|
||||
preds = self.dnn_model(x_test).detach().cpu().numpy()
|
||||
else:
|
||||
preds = self.dnn_model(x_test).detach().numpy()
|
||||
preds = self.dnn_model(x_test).detach().cpu().numpy()
|
||||
return pd.Series(np.squeeze(preds), index=x_test_pd.index)
|
||||
|
||||
def save(self, filename, **kwargs):
|
||||
|
||||
@@ -8,21 +8,20 @@ import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.init as init
|
||||
import torch.optim as optim
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
@@ -196,8 +195,8 @@ class SFM(Model):
|
||||
learning rate
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
GPU : int
|
||||
the GPU ID used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -216,7 +215,7 @@ class SFM(Model):
|
||||
eval_steps=5,
|
||||
loss="mse",
|
||||
optimizer="gd",
|
||||
GPU="0",
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
):
|
||||
@@ -239,8 +238,7 @@ class SFM(Model):
|
||||
self.eval_steps = eval_steps
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
@@ -259,7 +257,7 @@ class SFM(Model):
|
||||
"\neval_steps : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\ndevice : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
d_feat,
|
||||
@@ -276,7 +274,7 @@ class SFM(Model):
|
||||
eval_steps,
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
GPU,
|
||||
self.device,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
)
|
||||
@@ -295,6 +293,9 @@ class SFM(Model):
|
||||
dropout_U=self.dropout_U,
|
||||
device=self.device,
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.sfm_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.sfm_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.sfm_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
@@ -305,6 +306,10 @@ class SFM(Model):
|
||||
self.fitted = False
|
||||
self.sfm_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def test_epoch(self, data_x, data_y):
|
||||
|
||||
# prepare training data
|
||||
@@ -365,7 +370,6 @@ class SFM(Model):
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
@@ -377,6 +381,7 @@ class SFM(Model):
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = -np.inf
|
||||
@@ -409,7 +414,10 @@ class SFM(Model):
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.sfm_model.load_state_dict(best_param)
|
||||
torch.save(best_param, save_path)
|
||||
if self.device != "cpu":
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@@ -7,15 +7,13 @@ import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
from ...log import get_module_logger
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -23,6 +21,7 @@ import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Function
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
@@ -49,12 +48,12 @@ class TabnetModel(Model):
|
||||
loss="mse",
|
||||
metric="",
|
||||
early_stop=20,
|
||||
GPU="1",
|
||||
GPU=0,
|
||||
pretrain_loss="custom",
|
||||
ps=0.3,
|
||||
lr=0.01,
|
||||
pretrain=True,
|
||||
pretrain_file="./pretrain/best.model",
|
||||
pretrain_file=None,
|
||||
):
|
||||
"""
|
||||
TabNet model for Qlib
|
||||
@@ -75,29 +74,27 @@ class TabnetModel(Model):
|
||||
self.n_epochs = n_epochs
|
||||
self.logger = get_module_logger("TabNet")
|
||||
self.pretrain_n_epochs = pretrain_n_epochs
|
||||
self.device = "cuda:%s" % (GPU) if torch.cuda.is_available() else "cpu"
|
||||
self.device = "cuda:%s" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu"
|
||||
self.loss = loss
|
||||
self.metric = metric
|
||||
self.early_stop = early_stop
|
||||
self.pretrain = pretrain
|
||||
self.pretrain_file = pretrain_file
|
||||
self.pretrain_file = get_or_create_path(pretrain_file)
|
||||
self.logger.info(
|
||||
"TabNet:"
|
||||
"\nbatch_size : {}"
|
||||
"\nvirtual bs : {}"
|
||||
"\nGPU : {}"
|
||||
"\npretrain: {}".format(self.batch_size, vbs, GPU, pretrain)
|
||||
"\ndevice : {}"
|
||||
"\npretrain: {}".format(self.batch_size, vbs, self.device, self.pretrain)
|
||||
)
|
||||
self.fitted = False
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
self.tabnet_model = TabNet(
|
||||
inp_dim=self.d_feat, out_dim=self.out_dim, vbs=vbs, relax=relax, device=self.device
|
||||
).to(self.device)
|
||||
self.tabnet_decoder = TabNet_Decoder(self.out_dim, self.d_feat, n_shared, n_ind, vbs, n_steps, self.device).to(
|
||||
self.device
|
||||
)
|
||||
self.tabnet_model = TabNet(inp_dim=self.d_feat, out_dim=self.out_dim, vbs=vbs, relax=relax).to(self.device)
|
||||
self.tabnet_decoder = TabNet_Decoder(self.out_dim, self.d_feat, n_shared, n_ind, vbs, n_steps).to(self.device)
|
||||
self.logger.info("model:\n{:}\n{:}".format(self.tabnet_model, self.tabnet_decoder))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters([self.tabnet_model, self.tabnet_decoder])))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.pretrain_optimizer = optim.Adam(
|
||||
@@ -113,11 +110,12 @@ class TabnetModel(Model):
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def pretrain_fn(self, dataset=DatasetH, pretrain_file="./pretrain/best.model"):
|
||||
# make a directory if pretrian director does not exist
|
||||
if pretrain_file.startswith("./pretrain") and not os.path.exists("pretrain"):
|
||||
self.logger.info("make folder to store model...")
|
||||
os.makedirs("pretrain")
|
||||
get_or_create_path(pretrain_file)
|
||||
|
||||
[df_train, df_valid] = dataset.prepare(
|
||||
["pretrain", "pretrain_validation"],
|
||||
@@ -159,7 +157,6 @@ class TabnetModel(Model):
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
if self.pretrain:
|
||||
@@ -179,10 +176,11 @@ class TabnetModel(Model):
|
||||
df_train.fillna(df_train.mean(), inplace=True)
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
save_path = get_or_create_path(save_path)
|
||||
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = np.inf
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
@@ -201,16 +199,23 @@ class TabnetModel(Model):
|
||||
evals_result["train"].append(train_score)
|
||||
evals_result["valid"].append(val_score)
|
||||
|
||||
if val_score < best_score:
|
||||
if val_score > best_score:
|
||||
best_score = val_score
|
||||
stop_steps = 0
|
||||
best_epoch = epoch_idx
|
||||
best_param = copy.deepcopy(self.tabnet_model.state_dict())
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.tabnet_model.load_state_dict(best_param)
|
||||
torch.save(best_param, save_path)
|
||||
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self.fitted:
|
||||
@@ -260,12 +265,13 @@ class TabnetModel(Model):
|
||||
feature = x_values[indices[i : i + self.batch_size]].float().to(self.device)
|
||||
label = y_values[indices[i : i + self.batch_size]].float().to(self.device)
|
||||
priors = torch.ones(self.batch_size, self.d_feat).to(self.device)
|
||||
pred = self.tabnet_model(feature, priors)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
with torch.no_grad():
|
||||
pred = self.tabnet_model(feature, priors)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
@@ -348,10 +354,11 @@ class TabnetModel(Model):
|
||||
label = y_train_values.float().to(self.device)
|
||||
S_mask = S_mask.to(self.device)
|
||||
priors = 1 - S_mask
|
||||
(vec, sparse_loss) = self.tabnet_model(feature, priors)
|
||||
f = self.tabnet_decoder(vec)
|
||||
with torch.no_grad():
|
||||
(vec, sparse_loss) = self.tabnet_model(feature, priors)
|
||||
f = self.tabnet_decoder(vec)
|
||||
|
||||
loss = self.pretrain_loss_fn(label, f, S_mask)
|
||||
loss = self.pretrain_loss_fn(label, f, S_mask)
|
||||
losses.append(loss.item())
|
||||
|
||||
return np.mean(losses)
|
||||
@@ -397,9 +404,9 @@ class FinetuneModel(nn.Module):
|
||||
|
||||
|
||||
class DecoderStep(nn.Module):
|
||||
def __init__(self, inp_dim, out_dim, shared, n_ind, vbs, device):
|
||||
def __init__(self, inp_dim, out_dim, shared, n_ind, vbs):
|
||||
super().__init__()
|
||||
self.fea_tran = FeatureTransformer(inp_dim, out_dim, shared, n_ind, vbs, device)
|
||||
self.fea_tran = FeatureTransformer(inp_dim, out_dim, shared, n_ind, vbs)
|
||||
self.fc = nn.Linear(out_dim, out_dim)
|
||||
|
||||
def forward(self, x):
|
||||
@@ -408,13 +415,12 @@ class DecoderStep(nn.Module):
|
||||
|
||||
|
||||
class TabNet_Decoder(nn.Module):
|
||||
def __init__(self, inp_dim, out_dim, n_shared, n_ind, vbs, n_steps, device):
|
||||
def __init__(self, inp_dim, out_dim, n_shared, n_ind, vbs, n_steps):
|
||||
"""
|
||||
TabNet decoder that is used in pre-training
|
||||
"""
|
||||
self.out_dim = out_dim
|
||||
|
||||
super().__init__()
|
||||
self.out_dim = out_dim
|
||||
if n_shared > 0:
|
||||
self.shared = nn.ModuleList()
|
||||
self.shared.append(nn.Linear(inp_dim, 2 * out_dim))
|
||||
@@ -425,7 +431,7 @@ class TabNet_Decoder(nn.Module):
|
||||
self.n_steps = n_steps
|
||||
self.steps = nn.ModuleList()
|
||||
for x in range(n_steps):
|
||||
self.steps.append(DecoderStep(inp_dim, out_dim, self.shared, n_ind, vbs, device))
|
||||
self.steps.append(DecoderStep(inp_dim, out_dim, self.shared, n_ind, vbs))
|
||||
|
||||
def forward(self, x):
|
||||
out = torch.zeros(x.size(0), self.out_dim).to(x.device)
|
||||
@@ -435,9 +441,7 @@ class TabNet_Decoder(nn.Module):
|
||||
|
||||
|
||||
class TabNet(nn.Module):
|
||||
def __init__(
|
||||
self, inp_dim=6, out_dim=6, n_d=64, n_a=64, n_shared=2, n_ind=2, n_steps=5, relax=1.2, vbs=1024, device="cpu"
|
||||
):
|
||||
def __init__(self, inp_dim=6, out_dim=6, n_d=64, n_a=64, n_shared=2, n_ind=2, n_steps=5, relax=1.2, vbs=1024):
|
||||
"""
|
||||
TabNet AKA the original encoder
|
||||
|
||||
@@ -461,10 +465,10 @@ class TabNet(nn.Module):
|
||||
else:
|
||||
self.shared = None
|
||||
|
||||
self.first_step = FeatureTransformer(inp_dim, n_d + n_a, self.shared, n_ind, vbs, device)
|
||||
self.first_step = FeatureTransformer(inp_dim, n_d + n_a, self.shared, n_ind, vbs)
|
||||
self.steps = nn.ModuleList()
|
||||
for x in range(n_steps - 1):
|
||||
self.steps.append(DecisionStep(inp_dim, n_d, n_a, self.shared, n_ind, relax, vbs, device))
|
||||
self.steps.append(DecisionStep(inp_dim, n_d, n_a, self.shared, n_ind, relax, vbs))
|
||||
self.fc = nn.Linear(n_d, out_dim)
|
||||
self.bn = nn.BatchNorm1d(inp_dim, momentum=0.01)
|
||||
self.n_d = n_d
|
||||
@@ -473,14 +477,14 @@ class TabNet(nn.Module):
|
||||
assert not torch.isnan(x).any()
|
||||
x = self.bn(x)
|
||||
x_a = self.first_step(x)[:, self.n_d :]
|
||||
sparse_loss = torch.zeros(1).to(x.device)
|
||||
sparse_loss = []
|
||||
out = torch.zeros(x.size(0), self.n_d).to(x.device)
|
||||
for step in self.steps:
|
||||
x_te, l = step(x, x_a, priors)
|
||||
out += F.relu(x_te[:, : self.n_d]) # split the feautre from feat_transformer
|
||||
x_a = x_te[:, self.n_d :]
|
||||
sparse_loss += l
|
||||
return self.fc(out), sparse_loss
|
||||
sparse_loss.append(l)
|
||||
return self.fc(out), sum(sparse_loss)
|
||||
|
||||
|
||||
class GBN(nn.Module):
|
||||
@@ -498,9 +502,12 @@ class GBN(nn.Module):
|
||||
self.vbs = vbs
|
||||
|
||||
def forward(self, x):
|
||||
chunk = torch.chunk(x, x.size(0) // self.vbs, 0)
|
||||
res = [self.bn(y) for y in chunk]
|
||||
return torch.cat(res, 0)
|
||||
if x.size(0) <= self.vbs: # can not be chunked
|
||||
return self.bn(x)
|
||||
else:
|
||||
chunk = torch.chunk(x, x.size(0) // self.vbs, 0)
|
||||
res = [self.bn(y) for y in chunk]
|
||||
return torch.cat(res, 0)
|
||||
|
||||
|
||||
class GLU(nn.Module):
|
||||
@@ -548,7 +555,7 @@ class AttentionTransformer(nn.Module):
|
||||
|
||||
|
||||
class FeatureTransformer(nn.Module):
|
||||
def __init__(self, inp_dim, out_dim, shared, n_ind, vbs, device):
|
||||
def __init__(self, inp_dim, out_dim, shared, n_ind, vbs):
|
||||
super().__init__()
|
||||
first = True
|
||||
self.shared = nn.ModuleList()
|
||||
@@ -564,7 +571,7 @@ class FeatureTransformer(nn.Module):
|
||||
self.independ.append(GLU(inp, out_dim, vbs=vbs))
|
||||
for x in range(first, n_ind):
|
||||
self.independ.append(GLU(out_dim, out_dim, vbs=vbs))
|
||||
self.scale = torch.sqrt(torch.tensor([0.5], device=device))
|
||||
self.scale = float(np.sqrt(0.5))
|
||||
|
||||
def forward(self, x):
|
||||
if self.shared:
|
||||
@@ -583,10 +590,10 @@ class DecisionStep(nn.Module):
|
||||
One step for the TabNet
|
||||
"""
|
||||
|
||||
def __init__(self, inp_dim, n_d, n_a, shared, n_ind, relax, vbs, device):
|
||||
def __init__(self, inp_dim, n_d, n_a, shared, n_ind, relax, vbs):
|
||||
super().__init__()
|
||||
self.atten_tran = AttentionTransformer(n_a, inp_dim, relax, vbs)
|
||||
self.fea_tran = FeatureTransformer(inp_dim, n_d + n_a, shared, n_ind, vbs, device)
|
||||
self.fea_tran = FeatureTransformer(inp_dim, n_d + n_a, shared, n_ind, vbs)
|
||||
|
||||
def forward(self, x, a, priors):
|
||||
mask = self.atten_tran(a, priors)
|
||||
|
||||
37
qlib/contrib/model/pytorch_utils.py
Normal file
37
qlib/contrib/model/pytorch_utils.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def count_parameters(models_or_parameters, unit="m"):
|
||||
"""
|
||||
This function is to obtain the storage size unit of a (or multiple) models.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
models_or_parameters : PyTorch model(s) or a list of parameters.
|
||||
unit : the storage size unit.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The number of parameters of the given model(s) or parameters.
|
||||
"""
|
||||
if isinstance(models_or_parameters, nn.Module):
|
||||
counts = sum(v.numel() for v in models_or_parameters.parameters())
|
||||
elif isinstance(models_or_parameters, nn.Parameter):
|
||||
counts = models_or_parameters.numel()
|
||||
elif isinstance(models_or_parameters, (list, tuple)):
|
||||
return sum(count_parameters(x, unit) for x in models_or_parameters)
|
||||
else:
|
||||
counts = sum(v.numel() for v in models_or_parameters)
|
||||
unit = unit.lower()
|
||||
if unit == "kb" or unit == "k":
|
||||
counts /= 2 ** 10
|
||||
elif unit == "mb" or unit == "m":
|
||||
counts /= 2 ** 20
|
||||
elif unit == "gb" or unit == "g":
|
||||
counts /= 2 ** 30
|
||||
elif unit is not None:
|
||||
raise ValueError("Unknow unit: {:}".format(unit))
|
||||
return counts
|
||||
@@ -63,7 +63,7 @@ class UserManager:
|
||||
account_path = self.data_path / user_id
|
||||
strategy_file = self.data_path / user_id / "strategy_{}.pickle".format(user_id)
|
||||
model_file = self.data_path / user_id / "model_{}.pickle".format(user_id)
|
||||
cur_user_list = [user_id for user_id in self.users]
|
||||
cur_user_list = list(self.users)
|
||||
if user_id in cur_user_list:
|
||||
raise ValueError("User {} has been loaded".format(user_id))
|
||||
else:
|
||||
|
||||
@@ -148,7 +148,7 @@ class Operator:
|
||||
for user_id, user in um.users.items():
|
||||
dates, trade_exchange = prepare(um, trade_date, user_id, exchange_config)
|
||||
executor = SimulatorExecutor(trade_exchange=trade_exchange)
|
||||
if not str(dates[0].date()) == str(pred_date.date()):
|
||||
if str(dates[0].date()) != str(pred_date.date()):
|
||||
raise ValueError(
|
||||
"The account data is not newest! last trading date {}, today {}".format(
|
||||
dates[0].date(), trade_date.date()
|
||||
|
||||
@@ -161,7 +161,7 @@ class DistplotGraph(BaseGraph):
|
||||
"""
|
||||
_t_df = self._df.dropna()
|
||||
_data_list = [_t_df[_col] for _col in self._name_dict]
|
||||
_label_list = [_name for _name in self._name_dict.values()]
|
||||
_label_list = list(self._name_dict.values())
|
||||
_fig = create_distplot(_data_list, _label_list, show_rug=False, **self._graph_kwargs)
|
||||
|
||||
return _fig["data"]
|
||||
|
||||
@@ -7,7 +7,6 @@ import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from ..backtest.order import Order
|
||||
from ...utils import get_pre_trading_date
|
||||
from .order_generator import OrderGenWInteract
|
||||
|
||||
|
||||
@@ -390,11 +389,11 @@ class TopkDropoutStrategy(BaseStrategy, ListAdjustTimer):
|
||||
current_stock_list = current_temp.get_stock_list()
|
||||
value = cash * self.risk_degree / len(buy) if len(buy) > 0 else 0
|
||||
|
||||
# open_cost should be considered in the real trading environment, while the backtest in evaluate.py does not consider it
|
||||
# as the aim of demo is to accomplish same strategy as evaluate.py, so comment out this line
|
||||
# open_cost should be considered in the real trading environment, while the backtest in evaluate.py does not
|
||||
# consider it as the aim of demo is to accomplish same strategy as evaluate.py, so comment out this line
|
||||
# value = value / (1+trade_exchange.open_cost) # set open_cost limit
|
||||
for code in buy:
|
||||
# check is stock supended
|
||||
# check is stock suspended
|
||||
if not trade_exchange.is_stock_tradable(stock_id=code, trade_date=trade_date):
|
||||
continue
|
||||
# buy order
|
||||
|
||||
0
qlib/contrib/workflow/__init__.py
Normal file
0
qlib/contrib/workflow/__init__.py
Normal file
45
qlib/contrib/workflow/record_temp.py
Normal file
45
qlib/contrib/workflow/record_temp.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import re
|
||||
import pandas as pd
|
||||
from sklearn.metrics import mean_squared_error
|
||||
from pprint import pprint
|
||||
import numpy as np
|
||||
|
||||
from ...workflow.record_temp import SignalRecord
|
||||
from ...log import get_module_logger
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
|
||||
|
||||
class SignalMseRecord(SignalRecord):
|
||||
"""
|
||||
This is the Signal MSE Record class that computes the mean squared error (MSE).
|
||||
This class inherits the ``SignalMseRecord`` class.
|
||||
"""
|
||||
|
||||
artifact_path = "sig_analysis"
|
||||
|
||||
def __init__(self, recorder, **kwargs):
|
||||
super().__init__(recorder=recorder, **kwargs)
|
||||
|
||||
def generate(self, **kwargs):
|
||||
try:
|
||||
self.check(parent=True)
|
||||
except FileExistsError:
|
||||
super().generate()
|
||||
|
||||
pred = self.load("pred.pkl")
|
||||
label = self.load("label.pkl")
|
||||
masks = ~np.isnan(label.values)
|
||||
mse = mean_squared_error(pred.values[masks], label[masks])
|
||||
metrics = {"MSE": mse, "RMSE": np.sqrt(mse)}
|
||||
objects = {"mse.pkl": mse, "rmse.pkl": np.sqrt(mse)}
|
||||
self.recorder.log_metrics(**metrics)
|
||||
self.recorder.save_objects(**objects, artifact_path=self.get_path())
|
||||
pprint(metrics)
|
||||
|
||||
def list(self):
|
||||
paths = [self.get_path("mse.pkl"), self.get_path("rmse.pkl")]
|
||||
return paths
|
||||
@@ -1045,9 +1045,6 @@ class SimpleDatasetCache(DatasetCache):
|
||||
class DatasetURICache(DatasetCache):
|
||||
"""Prepared cache mechanism for server."""
|
||||
|
||||
def __init__(self, provider):
|
||||
super(DatasetURICache, self).__init__(provider)
|
||||
|
||||
def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, **kwargs):
|
||||
return hash_args(*self.normalize_uri_args(instruments, fields, freq), disk_cache)
|
||||
|
||||
|
||||
@@ -654,9 +654,6 @@ class LocalExpressionProvider(ExpressionProvider):
|
||||
Provide expression data from local data source.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def expression(self, instrument, field, start_time=None, end_time=None, freq="day"):
|
||||
expression = self.get_expression_instance(field)
|
||||
start_time = pd.Timestamp(start_time)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from ...utils.serial import Serializable
|
||||
from typing import Union, List, Tuple
|
||||
from typing import Union, List, Tuple, Dict, Text, Optional
|
||||
from ...utils import init_instance_by_config, np_ffill
|
||||
from ...log import get_module_logger
|
||||
from .handler import DataHandler, DataHandlerLP
|
||||
@@ -76,17 +76,6 @@ class DatasetH(Dataset):
|
||||
- The processing is related to data split.
|
||||
"""
|
||||
|
||||
def __init__(self, handler: Union[dict, DataHandler], segments: dict):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
handler : Union[dict, DataHandler]
|
||||
handler will be passed into setup_data.
|
||||
segments : dict
|
||||
handler will be passed into setup_data.
|
||||
"""
|
||||
super().__init__(handler, segments)
|
||||
|
||||
def init(self, handler_kwargs: dict = None, segment_kwargs: dict = None):
|
||||
"""
|
||||
Initialize the DatasetH
|
||||
@@ -124,7 +113,7 @@ class DatasetH(Dataset):
|
||||
raise TypeError(f"param handler_kwargs must be type dict, not {type(segment_kwargs)}")
|
||||
self.segments = segment_kwargs.copy()
|
||||
|
||||
def setup_data(self, handler: Union[dict, DataHandler], segments: dict):
|
||||
def setup_data(self, handler: Union[Dict, DataHandler], segments: Dict[Text, Tuple]):
|
||||
"""
|
||||
Setup the underlying data.
|
||||
|
||||
@@ -156,6 +145,11 @@ class DatasetH(Dataset):
|
||||
self.handler = init_instance_by_config(handler, accept_types=DataHandler)
|
||||
self.segments = segments.copy()
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(handler={handler}, segments={segments})".format(
|
||||
name=self.__class__.__name__, handler=self.handler, segments=self.segments
|
||||
)
|
||||
|
||||
def _prepare_seg(self, slc: slice, **kwargs):
|
||||
"""
|
||||
Give a slice, retrieve the according data
|
||||
@@ -168,7 +162,7 @@ class DatasetH(Dataset):
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
segments: Union[List[str], Tuple[str], str, slice],
|
||||
segments: Union[List[Text], Tuple[Text], Text, slice],
|
||||
col_set=DataHandler.CS_ALL,
|
||||
data_key=DataHandlerLP.DK_I,
|
||||
**kwargs,
|
||||
@@ -178,7 +172,7 @@ class DatasetH(Dataset):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
segments : Union[List[str], Tuple[str], str, slice]
|
||||
segments : Union[List[Text], Tuple[Text], Text, slice]
|
||||
Describe the scope of the data to be prepared
|
||||
Here are some examples:
|
||||
|
||||
@@ -265,7 +259,7 @@ class TSDataSampler:
|
||||
self.fillna_type = fillna_type
|
||||
assert get_level_index(data, "datetime") == 0
|
||||
self.data = lazy_sort_index(data)
|
||||
self.data_arr = np.array(self.data) # Get index from numpy.array will much faster than DataFrame.values! But
|
||||
self.data_arr = np.array(self.data) # Get index from numpy.array will much faster than DataFrame.values!
|
||||
# NOTE: append last line with full NaN for better performance in `__getitem__`
|
||||
self.data_arr = np.append(self.data_arr, np.full((1, self.data_arr.shape[1]), np.nan), axis=0)
|
||||
self.nan_idx = -1 # The last line is all NaN
|
||||
@@ -273,7 +267,6 @@ class TSDataSampler:
|
||||
# the data type will be changed
|
||||
# The index of usable data is between start_idx and end_idx
|
||||
self.start_idx, self.end_idx = self.data.index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end))
|
||||
# self.index_link = self.build_link(self.data)
|
||||
self.idx_df, self.idx_map = self.build_index(self.data)
|
||||
self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance
|
||||
|
||||
@@ -408,7 +401,7 @@ class TSDataSampler:
|
||||
# 1) for better performance, use the last nan line for padding the lost date
|
||||
# 2) In case of precision problems. We use np.float64. # TODO: I'm not sure if whether np.float64 will result in
|
||||
# precision problems. It will not cause any problems in my tests at least
|
||||
indices = np.nan_to_num(indices.astype(np.float64), nan=self.nan_idx).astype(np.int)
|
||||
indices = np.nan_to_num(indices.astype(np.float64), nan=self.nan_idx).astype(int)
|
||||
|
||||
data = self.data_arr[indices]
|
||||
if isinstance(idx, mtit):
|
||||
|
||||
@@ -35,7 +35,7 @@ class DataHandler(Serializable):
|
||||
The data handler try to maintain a handler with 2 level.
|
||||
`datetime` & `instruments`.
|
||||
|
||||
Any order of the index level can be suported(The order will implied in the data).
|
||||
Any order of the index level can be suported (The order will be implied in the data).
|
||||
The order <`datetime`, `instruments`> will be used when the dataframe index name is missed.
|
||||
|
||||
Example of the data:
|
||||
@@ -47,8 +47,8 @@ class DataHandler(Serializable):
|
||||
$close $volume Ref($close, 1) Mean($close, 3) $high-$low LABEL0
|
||||
datetime instrument
|
||||
2010-01-04 SH600000 81.807068 17145150.0 83.737389 83.016739 2.741058 0.0032
|
||||
SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042
|
||||
SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
|
||||
SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042
|
||||
SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
|
||||
|
||||
|
||||
Tips for improving the performance of datahandler
|
||||
|
||||
@@ -74,7 +74,6 @@ class NpElemOperator(ElemOperator):
|
||||
"""
|
||||
|
||||
def __init__(self, feature, func):
|
||||
self.feature = feature
|
||||
self.func = func
|
||||
super(NpElemOperator, self).__init__(feature)
|
||||
|
||||
@@ -289,8 +288,6 @@ class NpPairOperator(PairOperator):
|
||||
"""
|
||||
|
||||
def __init__(self, feature_left, feature_right, func):
|
||||
self.feature_left = feature_left
|
||||
self.feature_right = feature_right
|
||||
self.func = func
|
||||
super(NpPairOperator, self).__init__(feature_left, feature_right)
|
||||
|
||||
@@ -1182,7 +1179,7 @@ class Slope(Rolling):
|
||||
Returns
|
||||
----------
|
||||
Expression
|
||||
a feature instance with regression slope of given window
|
||||
a feature instance with linear regression slope of given window
|
||||
"""
|
||||
|
||||
def __init__(self, feature, N):
|
||||
@@ -1210,7 +1207,7 @@ class Rsquare(Rolling):
|
||||
Returns
|
||||
----------
|
||||
Expression
|
||||
a feature instance with regression r-value square of given window
|
||||
a feature instance with linear regression r-value square of given window
|
||||
"""
|
||||
|
||||
def __init__(self, feature, N):
|
||||
@@ -1489,7 +1486,7 @@ OpsList = [
|
||||
]
|
||||
|
||||
|
||||
class OpsWrapper(object):
|
||||
class OpsWrapper:
|
||||
"""Ops Wrapper"""
|
||||
|
||||
def __init__(self):
|
||||
|
||||
10
qlib/log.py
10
qlib/log.py
@@ -3,8 +3,7 @@
|
||||
|
||||
|
||||
import logging
|
||||
import logging.handlers
|
||||
import os
|
||||
from typing import Optional, Text, Dict, Any
|
||||
import re
|
||||
from logging import config as logging_config
|
||||
from time import time
|
||||
@@ -13,16 +12,13 @@ from contextlib import contextmanager
|
||||
from .config import C
|
||||
|
||||
|
||||
def get_module_logger(module_name, level=None):
|
||||
def get_module_logger(module_name, level: Optional[int] = None):
|
||||
"""
|
||||
Get a logger for a specific module.
|
||||
|
||||
:param module_name: str
|
||||
Logic module name.
|
||||
:param level: int
|
||||
:param sh_level: int
|
||||
Stream handler log level.
|
||||
:param log_format: str
|
||||
:return: Logger
|
||||
Logger object.
|
||||
"""
|
||||
@@ -103,7 +99,7 @@ class TimeInspector:
|
||||
cls.log_cost_time(info=f"{name} Done")
|
||||
|
||||
|
||||
def set_log_with_config(log_config: dict):
|
||||
def set_log_with_config(log_config: Dict[Text, Any]):
|
||||
"""set log with config
|
||||
|
||||
:param log_config:
|
||||
|
||||
@@ -43,7 +43,8 @@ class Model(BaseModel):
|
||||
|
||||
# get weights
|
||||
try:
|
||||
wdf_train, wdf_valid = dataset.prepare(["train", "valid"], col_set=["weight"], data_key=DataHandlerLP.DK_L)
|
||||
wdf_train, wdf_valid = dataset.prepare(["train", "valid"], col_set=["weight"],
|
||||
data_key=DataHandlerLP.DK_L)
|
||||
w_train, w_valid = wdf_train["weight"], wdf_valid["weight"]
|
||||
except KeyError as e:
|
||||
w_train = pd.DataFrame(np.ones_like(y_train.values), index=y_train.index)
|
||||
|
||||
7
qlib/model/riskmodel/__init__.py
Normal file
7
qlib/model/riskmodel/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from .base import RiskModel
|
||||
from .poet import POETCovEstimator
|
||||
from .shrink import ShrinkCovEstimator
|
||||
from .structured import StructuredCovEstimator
|
||||
147
qlib/model/riskmodel/base.py
Normal file
147
qlib/model/riskmodel/base.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import inspect
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Union
|
||||
|
||||
from qlib.model.base import BaseModel
|
||||
|
||||
|
||||
class RiskModel(BaseModel):
|
||||
"""Risk Model
|
||||
|
||||
A risk model is used to estimate the covariance matrix of stock returns.
|
||||
"""
|
||||
|
||||
MASK_NAN = "mask"
|
||||
FILL_NAN = "fill"
|
||||
IGNORE_NAN = "ignore"
|
||||
|
||||
def __init__(self, nan_option: str = "ignore", assume_centered: bool = False, scale_return: bool = True):
|
||||
"""
|
||||
Args:
|
||||
nan_option (str): nan handling option (`ignore`/`mask`/`fill`).
|
||||
assume_centered (bool): whether the data is assumed to be centered.
|
||||
scale_return (bool): whether scale returns as percentage.
|
||||
"""
|
||||
# nan
|
||||
assert nan_option in [
|
||||
self.MASK_NAN,
|
||||
self.FILL_NAN,
|
||||
self.IGNORE_NAN,
|
||||
], f"`nan_option={nan_option}` is not supported"
|
||||
self.nan_option = nan_option
|
||||
|
||||
self.assume_centered = assume_centered
|
||||
self.scale_return = scale_return
|
||||
|
||||
def predict(
|
||||
self,
|
||||
X: Union[pd.Series, pd.DataFrame, np.ndarray],
|
||||
return_corr: bool = False,
|
||||
is_price: bool = True,
|
||||
return_decomposed_components=False,
|
||||
) -> Union[pd.DataFrame, np.ndarray, tuple]:
|
||||
"""
|
||||
Args:
|
||||
X (pd.Series, pd.DataFrame or np.ndarray): data from which to estimate the covariance,
|
||||
with variables as columns and observations as rows.
|
||||
return_corr (bool): whether return the correlation matrix.
|
||||
is_price (bool): whether `X` contains price (if not assume stock returns).
|
||||
return_decomposed_components (bool): whether return decomposed components of the covariance matrix.
|
||||
|
||||
Returns:
|
||||
pd.DataFrame or np.ndarray: estimated covariance (or correlation).
|
||||
"""
|
||||
assert (
|
||||
not return_corr or not return_decomposed_components
|
||||
), "Can only return either correlation matrix or decomposed components."
|
||||
|
||||
# transform input into 2D array
|
||||
if not isinstance(X, (pd.Series, pd.DataFrame)):
|
||||
columns = None
|
||||
else:
|
||||
if isinstance(X.index, pd.MultiIndex):
|
||||
if isinstance(X, pd.DataFrame):
|
||||
X = X.iloc[:, 0].unstack(level="instrument") # always use the first column
|
||||
else:
|
||||
X = X.unstack(level="instrument")
|
||||
else:
|
||||
# X is 2D DataFrame
|
||||
pass
|
||||
columns = X.columns # will be used to restore dataframe
|
||||
X = X.values
|
||||
|
||||
# calculate pct_change
|
||||
if is_price:
|
||||
X = X[1:] / X[:-1] - 1 # NOTE: resulting `n - 1` rows
|
||||
|
||||
# scale return
|
||||
if self.scale_return:
|
||||
X *= 100
|
||||
|
||||
# handle nan and centered
|
||||
X = self._preprocess(X)
|
||||
|
||||
# return decomposed components if needed
|
||||
if return_decomposed_components:
|
||||
assert (
|
||||
"return_decomposed_components" in inspect.getfullargspec(self._predict).args
|
||||
), "This risk model does not support return decomposed components of the covariance matrix "
|
||||
|
||||
F, cov_b, var_u = self._predict(X, return_decomposed_components=True)
|
||||
return F, cov_b, var_u
|
||||
|
||||
# estimate covariance
|
||||
S = self._predict(X)
|
||||
|
||||
# return correlation if needed
|
||||
if return_corr:
|
||||
vola = np.sqrt(np.diag(S))
|
||||
corr = S / np.outer(vola, vola)
|
||||
if columns is None:
|
||||
return corr
|
||||
return pd.DataFrame(corr, index=columns, columns=columns)
|
||||
|
||||
# return covariance
|
||||
if columns is None:
|
||||
return S
|
||||
return pd.DataFrame(S, index=columns, columns=columns)
|
||||
|
||||
def _predict(self, X: np.ndarray) -> np.ndarray:
|
||||
"""covariance estimation implementation
|
||||
|
||||
This method should be overridden by child classes.
|
||||
|
||||
By default, this method implements the empirical covariance estimation.
|
||||
|
||||
Args:
|
||||
X (np.ndarray): data matrix containing multiple variables (columns) and observations (rows).
|
||||
|
||||
Returns:
|
||||
np.ndarray: covariance matrix.
|
||||
"""
|
||||
xTx = np.asarray(X.T.dot(X))
|
||||
N = len(X)
|
||||
if isinstance(X, np.ma.MaskedArray):
|
||||
M = 1 - X.mask
|
||||
N = M.T.dot(M) # each pair has distinct number of samples
|
||||
return xTx / N
|
||||
|
||||
def _preprocess(self, X: np.ndarray) -> Union[np.ndarray, np.ma.MaskedArray]:
|
||||
"""handle nan and centerize data
|
||||
|
||||
Note:
|
||||
if `nan_option='mask'` then the returned array will be `np.ma.MaskedArray`.
|
||||
"""
|
||||
# handle nan
|
||||
if self.nan_option == self.FILL_NAN:
|
||||
X = np.nan_to_num(X)
|
||||
elif self.nan_option == self.MASK_NAN:
|
||||
X = np.ma.masked_invalid(X)
|
||||
# centralize
|
||||
if not self.assume_centered:
|
||||
X = X - np.nanmean(X, axis=0)
|
||||
return X
|
||||
84
qlib/model/riskmodel/poet.py
Normal file
84
qlib/model/riskmodel/poet.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import numpy as np
|
||||
|
||||
from qlib.model.riskmodel import RiskModel
|
||||
|
||||
|
||||
class POETCovEstimator(RiskModel):
|
||||
"""Principal Orthogonal Complement Thresholding Estimator (POET)
|
||||
|
||||
Reference:
|
||||
[1] Fan, J., Liao, Y., & Mincheva, M. (2013). Large covariance estimation by thresholding principal orthogonal complements.
|
||||
Journal of the Royal Statistical Society. Series B: Statistical Methodology, 75(4), 603–680. https://doi.org/10.1111/rssb.12016
|
||||
[2] http://econweb.rutgers.edu/yl1114/papers/poet/POET.m
|
||||
"""
|
||||
|
||||
THRESH_SOFT = "soft"
|
||||
THRESH_HARD = "hard"
|
||||
THRESH_SCAD = "scad"
|
||||
|
||||
def __init__(self, num_factors: int = 0, thresh: float = 1.0, thresh_method: str = "soft", **kwargs):
|
||||
"""
|
||||
Args:
|
||||
num_factors (int): number of factors (if set to zero, no factor model will be used).
|
||||
thresh (float): the positive constant for thresholding.
|
||||
thresh_method (str): thresholding method, which can be
|
||||
- 'soft': soft thresholding.
|
||||
- 'hard': hard thresholding.
|
||||
- 'scad': scad thresholding.
|
||||
kwargs: see `RiskModel` for more information.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
assert num_factors >= 0, "`num_factors` requires a positive integer"
|
||||
self.num_factors = num_factors
|
||||
|
||||
assert thresh >= 0, "`thresh` requires a positive float number"
|
||||
self.thresh = thresh
|
||||
|
||||
assert thresh_method in [
|
||||
self.THRESH_HARD,
|
||||
self.THRESH_SOFT,
|
||||
self.THRESH_SCAD,
|
||||
], "`thresh_method` should be `soft`/`hard`/`scad`"
|
||||
self.thresh_method = thresh_method
|
||||
|
||||
def _predict(self, X: np.ndarray) -> np.ndarray:
|
||||
|
||||
Y = X.T # NOTE: to match POET's implementation
|
||||
p, n = Y.shape
|
||||
|
||||
if self.num_factors > 0:
|
||||
Dd, V = np.linalg.eig(Y.T.dot(Y))
|
||||
V = V[:, np.argsort(Dd)]
|
||||
F = V[:, -self.num_factors :][:, ::-1] * np.sqrt(n)
|
||||
LamPCA = Y.dot(F) / n
|
||||
uhat = np.asarray(Y - LamPCA.dot(F.T))
|
||||
Lowrank = np.asarray(LamPCA.dot(LamPCA.T))
|
||||
rate = 1 / np.sqrt(p) + np.sqrt(np.log(p) / n)
|
||||
else:
|
||||
uhat = np.asarray(Y)
|
||||
rate = np.sqrt(np.log(p) / n)
|
||||
Lowrank = 0
|
||||
|
||||
lamb = rate * self.thresh
|
||||
SuPCA = uhat.dot(uhat.T) / n
|
||||
SuDiag = np.diag(np.diag(SuPCA))
|
||||
R = np.linalg.inv(SuDiag ** 0.5).dot(SuPCA).dot(np.linalg.inv(SuDiag ** 0.5))
|
||||
|
||||
if self.thresh_method == self.THRESH_HARD:
|
||||
M = R * (np.abs(R) > lamb)
|
||||
elif self.thresh_method == self.THRESH_SOFT:
|
||||
res = np.abs(R) - lamb
|
||||
res = (res + np.abs(res)) / 2
|
||||
M = np.sign(R) * res
|
||||
else:
|
||||
M1 = (np.abs(R) < 2 * lamb) * np.sign(R) * (np.abs(R) - lamb) * (np.abs(R) > lamb)
|
||||
M2 = (np.abs(R) < 3.7 * lamb) * (np.abs(R) >= 2 * lamb) * (2.7 * R - 3.7 * np.sign(R) * lamb) / 1.7
|
||||
M3 = (np.abs(R) >= 3.7 * lamb) * R
|
||||
M = M1 + M2 + M3
|
||||
|
||||
Rthresh = M - np.diag(np.diag(M)) + np.eye(p)
|
||||
SigmaU = (SuDiag ** 0.5).dot(Rthresh).dot(SuDiag ** 0.5)
|
||||
SigmaY = SigmaU + Lowrank
|
||||
|
||||
return SigmaY
|
||||
@@ -1,133 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import warnings
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from typing import Union
|
||||
|
||||
from qlib.model.base import BaseModel
|
||||
|
||||
|
||||
class RiskModel(BaseModel):
|
||||
"""Risk Model
|
||||
|
||||
A risk model is used to estimate the covariance matrix of stock returns.
|
||||
"""
|
||||
|
||||
MASK_NAN = "mask"
|
||||
FILL_NAN = "fill"
|
||||
IGNORE_NAN = "ignore"
|
||||
|
||||
def __init__(self, nan_option: str = "ignore", assume_centered: bool = False, scale_return: bool = True):
|
||||
"""
|
||||
Args:
|
||||
nan_option (str): nan handling option (`ignore`/`mask`/`fill`).
|
||||
assume_centered (bool): whether the data is assumed to be centered.
|
||||
scale_return (bool): whether scale returns as percentage.
|
||||
"""
|
||||
# nan
|
||||
assert nan_option in [
|
||||
self.MASK_NAN,
|
||||
self.FILL_NAN,
|
||||
self.IGNORE_NAN,
|
||||
], f"`nan_option={nan_option}` is not supported"
|
||||
self.nan_option = nan_option
|
||||
|
||||
self.assume_centered = assume_centered
|
||||
self.scale_return = scale_return
|
||||
|
||||
def predict(
|
||||
self, X: Union[pd.Series, pd.DataFrame, np.ndarray], return_corr: bool = False, is_price: bool = True
|
||||
) -> Union[pd.DataFrame, np.ndarray]:
|
||||
"""
|
||||
Args:
|
||||
X (pd.Series, pd.DataFrame or np.ndarray): data from which to estimate the covariance,
|
||||
with variables as columns and observations as rows.
|
||||
return_corr (bool): whether return the correlation matrix.
|
||||
is_price (bool): whether `X` contains price (if not assume stock returns).
|
||||
|
||||
Returns:
|
||||
pd.DataFrame or np.ndarray: estimated covariance (or correlation).
|
||||
"""
|
||||
# transform input into 2D array
|
||||
if not isinstance(X, (pd.Series, pd.DataFrame)):
|
||||
columns = None
|
||||
else:
|
||||
if isinstance(X.index, pd.MultiIndex):
|
||||
if isinstance(X, pd.DataFrame):
|
||||
X = X.iloc[:, 0].unstack(level="instrument") # always use the first column
|
||||
else:
|
||||
X = X.unstack(level="instrument")
|
||||
else:
|
||||
# X is 2D DataFrame
|
||||
pass
|
||||
columns = X.columns # will be used to restore dataframe
|
||||
X = X.values
|
||||
|
||||
# calculate pct_change
|
||||
if is_price:
|
||||
X = X[1:] / X[:-1] - 1 # NOTE: resulting `n - 1` rows
|
||||
|
||||
# scale return
|
||||
if self.scale_return:
|
||||
X *= 100
|
||||
|
||||
# handle nan and centered
|
||||
X = self._preprocess(X)
|
||||
|
||||
# estimate covariance
|
||||
S = self._predict(X)
|
||||
|
||||
# return correlation if needed
|
||||
if return_corr:
|
||||
vola = np.sqrt(np.diag(S))
|
||||
corr = S / np.outer(vola, vola)
|
||||
if columns is None:
|
||||
return corr
|
||||
return pd.DataFrame(corr, index=columns, columns=columns)
|
||||
|
||||
# return covariance
|
||||
if columns is None:
|
||||
return S
|
||||
return pd.DataFrame(S, index=columns, columns=columns)
|
||||
|
||||
def _predict(self, X: np.ndarray) -> np.ndarray:
|
||||
"""covariance estimation implementation
|
||||
|
||||
This method should be overridden by child classes.
|
||||
|
||||
By default, this method implements the empirical covariance estimation.
|
||||
|
||||
Args:
|
||||
X (np.ndarray): data matrix containing multiple variables (columns) and observations (rows).
|
||||
|
||||
Returns:
|
||||
np.ndarray: covariance matrix.
|
||||
"""
|
||||
xTx = np.asarray(X.T.dot(X))
|
||||
N = len(X)
|
||||
if isinstance(X, np.ma.MaskedArray):
|
||||
M = 1 - X.mask
|
||||
N = M.T.dot(M) # each pair has distinct number of samples
|
||||
return xTx / N
|
||||
|
||||
def _preprocess(self, X: np.ndarray) -> Union[np.ndarray, np.ma.MaskedArray]:
|
||||
"""handle nan and centerize data
|
||||
|
||||
Note:
|
||||
if `nan_option='mask'` then the returned array will be `np.ma.MaskedArray`.
|
||||
"""
|
||||
# handle nan
|
||||
if self.nan_option == self.FILL_NAN:
|
||||
X = np.nan_to_num(X)
|
||||
elif self.nan_option == self.MASK_NAN:
|
||||
X = np.ma.masked_invalid(X)
|
||||
# centerize
|
||||
if not self.assume_centered:
|
||||
X = X - np.nanmean(X, axis=0)
|
||||
return X
|
||||
from qlib.model.riskmodel import RiskModel
|
||||
|
||||
|
||||
class ShrinkCovEstimator(RiskModel):
|
||||
@@ -162,8 +36,9 @@ class ShrinkCovEstimator(RiskModel):
|
||||
[3] Ledoit, O., & Wolf, M. (2003). Improved estimation of the covariance matrix of stock returns
|
||||
with an application to portfolio selection.
|
||||
Journal of Empirical Finance, 10(5), 603–621. https://doi.org/10.1016/S0927-5398(03)00007-0
|
||||
[4] Chen, Y., Wiesel, A., Eldar, Y. C., & Hero, A. O. (2010). Shrinkage algorithms for MMSE covariance estimation.
|
||||
IEEE Transactions on Signal Processing, 58(10), 5016–5029. https://doi.org/10.1109/TSP.2010.2053029
|
||||
[4] Chen, Y., Wiesel, A., Eldar, Y. C., & Hero, A. O. (2010). Shrinkage algorithms for MMSE covariance
|
||||
estimation. IEEE Transactions on Signal Processing, 58(10), 5016–5029.
|
||||
https://doi.org/10.1109/TSP.2010.2053029
|
||||
[5] https://www.econ.uzh.ch/dam/jcr:ffffffff-935a-b0d6-0000-00007f64e5b9/cov1para.m.zip
|
||||
[6] https://www.econ.uzh.ch/dam/jcr:ffffffff-935a-b0d6-ffff-ffffde5e2d4e/covCor.m.zip
|
||||
[7] https://www.econ.uzh.ch/dam/jcr:ffffffff-935a-b0d6-0000-0000648dfc98/covMarket.m.zip
|
||||
@@ -384,84 +259,3 @@ class ShrinkCovEstimator(RiskModel):
|
||||
alpha = max(0, min(1, kappa / t))
|
||||
|
||||
return alpha
|
||||
|
||||
|
||||
class POETCovEstimator(RiskModel):
|
||||
"""Principal Orthogonal Complement Thresholding Estimator (POET)
|
||||
|
||||
Reference:
|
||||
[1] Fan, J., Liao, Y., & Mincheva, M. (2013). Large covariance estimation by thresholding principal orthogonal complements.
|
||||
Journal of the Royal Statistical Society. Series B: Statistical Methodology, 75(4), 603–680. https://doi.org/10.1111/rssb.12016
|
||||
[2] http://econweb.rutgers.edu/yl1114/papers/poet/POET.m
|
||||
"""
|
||||
|
||||
THRESH_SOFT = "soft"
|
||||
THRESH_HARD = "hard"
|
||||
THRESH_SCAD = "scad"
|
||||
|
||||
def __init__(self, num_factors: int = 0, thresh: float = 1.0, thresh_method: str = "soft", **kwargs):
|
||||
"""
|
||||
Args:
|
||||
num_factors (int): number of factors (if set to zero, no factor model will be used).
|
||||
thresh (float): the positive constant for thresholding.
|
||||
thresh_method (str): thresholding method, which can be
|
||||
- 'soft': soft thresholding.
|
||||
- 'hard': hard thresholding.
|
||||
- 'scad': scad thresholding.
|
||||
kwargs: see `RiskModel` for more information.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
assert num_factors >= 0, "`num_factors` requires a positive integer"
|
||||
self.num_factors = num_factors
|
||||
|
||||
assert thresh >= 0, "`thresh` requires a positive float number"
|
||||
self.thresh = thresh
|
||||
|
||||
assert thresh_method in [
|
||||
self.THRESH_HARD,
|
||||
self.THRESH_SOFT,
|
||||
self.THRESH_SCAD,
|
||||
], "`thresh_method` should be `soft`/`hard`/`scad`"
|
||||
self.thresh_method = thresh_method
|
||||
|
||||
def _predict(self, X: np.ndarray) -> np.ndarray:
|
||||
|
||||
Y = X.T # NOTE: to match POET's implementation
|
||||
p, n = Y.shape
|
||||
|
||||
if self.num_factors > 0:
|
||||
Dd, V = np.linalg.eig(Y.T.dot(Y))
|
||||
V = V[:, np.argsort(Dd)]
|
||||
F = V[:, -self.num_factors :][:, ::-1] * np.sqrt(n)
|
||||
LamPCA = Y.dot(F) / n
|
||||
uhat = np.asarray(Y - LamPCA.dot(F.T))
|
||||
Lowrank = np.asarray(LamPCA.dot(LamPCA.T))
|
||||
rate = 1 / np.sqrt(p) + np.sqrt(np.log(p) / n)
|
||||
else:
|
||||
uhat = np.asarray(Y)
|
||||
rate = np.sqrt(np.log(p) / n)
|
||||
Lowrank = 0
|
||||
|
||||
lamb = rate * self.thresh
|
||||
SuPCA = uhat.dot(uhat.T) / n
|
||||
SuDiag = np.diag(np.diag(SuPCA))
|
||||
R = np.linalg.inv(SuDiag ** 0.5).dot(SuPCA).dot(np.linalg.inv(SuDiag ** 0.5))
|
||||
|
||||
if self.thresh_method == self.THRESH_HARD:
|
||||
M = R * (np.abs(R) > lamb)
|
||||
elif self.thresh_method == self.THRESH_SOFT:
|
||||
res = np.abs(R) - lamb
|
||||
res = (res + np.abs(res)) / 2
|
||||
M = np.sign(R) * res
|
||||
else:
|
||||
M1 = (np.abs(R) < 2 * lamb) * np.sign(R) * (np.abs(R) - lamb) * (np.abs(R) > lamb)
|
||||
M2 = (np.abs(R) < 3.7 * lamb) * (np.abs(R) >= 2 * lamb) * (2.7 * R - 3.7 * np.sign(R) * lamb) / 1.7
|
||||
M3 = (np.abs(R) >= 3.7 * lamb) * R
|
||||
M = M1 + M2 + M3
|
||||
|
||||
Rthresh = M - np.diag(np.diag(M)) + np.eye(p)
|
||||
SigmaU = (SuDiag ** 0.5).dot(Rthresh).dot(SuDiag ** 0.5)
|
||||
SigmaY = SigmaU + Lowrank
|
||||
|
||||
return SigmaY
|
||||
84
qlib/model/riskmodel/structured.py
Normal file
84
qlib/model/riskmodel/structured.py
Normal file
@@ -0,0 +1,84 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Union
|
||||
from sklearn.decomposition import PCA, FactorAnalysis
|
||||
|
||||
from qlib.model.riskmodel import RiskModel
|
||||
|
||||
|
||||
class StructuredCovEstimator(RiskModel):
|
||||
"""Structured Covariance Estimator
|
||||
|
||||
This estimator assumes observations can be predicted by multiple factors
|
||||
X = FB + U
|
||||
where `F` can be specified by explicit risk factors or latent factors.
|
||||
|
||||
Therefore the structured covariance can be estimated by
|
||||
cov(X) = F cov(B) F.T + cov(U)
|
||||
|
||||
We use latent factor models to estimate the structured covariance.
|
||||
Specifically, the following latent factor models are supported:
|
||||
- `pca`: Principal Component Analysis
|
||||
- `fa`: Factor Analysis
|
||||
|
||||
Reference: [1] Fan, J., Liao, Y., & Liu, H. (2016). An overview of the estimation of large covariance and
|
||||
precision matrices. Econometrics Journal, 19(1), C1–C32. https://doi.org/10.1111/ectj.12061
|
||||
"""
|
||||
|
||||
FACTOR_MODEL_PCA = "pca"
|
||||
FACTOR_MODEL_FA = "fa"
|
||||
DEFAULT_NAN_OPTION = "fill"
|
||||
|
||||
def __init__(self, factor_model: str = "pca", num_factors: int = 10, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
factor_model (str): the latent factor models used to estimate the structured covariance (`pca`/`fa`).
|
||||
num_factors (int): number of components to keep.
|
||||
kwargs: see `RiskModel` for more information
|
||||
"""
|
||||
if "nan_option" in kwargs.keys():
|
||||
assert kwargs["nan_option"] in [self.DEFAULT_NAN_OPTION], "nan_option={} is not supported".format(
|
||||
kwargs["nan_option"]
|
||||
)
|
||||
else:
|
||||
kwargs["nan_option"] = self.DEFAULT_NAN_OPTION
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
assert factor_model in [
|
||||
self.FACTOR_MODEL_PCA,
|
||||
self.FACTOR_MODEL_FA,
|
||||
], "factor_model={} is not supported".format(factor_model)
|
||||
self.solver = PCA if factor_model == self.FACTOR_MODEL_PCA else FactorAnalysis
|
||||
|
||||
self.num_factors = num_factors
|
||||
|
||||
def _predict(self, X: np.ndarray, return_decomposed_components=False) -> Union[np.ndarray, tuple]:
|
||||
"""
|
||||
covariance estimation implementation
|
||||
|
||||
Args:
|
||||
X (np.ndarray): data matrix containing multiple variables (columns) and observations (rows).
|
||||
return_decomposed_components (bool): whether return decomposed components of the covariance matrix.
|
||||
|
||||
Returns:
|
||||
tuple or np.ndarray: decomposed covariance matrix or covariance matrix.
|
||||
"""
|
||||
|
||||
model = self.solver(self.num_factors, random_state=0).fit(X)
|
||||
|
||||
F = model.components_.T # num_features x num_factors
|
||||
B = model.transform(X) # num_samples x num_factors
|
||||
U = X - B @ F.T
|
||||
cov_b = np.cov(B.T) # num_factors x num_factors
|
||||
var_u = np.var(U, axis=0) # diagonal
|
||||
|
||||
if return_decomposed_components:
|
||||
return F, cov_b, var_u
|
||||
|
||||
cov_x = F @ cov_b @ F.T + np.diag(var_u)
|
||||
|
||||
return cov_x
|
||||
@@ -0,0 +1,2 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
6
qlib/portfolio/optimizer/__init__.py
Normal file
6
qlib/portfolio/optimizer/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from .base import BaseOptimizer
|
||||
from .optimizer import PortfolioOptimizer
|
||||
from .enhanced_indexing import EnhancedIndexingOptimizer
|
||||
13
qlib/portfolio/optimizer/base.py
Normal file
13
qlib/portfolio/optimizer/base.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import abc
|
||||
|
||||
|
||||
class BaseOptimizer(abc.ABC):
|
||||
""" Construct portfolio with a optimization related method """
|
||||
|
||||
@abc.abstractmethod
|
||||
def __call__(self, *args, **kwargs) -> object:
|
||||
""" Generate a optimized portfolio allocation """
|
||||
pass
|
||||
143
qlib/portfolio/optimizer/enhanced_indexing.py
Normal file
143
qlib/portfolio/optimizer/enhanced_indexing.py
Normal file
@@ -0,0 +1,143 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import numpy as np
|
||||
import cvxpy as cp
|
||||
import pandas as pd
|
||||
from typing import Union
|
||||
|
||||
from qlib.portfolio.optimizer import BaseOptimizer
|
||||
|
||||
|
||||
class EnhancedIndexingOptimizer(BaseOptimizer):
|
||||
"""
|
||||
Portfolio Optimizer with Enhanced Indexing
|
||||
|
||||
Note:
|
||||
This optimizer always assumes full investment and no-shorting.
|
||||
"""
|
||||
|
||||
START_FROM_W0 = "w0"
|
||||
START_FROM_BENCH = "benchmark"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lamb: float = 10,
|
||||
delta: float = 0.4,
|
||||
bench_dev: float = 0.01,
|
||||
inds_dev: float = None,
|
||||
scale_alpha: bool = True,
|
||||
verbose: bool = False,
|
||||
warm_start: str = None,
|
||||
max_iters: int = 10000,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
lamb (float): risk aversion parameter (larger `lamb` means less focus on return)
|
||||
delta (float): turnover rate limit
|
||||
bench_dev (float): benchmark deviation limit
|
||||
inds_dev (float/None): industry deviation limit, set `inds_dev` to None to ignore industry specific
|
||||
restriction
|
||||
scale_alpha (bool): if to scale alpha to match the volatility of the covariance matrix
|
||||
verbose (bool): if print detailed information about the solver
|
||||
warm_start (str): whether try to warm start (`w0`/`benchmark`/``)
|
||||
(https://www.cvxpy.org/tutorial/advanced/index.html#warm-start)
|
||||
"""
|
||||
|
||||
assert lamb >= 0, "risk aversion parameter `lamb` should be positive"
|
||||
self.lamb = lamb
|
||||
|
||||
assert delta >= 0, "turnover limit `delta` should be positive"
|
||||
self.delta = delta
|
||||
|
||||
assert bench_dev >= 0, "benchmark deviation limit `bench_dev` should be positive"
|
||||
self.bench_dev = bench_dev
|
||||
|
||||
assert inds_dev is None or inds_dev >= 0, "industry deviation limit `inds_dev` should be positive or None."
|
||||
self.inds_dev = inds_dev
|
||||
|
||||
assert warm_start in [
|
||||
None,
|
||||
self.START_FROM_W0,
|
||||
self.START_FROM_BENCH,
|
||||
], "illegal warm start option"
|
||||
self.start_from_w0 = warm_start == self.START_FROM_W0
|
||||
self.start_from_bench = warm_start == self.START_FROM_BENCH
|
||||
|
||||
self.scale_alpha = scale_alpha
|
||||
self.verbose = verbose
|
||||
self.max_iters = max_iters
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
u: Union[np.ndarray, pd.Series],
|
||||
F: np.ndarray,
|
||||
covB: np.ndarray,
|
||||
varU: np.ndarray,
|
||||
w0: np.ndarray,
|
||||
w_bench: np.ndarray,
|
||||
inds_onehot: np.ndarray = None,
|
||||
) -> Union[np.ndarray, pd.Series]:
|
||||
"""
|
||||
Args:
|
||||
u (np.ndarray or pd.Series): expected returns (a.k.a., alpha)
|
||||
F, covB, varU (np.ndarray): see StructuredCovEstimator
|
||||
w0 (np.ndarray): initial weights (for turnover control)
|
||||
w_bench (np.ndarray): benchmark weights
|
||||
inds_onehot (np.ndarray): industry (onehot)
|
||||
|
||||
Returns:
|
||||
np.ndarray or pd.Series: optimized portfolio allocation
|
||||
"""
|
||||
assert inds_onehot is not None or self.inds_dev is None, "Industry onehot vector is required."
|
||||
|
||||
# transform dataframe into array
|
||||
if isinstance(u, pd.Series):
|
||||
u = u.values
|
||||
|
||||
# scale alpha to match volatility
|
||||
if self.scale_alpha:
|
||||
u = u / u.std()
|
||||
x_variance = np.mean(np.diag(F @ covB @ F.T) + varU)
|
||||
u *= x_variance ** 0.5
|
||||
|
||||
w = cp.Variable(len(u)) # num_assets
|
||||
v = w @ F # num_factors
|
||||
ret = w @ u
|
||||
risk = cp.quad_form(v, covB) + cp.sum(cp.multiply(varU, w ** 2))
|
||||
obj = cp.Maximize(ret - self.lamb * risk)
|
||||
d_bench = w - w_bench
|
||||
cons = [
|
||||
w >= 0,
|
||||
cp.sum(w) == 1,
|
||||
d_bench >= -self.bench_dev,
|
||||
d_bench <= self.bench_dev,
|
||||
]
|
||||
|
||||
if self.inds_dev is not None:
|
||||
d_inds = d_bench @ inds_onehot
|
||||
cons.append(d_inds >= -self.inds_dev)
|
||||
cons.append(d_inds <= self.inds_dev)
|
||||
|
||||
if w0 is not None:
|
||||
turnover = cp.sum(cp.abs(w - w0))
|
||||
cons.append(turnover <= self.delta)
|
||||
|
||||
warm_start = False
|
||||
if self.start_from_w0:
|
||||
if w0 is None:
|
||||
print("Warning: try warm start with w0, but w0 is `None`.")
|
||||
else:
|
||||
w.value = w0
|
||||
warm_start = True
|
||||
elif self.start_from_bench:
|
||||
w.value = w_bench
|
||||
warm_start = True
|
||||
|
||||
prob = cp.Problem(obj, cons)
|
||||
prob.solve(solver=cp.SCS, verbose=self.verbose, warm_start=warm_start, max_iters=self.max_iters)
|
||||
|
||||
if prob.status != "optimal":
|
||||
print("Warning: solve failed.", prob.status)
|
||||
|
||||
return np.asarray(w.value)
|
||||
@@ -1,15 +1,17 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
import warnings
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import scipy.optimize as so
|
||||
|
||||
from typing import Optional, Union, Callable, List
|
||||
|
||||
from qlib.portfolio.optimizer import BaseOptimizer
|
||||
|
||||
class PortfolioOptimizer:
|
||||
|
||||
class PortfolioOptimizer(BaseOptimizer):
|
||||
"""Portfolio Optimizer
|
||||
|
||||
The following optimization algorithms are supported:
|
||||
@@ -42,6 +44,7 @@ class PortfolioOptimizer:
|
||||
lamb (float): risk aversion parameter (larger `lamb` means more focus on return)
|
||||
delta (float): turnover rate limit
|
||||
alpha (float): l2 norm regularizer
|
||||
scale_alpha (bool): if to scale alpha to match the volatility of the covariance matrix
|
||||
tol (float): tolerance for optimization termination
|
||||
"""
|
||||
assert method in [self.OPT_GMV, self.OPT_MVO, self.OPT_RP, self.OPT_INV], f"method `{method}` is not supported"
|
||||
@@ -57,6 +60,7 @@ class PortfolioOptimizer:
|
||||
self.alpha = alpha
|
||||
|
||||
self.tol = tol
|
||||
self.scale_alpha = scale_alpha
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -83,18 +87,18 @@ class PortfolioOptimizer:
|
||||
if u is not None:
|
||||
assert len(u) == len(S), "`u` has mismatched shape"
|
||||
if isinstance(u, pd.Series):
|
||||
assert all(u.index == index), "`u` has mismatched index"
|
||||
assert u.index.equals(index), "`u` has mismatched index"
|
||||
u = u.values
|
||||
|
||||
# transform initial weights
|
||||
if w0 is not None:
|
||||
assert len(w0) == len(S), "`w0` has mismatched shape"
|
||||
if isinstance(w0, pd.Series):
|
||||
assert all(w0.index == index), "`w0` has mismatched index"
|
||||
assert w0.index.equals(index), "`w0` has mismatched index"
|
||||
w0 = w0.values
|
||||
|
||||
# scale alpha to match volatility
|
||||
if u is not None:
|
||||
if u is not None and self.scale_alpha:
|
||||
u = u / u.std()
|
||||
u *= np.mean(np.diag(S)) ** 0.5
|
||||
|
||||
@@ -173,7 +177,7 @@ class PortfolioOptimizer:
|
||||
"""
|
||||
return self._solve(len(S), self._get_objective_rp(S), *self._get_constrains(w0))
|
||||
|
||||
def _get_objective_gmv(self, S: np.ndarray) -> np.ndarray:
|
||||
def _get_objective_gmv(self, S: np.ndarray) -> Callable:
|
||||
"""global minimum variance optimization objective
|
||||
|
||||
Optimization objective
|
||||
@@ -185,7 +189,7 @@ class PortfolioOptimizer:
|
||||
|
||||
return func
|
||||
|
||||
def _get_objective_mvo(self, S: np.ndarray, u: np.ndarray = None) -> np.ndarray:
|
||||
def _get_objective_mvo(self, S: np.ndarray, u: np.ndarray = None) -> Callable:
|
||||
"""mean-variance optimization objective
|
||||
|
||||
Optimization objective
|
||||
@@ -199,7 +203,7 @@ class PortfolioOptimizer:
|
||||
|
||||
return func
|
||||
|
||||
def _get_objective_rp(self, S: np.ndarray) -> np.ndarray:
|
||||
def _get_objective_rp(self, S: np.ndarray) -> Callable:
|
||||
"""risk-parity optimization objective
|
||||
|
||||
Optimization objective
|
||||
@@ -247,7 +251,11 @@ class PortfolioOptimizer:
|
||||
# add l2 regularization
|
||||
wrapped_obj = obj
|
||||
if self.alpha > 0:
|
||||
wrapped_obj = lambda x: obj(x) + self.alpha * np.sum(np.square(x))
|
||||
|
||||
def opt_obj(x):
|
||||
return obj(x) + self.alpha * np.sum(np.square(x))
|
||||
|
||||
wrapped_obj = opt_obj
|
||||
|
||||
# solve
|
||||
x0 = np.ones(n) / n # init results
|
||||
@@ -24,7 +24,7 @@ import collections
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import Union, Tuple
|
||||
from typing import Union, Tuple, Text, Optional
|
||||
|
||||
from ..config import C
|
||||
from ..log import get_module_logger, set_log_with_config
|
||||
@@ -64,7 +64,7 @@ def np_ffill(arr: np.array):
|
||||
arr : np.array
|
||||
Input numpy 1D array
|
||||
"""
|
||||
mask = np.isnan(arr.astype(np.float)) # np.isnan only works on np.float
|
||||
mask = np.isnan(arr.astype(float)) # np.isnan only works on np.float
|
||||
# get fill index
|
||||
idx = np.where(~mask, np.arange(mask.shape[0]), 0)
|
||||
np.maximum.accumulate(idx, out=idx)
|
||||
@@ -212,7 +212,7 @@ def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
|
||||
|
||||
|
||||
def init_instance_by_config(
|
||||
config: Union[str, dict, object], module=None, accept_types: Union[type, Tuple[type]] = tuple([]), **kwargs
|
||||
config: Union[str, dict, object], module=None, accept_types: Union[type, Tuple[type]] = (), **kwargs
|
||||
) -> object:
|
||||
"""
|
||||
get initialized instance with config
|
||||
@@ -276,23 +276,31 @@ def compare_dict_value(src_data: dict, dst_data: dict):
|
||||
return changes
|
||||
|
||||
|
||||
def create_save_path(save_path=None):
|
||||
"""Create save path
|
||||
def get_or_create_path(path: Optional[Text] = None, return_dir: bool = False):
|
||||
"""Create or get a file or directory given the path and return_dir.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
save_path: str
|
||||
path: a string indicates the path or None indicates creating a temporary path.
|
||||
return_dir: if True, create and return a directory; otherwise c&r a file.
|
||||
|
||||
"""
|
||||
if save_path:
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
if path:
|
||||
if return_dir and not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
elif not return_dir: # return a file, thus we need to create its parent directory
|
||||
xpath = os.path.abspath(os.path.join(path, ".."))
|
||||
if not os.path.exists(xpath):
|
||||
os.makedirs(xpath)
|
||||
else:
|
||||
temp_dir = os.path.expanduser("~/tmp")
|
||||
if not os.path.exists(temp_dir):
|
||||
os.makedirs(temp_dir)
|
||||
_, save_path = tempfile.mkstemp(dir=temp_dir)
|
||||
return save_path
|
||||
if return_dir:
|
||||
_, path = tempfile.mkdtemp(dir=temp_dir)
|
||||
else:
|
||||
_, path = tempfile.mkstemp(dir=temp_dir)
|
||||
return path
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@@ -722,6 +730,9 @@ class Wrapper:
|
||||
def register(self, provider):
|
||||
self._provider = provider
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(provider={provider})".format(name=self.__class__.__name__, provider=self._provider)
|
||||
|
||||
def __getattr__(self, key):
|
||||
if self._provider is None:
|
||||
raise AttributeError("Please run qlib.init() first using qlib")
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Text, Optional
|
||||
from .expm import MLflowExpManager
|
||||
from .exp import Experiment
|
||||
from .recorder import Recorder
|
||||
@@ -16,26 +17,47 @@ class QlibRecorder:
|
||||
def __init__(self, exp_manager):
|
||||
self.exp_manager = exp_manager
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(manager={manager})".format(name=self.__class__.__name__, manager=self.exp_manager)
|
||||
|
||||
@contextmanager
|
||||
def start(self, experiment_name=None, recorder_name=None):
|
||||
def start(
|
||||
self,
|
||||
experiment_name: Optional[Text] = None,
|
||||
recorder_name: Optional[Text] = None,
|
||||
uri: Optional[Text] = None,
|
||||
resume: bool = False,
|
||||
):
|
||||
"""
|
||||
Method to start an experiment. This method can only be called within a Python's `with` statement. Here is the example code:
|
||||
|
||||
.. code-block:: Python
|
||||
|
||||
# start new experiment and recorder
|
||||
with R.start('test', 'recorder_1'):
|
||||
model.fit(dataset)
|
||||
R.log...
|
||||
... # further operations
|
||||
|
||||
# resume previous experiment and recorder
|
||||
with R.start('test', 'recorder_1', resume=True): # if users want to resume recorder, they have to specify the exact same name for experiment and recorder.
|
||||
... # further operations
|
||||
|
||||
Parameters
|
||||
----------
|
||||
experiment_name : str
|
||||
name of the experiment one wants to start.
|
||||
recorder_name : str
|
||||
name of the recorder under the experiment one wants to start.
|
||||
uri : str
|
||||
The tracking uri of the experiment, where all the artifacts/metrics etc. will be stored.
|
||||
The default uri is set in the qlib.config. Note that this uri argument will not change the one defined in the config file.
|
||||
Therefore, the next time when users call this function in the same experiment,
|
||||
they have to also specify this argument with the same value. Otherwise, inconsistent uri may occur.
|
||||
resume : bool
|
||||
whether to resume the specific recorder with given name under the given experiment.
|
||||
"""
|
||||
run = self.start_exp(experiment_name, recorder_name)
|
||||
run = self.start_exp(experiment_name, recorder_name, uri, resume)
|
||||
try:
|
||||
yield run
|
||||
except Exception as e:
|
||||
@@ -43,7 +65,7 @@ class QlibRecorder:
|
||||
raise e
|
||||
self.end_exp(Recorder.STATUS_FI)
|
||||
|
||||
def start_exp(self, experiment_name=None, recorder_name=None, uri=None):
|
||||
def start_exp(self, experiment_name=None, recorder_name=None, uri=None, resume=False):
|
||||
"""
|
||||
Lower level method for starting an experiment. When use this method, one should end the experiment manually
|
||||
and the status of the recorder may not be handled properly. Here is the example code:
|
||||
@@ -64,12 +86,14 @@ class QlibRecorder:
|
||||
uri : str
|
||||
the tracking uri of the experiment, where all the artifacts/metrics etc. will be stored.
|
||||
The default uri are set in the qlib.config.
|
||||
resume : bool
|
||||
whether to resume the specific recorder with given name under the given experiment.
|
||||
|
||||
Returns
|
||||
-------
|
||||
An experiment instance being started.
|
||||
"""
|
||||
return self.exp_manager.start_exp(experiment_name, recorder_name, uri)
|
||||
return self.exp_manager.start_exp(experiment_name, recorder_name, uri, resume)
|
||||
|
||||
def end_exp(self, recorder_status=Recorder.STATUS_FI):
|
||||
"""
|
||||
@@ -272,7 +296,13 @@ class QlibRecorder:
|
||||
-------
|
||||
The uri of current experiment manager.
|
||||
"""
|
||||
return self.exp_manager.get_uri()
|
||||
return self.exp_manager.uri
|
||||
|
||||
def set_uri(self, uri: Optional[Text]):
|
||||
"""
|
||||
Method to reset the current uri of current experiment manager.
|
||||
"""
|
||||
self.exp_manager.set_uri(uri)
|
||||
|
||||
def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=None):
|
||||
"""
|
||||
|
||||
@@ -16,7 +16,7 @@ def get_path_list(path):
|
||||
if isinstance(path, str):
|
||||
return [path]
|
||||
else:
|
||||
return [p for p in path]
|
||||
return list(path)
|
||||
|
||||
|
||||
def sys_config(config, config_path):
|
||||
|
||||
@@ -23,7 +23,7 @@ class Experiment:
|
||||
self.active_recorder = None # only one recorder can running each time
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.info)
|
||||
return "{name}(id={id}, info={info})".format(name=self.__class__.__name__, id=self.id, info=self.info)
|
||||
|
||||
def __str__(self):
|
||||
return str(self.info)
|
||||
@@ -39,7 +39,7 @@ class Experiment:
|
||||
output["recorders"] = list(recorders.keys())
|
||||
return output
|
||||
|
||||
def start(self, recorder_name=None):
|
||||
def start(self, recorder_name=None, resume=False):
|
||||
"""
|
||||
Start the experiment and set it to be active. This method will also start a new recorder.
|
||||
|
||||
@@ -47,6 +47,8 @@ class Experiment:
|
||||
----------
|
||||
recorder_name : str
|
||||
the name of the recorder to be created.
|
||||
resume : bool
|
||||
whether to resume the first recorder
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -149,7 +151,57 @@ class Experiment:
|
||||
-------
|
||||
A recorder object.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_recorder` method.")
|
||||
# special case of getting the recorder
|
||||
if recorder_id is None and recorder_name is None:
|
||||
if self.active_recorder is not None:
|
||||
return self.active_recorder
|
||||
recorder_name = self._default_rec_name
|
||||
if create:
|
||||
recorder, is_new = self._get_or_create_rec(recorder_id=recorder_id, recorder_name=recorder_name)
|
||||
else:
|
||||
recorder, is_new = self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name), False
|
||||
if is_new:
|
||||
self.active_recorder = recorder
|
||||
# start the recorder
|
||||
self.active_recorder.start_run()
|
||||
return recorder
|
||||
|
||||
def _get_or_create_rec(self, recorder_id=None, recorder_name=None) -> (object, bool):
|
||||
"""
|
||||
Method for getting or creating a recorder. It will try to first get a valid recorder, if exception occurs, it will
|
||||
automatically create a new recorder based on the given id and name.
|
||||
"""
|
||||
try:
|
||||
if recorder_id is None and recorder_name is None:
|
||||
recorder_name = self._default_rec_name
|
||||
return self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name), False
|
||||
except ValueError:
|
||||
if recorder_name is None:
|
||||
recorder_name = self._default_rec_name
|
||||
logger.info(f"No valid recorder found. Create a new recorder with name {recorder_name}.")
|
||||
return self.create_recorder(recorder_name), True
|
||||
|
||||
def _get_recorder(self, recorder_id=None, recorder_name=None):
|
||||
"""
|
||||
Get specific recorder by name or id. If it does not exist, raise ValueError
|
||||
|
||||
Parameters
|
||||
----------
|
||||
recorder_id :
|
||||
The id of recorder
|
||||
recorder_name :
|
||||
The name of recorder
|
||||
|
||||
Returns
|
||||
-------
|
||||
Recorder:
|
||||
The searched recorder
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `_get_recorder` method")
|
||||
|
||||
def list_recorders(self):
|
||||
"""
|
||||
@@ -173,16 +225,25 @@ class MLflowExperiment(Experiment):
|
||||
self._uri = uri
|
||||
self._default_name = None
|
||||
self._default_rec_name = "mlflow_recorder"
|
||||
self.client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
|
||||
self._client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
|
||||
|
||||
def start(self, recorder_name=None):
|
||||
# set the active experiment
|
||||
mlflow.set_experiment(self.name)
|
||||
def __repr__(self):
|
||||
return "{name}(id={id}, info={info})".format(name=self.__class__.__name__, id=self.id, info=self.info)
|
||||
|
||||
def start(self, recorder_name=None, resume=False):
|
||||
logger.info(f"Experiment {self.id} starts running ...")
|
||||
# set up recorder
|
||||
recorder = self.create_recorder(recorder_name)
|
||||
# Get or create recorder
|
||||
if recorder_name is None:
|
||||
recorder_name = self._default_rec_name
|
||||
# resume the recorder
|
||||
if resume:
|
||||
recorder, _ = self._get_or_create_rec(recorder_name=recorder_name)
|
||||
# create a new recorder
|
||||
else:
|
||||
recorder = self.create_recorder(recorder_name)
|
||||
# Set up active recorder
|
||||
self.active_recorder = recorder
|
||||
# start the recorder
|
||||
# Start the recorder
|
||||
self.active_recorder.start_run()
|
||||
|
||||
return self.active_recorder
|
||||
@@ -199,36 +260,6 @@ class MLflowExperiment(Experiment):
|
||||
|
||||
return recorder
|
||||
|
||||
def get_recorder(self, recorder_id=None, recorder_name=None, create=True):
|
||||
# special case of getting the recorder
|
||||
if recorder_id is None and recorder_name is None:
|
||||
if self.active_recorder is not None:
|
||||
return self.active_recorder
|
||||
recorder_name = self._default_rec_name
|
||||
if create:
|
||||
recorder, is_new = self._get_or_create_rec(recorder_id=recorder_id, recorder_name=recorder_name)
|
||||
else:
|
||||
recorder, is_new = self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name), False
|
||||
if is_new:
|
||||
mlflow.set_experiment(self.name)
|
||||
self.active_recorder = recorder
|
||||
# start the recorder
|
||||
self.active_recorder.start_run()
|
||||
return recorder
|
||||
|
||||
def _get_or_create_rec(self, recorder_id=None, recorder_name=None) -> (object, bool):
|
||||
"""
|
||||
Method for getting or creating a recorder. It will try to first get a valid recorder, if exception occurs, it will
|
||||
automatically create a new recorder based on the given id and name.
|
||||
"""
|
||||
try:
|
||||
return self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name), False
|
||||
except ValueError:
|
||||
if recorder_name is None:
|
||||
recorder_name = self._default_rec_name
|
||||
logger.info(f"No valid recorder found. Create a new recorder with name {recorder_name}.")
|
||||
return self.create_recorder(recorder_name), True
|
||||
|
||||
def _get_recorder(self, recorder_id=None, recorder_name=None):
|
||||
"""
|
||||
Method for getting or creating a recorder. It will try to first get a valid recorder, if exception occurs, it will
|
||||
@@ -239,14 +270,14 @@ class MLflowExperiment(Experiment):
|
||||
), "Please input at least one of recorder id or name before retrieving recorder."
|
||||
if recorder_id is not None:
|
||||
try:
|
||||
run = self.client.get_run(recorder_id)
|
||||
run = self._client.get_run(recorder_id)
|
||||
recorder = MLflowRecorder(self.id, self._uri, mlflow_run=run)
|
||||
return recorder
|
||||
except MlflowException:
|
||||
raise ValueError("No valid recorder has been found, please make sure the input recorder id is correct.")
|
||||
elif recorder_name is not None:
|
||||
logger.warning(
|
||||
f"Please make sure the recorder name {recorder_name} is unique, we will only return the first recorder if there exist several matched the given name."
|
||||
f"Please make sure the recorder name {recorder_name} is unique, we will only return the latest recorder if there exist several matched the given name."
|
||||
)
|
||||
recorders = self.list_recorders()
|
||||
for rid in recorders:
|
||||
@@ -260,7 +291,7 @@ class MLflowExperiment(Experiment):
|
||||
max_results = 100000 if kwargs.get("max_results") is None else kwargs.get("max_results")
|
||||
order_by = kwargs.get("order_by")
|
||||
|
||||
return self.client.search_runs([self.id], filter_string, run_view_type, max_results, order_by)
|
||||
return self._client.search_runs([self.id], filter_string, run_view_type, max_results, order_by)
|
||||
|
||||
def delete_recorder(self, recorder_id=None, recorder_name=None):
|
||||
assert (
|
||||
@@ -268,10 +299,10 @@ class MLflowExperiment(Experiment):
|
||||
), "Please input a valid recorder id or name before deleting."
|
||||
try:
|
||||
if recorder_id is not None:
|
||||
self.client.delete_run(recorder_id)
|
||||
self._client.delete_run(recorder_id)
|
||||
else:
|
||||
recorder = self._get_recorder(recorder_name=recorder_name)
|
||||
self.client.delete_run(recorder.id)
|
||||
self._client.delete_run(recorder.id)
|
||||
except MlflowException as e:
|
||||
raise Exception(
|
||||
f"Error: {e}. Something went wrong when deleting recorder. Please check if the name/id of the recorder is correct."
|
||||
@@ -280,7 +311,7 @@ class MLflowExperiment(Experiment):
|
||||
UNLIMITED = 50000 # FIXME: Mlflow can only list 50000 records at most!!!!!!!
|
||||
|
||||
def list_recorders(self, max_results=UNLIMITED):
|
||||
runs = self.client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)[::-1]
|
||||
runs = self._client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)
|
||||
recorders = dict()
|
||||
for i in range(len(runs)):
|
||||
recorder = MLflowRecorder(self.id, self._uri, mlflow_run=runs[i])
|
||||
|
||||
@@ -7,8 +7,11 @@ from mlflow.entities import ViewType
|
||||
import os
|
||||
from pathlib import Path
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Text
|
||||
|
||||
from .exp import MLflowExperiment, Experiment
|
||||
from .recorder import Recorder, MLflowRecorder
|
||||
from ..config import C
|
||||
from .recorder import Recorder
|
||||
from ..log import get_module_logger
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
@@ -20,12 +23,22 @@ class ExpManager:
|
||||
(The link: https://mlflow.org/docs/latest/python_api/mlflow.html)
|
||||
"""
|
||||
|
||||
def __init__(self, uri, default_exp_name):
|
||||
self.uri = uri
|
||||
self.default_exp_name = default_exp_name
|
||||
def __init__(self, uri: Text, default_exp_name: Optional[Text]):
|
||||
self._current_uri = uri
|
||||
self._default_exp_name = default_exp_name
|
||||
self.active_experiment = None # only one experiment can active each time
|
||||
|
||||
def start_exp(self, experiment_name=None, recorder_name=None, uri=None, **kwargs):
|
||||
def __repr__(self):
|
||||
return "{name}(current_uri={curi})".format(name=self.__class__.__name__, curi=self._current_uri)
|
||||
|
||||
def start_exp(
|
||||
self,
|
||||
experiment_name: Optional[Text] = None,
|
||||
recorder_name: Optional[Text] = None,
|
||||
uri: Optional[Text] = None,
|
||||
resume: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Start an experiment. This method includes first get_or_create an experiment, and then
|
||||
set it to be active.
|
||||
@@ -38,6 +51,8 @@ class ExpManager:
|
||||
name of the recorder to be started.
|
||||
uri : str
|
||||
the current tracking URI.
|
||||
resume : boolean
|
||||
whether to resume the experiment and recorder.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -45,7 +60,7 @@ class ExpManager:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `start_exp` method.")
|
||||
|
||||
def end_exp(self, recorder_status: str = Recorder.STATUS_S, **kwargs):
|
||||
def end_exp(self, recorder_status: Text = Recorder.STATUS_S, **kwargs):
|
||||
"""
|
||||
End an active experiment.
|
||||
|
||||
@@ -58,7 +73,7 @@ class ExpManager:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `end_exp` method.")
|
||||
|
||||
def create_exp(self, experiment_name=None):
|
||||
def create_exp(self, experiment_name: Optional[Text] = None):
|
||||
"""
|
||||
Create an experiment.
|
||||
|
||||
@@ -139,9 +154,7 @@ class ExpManager:
|
||||
if self.active_experiment is not None:
|
||||
return self.active_experiment
|
||||
# User don't want get active code now.
|
||||
# Don't assume underlying code could handle the case of two None
|
||||
if experiment_id is None and experiment_name is None:
|
||||
experiment_name = self.default_exp_name
|
||||
experiment_name = self._default_exp_name
|
||||
|
||||
if create:
|
||||
exp, is_new = self._get_or_create_exp(experiment_id=experiment_id, experiment_name=experiment_name)
|
||||
@@ -159,25 +172,23 @@ class ExpManager:
|
||||
automatically create a new experiment based on the given id and name.
|
||||
"""
|
||||
try:
|
||||
if experiment_id is None and experiment_name is None:
|
||||
experiment_name = self.default_exp_name
|
||||
return self._get_exp(experiment_id=experiment_id, experiment_name=experiment_name), False
|
||||
except ValueError:
|
||||
if experiment_name is None:
|
||||
experiment_name = self.default_exp_name
|
||||
experiment_name = self._default_exp_name
|
||||
logger.info(f"No valid experiment found. Create a new experiment with name {experiment_name}.")
|
||||
return self.create_exp(experiment_name), True
|
||||
|
||||
def _get_exp(self, experiment_id=None, experiment_name=None) -> Experiment:
|
||||
"""
|
||||
get specific experiment by name or id. If it does not exist, raise ValueError
|
||||
Get specific experiment by name or id. If it does not exist, raise ValueError.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
experiment_id :
|
||||
The id of experiment
|
||||
experiment_name :
|
||||
The id name experiment
|
||||
The name of experiment
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -203,7 +214,17 @@ class ExpManager:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `delete_exp` method.")
|
||||
|
||||
def get_uri(self):
|
||||
@property
|
||||
def default_uri(self):
|
||||
"""
|
||||
Get the default tracking URI from qlib.config.C
|
||||
"""
|
||||
if "kwargs" not in C.exp_manager or "uri" not in C.exp_manager["kwargs"]:
|
||||
raise ValueError("The default URI is not set in qlib.config.C")
|
||||
return C.exp_manager["kwargs"]["uri"]
|
||||
|
||||
@property
|
||||
def uri(self):
|
||||
"""
|
||||
Get the default tracking URI or current URI.
|
||||
|
||||
@@ -211,7 +232,31 @@ class ExpManager:
|
||||
-------
|
||||
The tracking URI string.
|
||||
"""
|
||||
return self.uri
|
||||
return self._current_uri or self.default_uri
|
||||
|
||||
def set_uri(self, uri: Optional[Text] = None):
|
||||
"""
|
||||
Set the current tracking URI and the corresponding variables.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
uri : str
|
||||
|
||||
"""
|
||||
if uri is None:
|
||||
logger.info("No tracking URI is provided. Use the default tracking URI.")
|
||||
self._current_uri = self.default_uri
|
||||
else:
|
||||
# Temporarily re-set the current uri as the uri argument.
|
||||
self._current_uri = uri
|
||||
# Customized features for subclasses.
|
||||
self._set_uri()
|
||||
|
||||
def _set_uri(self):
|
||||
"""
|
||||
Customized features for subclasses' set_uri function.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `_set_uri` method.")
|
||||
|
||||
def list_experiments(self):
|
||||
"""
|
||||
@@ -229,42 +274,54 @@ class MLflowExpManager(ExpManager):
|
||||
Use mlflow to implement ExpManager.
|
||||
"""
|
||||
|
||||
def __init__(self, uri, default_exp_name):
|
||||
def __init__(self, uri: Text, default_exp_name: Optional[Text]):
|
||||
super(MLflowExpManager, self).__init__(uri, default_exp_name)
|
||||
self._client = None
|
||||
|
||||
def _set_uri(self):
|
||||
self._client = mlflow.tracking.MlflowClient(tracking_uri=self.uri)
|
||||
logger.info("{:}".format(self._client))
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
# Delay the creation of mlflow client in case of creating `mlruns` folder when importing qlib
|
||||
if not hasattr(self, "_client"):
|
||||
if self._client is None:
|
||||
self._client = mlflow.tracking.MlflowClient(tracking_uri=self.uri)
|
||||
return self._client
|
||||
|
||||
def start_exp(self, experiment_name=None, recorder_name=None, uri=None):
|
||||
# set the tracking uri
|
||||
if uri is None:
|
||||
logger.info("No tracking URI is provided. Use the default tracking URI.")
|
||||
else:
|
||||
self.uri = uri
|
||||
# create experiment
|
||||
def start_exp(
|
||||
self,
|
||||
experiment_name: Optional[Text] = None,
|
||||
recorder_name: Optional[Text] = None,
|
||||
uri: Optional[Text] = None,
|
||||
resume: bool = False,
|
||||
):
|
||||
# Set the tracking uri
|
||||
self.set_uri(uri)
|
||||
# Create experiment
|
||||
if experiment_name is None:
|
||||
experiment_name = self._default_exp_name
|
||||
experiment, _ = self._get_or_create_exp(experiment_name=experiment_name)
|
||||
# set up active experiment
|
||||
# Set up active experiment
|
||||
self.active_experiment = experiment
|
||||
# start the experiment
|
||||
self.active_experiment.start(recorder_name)
|
||||
# Start the experiment
|
||||
self.active_experiment.start(recorder_name, resume)
|
||||
|
||||
return self.active_experiment
|
||||
|
||||
def end_exp(self, recorder_status: str = Recorder.STATUS_S):
|
||||
def end_exp(self, recorder_status: Text = Recorder.STATUS_S):
|
||||
if self.active_experiment is not None:
|
||||
self.active_experiment.end(recorder_status)
|
||||
self.active_experiment = None
|
||||
# When an experiment end, we will release the current uri.
|
||||
self._current_uri = None
|
||||
|
||||
def create_exp(self, experiment_name=None):
|
||||
def create_exp(self, experiment_name: Optional[Text] = None):
|
||||
assert experiment_name is not None
|
||||
# init experiment
|
||||
experiment_id = self.client.create_experiment(experiment_name)
|
||||
experiment = MLflowExperiment(experiment_id, experiment_name, self.uri)
|
||||
experiment._default_name = self.default_exp_name
|
||||
experiment._default_name = self._default_exp_name
|
||||
|
||||
return experiment
|
||||
|
||||
|
||||
@@ -110,7 +110,7 @@ class SignalRecord(RecordTemp):
|
||||
This is the Signal Record class that generates the signal prediction. This class inherits the ``RecordTemp`` class.
|
||||
"""
|
||||
|
||||
def __init__(self, model=None, dataset=None, recorder=None, **kwargs):
|
||||
def __init__(self, model=None, dataset=None, recorder=None):
|
||||
super().__init__(recorder=recorder)
|
||||
self.model = model
|
||||
self.dataset = dataset
|
||||
@@ -164,13 +164,15 @@ class SigAnaRecord(SignalRecord):
|
||||
artifact_path = "sig_analysis"
|
||||
|
||||
def __init__(self, recorder, ana_long_short=False, ann_scaler=252, **kwargs):
|
||||
super().__init__(recorder=recorder, **kwargs)
|
||||
self.ana_long_short = ana_long_short
|
||||
self.ann_scaler = ann_scaler
|
||||
super().__init__(recorder=recorder, **kwargs)
|
||||
# The name must be unique. Otherwise it will be overridden
|
||||
|
||||
def generate(self):
|
||||
self.check(parent=True)
|
||||
def generate(self, **kwargs):
|
||||
try:
|
||||
self.check(parent=True)
|
||||
except FileExistsError:
|
||||
super().generate()
|
||||
|
||||
pred = self.load("pred.pkl")
|
||||
label = self.load("label.pkl")
|
||||
@@ -228,7 +230,7 @@ class PortAnaRecord(SignalRecord):
|
||||
config["backtest"] : dict
|
||||
define the backtest kwargs.
|
||||
"""
|
||||
super().__init__(recorder=recorder)
|
||||
super().__init__(recorder=recorder, **kwargs)
|
||||
|
||||
self.strategy_config = config["strategy"]
|
||||
self.backtest_config = config["backtest"]
|
||||
@@ -236,10 +238,13 @@ class PortAnaRecord(SignalRecord):
|
||||
|
||||
def generate(self, **kwargs):
|
||||
# check previously stored prediction results
|
||||
self.check(parent=True) # "Make sure the parent process is completed and store the data properly."
|
||||
try:
|
||||
self.check(parent=True) # "Make sure the parent process is completed and store the data properly."
|
||||
except FileExistsError:
|
||||
super().generate()
|
||||
|
||||
# custom strategy and get backtest
|
||||
pred_score = super().load()
|
||||
pred_score = super().load("pred.pkl")
|
||||
report_dict = normal_backtest(pred_score, strategy=self.strategy, **self.backtest_config)
|
||||
report_normal = report_dict.get("report_df")
|
||||
positions_normal = report_dict.get("positions")
|
||||
|
||||
@@ -34,7 +34,7 @@ class Recorder:
|
||||
self.status = Recorder.STATUS_S
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.info)
|
||||
return "{name}(info={info})".format(name=self.__class__.__name__, info=self.info)
|
||||
|
||||
def __str__(self):
|
||||
return str(self.info)
|
||||
@@ -201,7 +201,7 @@ class MLflowRecorder(Recorder):
|
||||
def __init__(self, experiment_id, uri, name=None, mlflow_run=None):
|
||||
super(MLflowRecorder, self).__init__(experiment_id, name)
|
||||
self._uri = uri
|
||||
self.artifact_uri = None
|
||||
self._artifact_uri = None
|
||||
self.client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
|
||||
# construct from mlflow run
|
||||
if mlflow_run is not None:
|
||||
@@ -220,14 +220,51 @@ class MLflowRecorder(Recorder):
|
||||
else None
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
name = self.__class__.__name__
|
||||
space_length = len(name) + 1
|
||||
return "{name}(info={info},\n{space}uri={uri},\n{space}artifact_uri={artifact_uri},\n{space}client={client})".format(
|
||||
name=name,
|
||||
space=" " * space_length,
|
||||
info=self.info,
|
||||
uri=self.uri,
|
||||
artifact_uri=self.artifact_uri,
|
||||
client=self.client,
|
||||
)
|
||||
|
||||
@property
|
||||
def uri(self):
|
||||
return self._uri
|
||||
|
||||
@property
|
||||
def artifact_uri(self):
|
||||
return self._artifact_uri
|
||||
|
||||
def get_local_dir(self):
|
||||
"""
|
||||
This function will return the directory path of this recorder.
|
||||
"""
|
||||
if self.artifact_uri is not None:
|
||||
local_dir_path = Path(self.artifact_uri.lstrip("file:")) / ".."
|
||||
local_dir_path = str(local_dir_path.resolve())
|
||||
if os.path.isdir(local_dir_path):
|
||||
return local_dir_path
|
||||
else:
|
||||
raise RuntimeError("This recorder is not saved in the local file system.")
|
||||
|
||||
else:
|
||||
raise Exception(
|
||||
"Please make sure the recorder has been created and started properly before getting artifact uri."
|
||||
)
|
||||
|
||||
def start_run(self):
|
||||
# set the tracking uri
|
||||
mlflow.set_tracking_uri(self._uri)
|
||||
mlflow.set_tracking_uri(self.uri)
|
||||
# start the run
|
||||
run = mlflow.start_run(self.id, self.experiment_id, self.name)
|
||||
# save the run id and artifact_uri
|
||||
self.id = run.info.run_id
|
||||
self.artifact_uri = run.info.artifact_uri
|
||||
self._artifact_uri = run.info.artifact_uri
|
||||
self.start_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
self.status = Recorder.STATUS_R
|
||||
logger.info(f"Recorder {self.id} starts running under Experiment {self.experiment_id} ...")
|
||||
@@ -247,7 +284,7 @@ class MLflowRecorder(Recorder):
|
||||
self.status = status
|
||||
|
||||
def save_objects(self, local_path=None, artifact_path=None, **kwargs):
|
||||
assert self._uri is not None, "Please start the experiment and recorder first before using recorder directly."
|
||||
assert self.uri is not None, "Please start the experiment and recorder first before using recorder directly."
|
||||
if local_path is not None:
|
||||
self.client.log_artifacts(self.id, local_path, artifact_path)
|
||||
else:
|
||||
@@ -259,7 +296,7 @@ class MLflowRecorder(Recorder):
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
def load_object(self, name):
|
||||
assert self._uri is not None, "Please start the experiment and recorder first before using recorder directly."
|
||||
assert self.uri is not None, "Please start the experiment and recorder first before using recorder directly."
|
||||
path = self.client.download_artifacts(self.id, name)
|
||||
with Path(path).open("rb") as f:
|
||||
return pickle.load(f)
|
||||
@@ -289,7 +326,7 @@ class MLflowRecorder(Recorder):
|
||||
)
|
||||
|
||||
def list_artifacts(self, artifact_path=None):
|
||||
assert self._uri is not None, "Please start the experiment and recorder first before using recorder directly."
|
||||
assert self.uri is not None, "Please start the experiment and recorder first before using recorder directly."
|
||||
artifacts = self.client.list_artifacts(self.id, artifact_path)
|
||||
return [art.path for art in artifacts]
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ class CheckBin:
|
||||
self.csv_files = sorted(csv_path.glob(f"*{file_suffix}") if csv_path.is_dir() else [csv_path])
|
||||
|
||||
if check_fields is None:
|
||||
check_fields = list(map(lambda x: x.split(".")[0], bin_path_list[0].glob(f"*.bin")))
|
||||
check_fields = list(map(lambda x: x.name.split(".")[0], bin_path_list[0].glob(f"*.bin")))
|
||||
else:
|
||||
check_fields = check_fields.split(",") if isinstance(check_fields, str) else check_fields
|
||||
self.check_fields = list(map(lambda x: x.strip(), check_fields))
|
||||
@@ -91,6 +91,7 @@ class CheckBin:
|
||||
origin_df[self.symbol_field_name] = symbol
|
||||
origin_df.set_index([self.symbol_field_name, self.date_field_name], inplace=True)
|
||||
origin_df.index.names = qlib_df.index.names
|
||||
origin_df = origin_df.reindex(qlib_df.index)
|
||||
try:
|
||||
compare = datacompy.Compare(
|
||||
origin_df,
|
||||
|
||||
430
scripts/data_collector/base.py
Normal file
430
scripts/data_collector/base.py
Normal file
@@ -0,0 +1,430 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
import abc
|
||||
import time
|
||||
import datetime
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
from typing import Type
|
||||
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
||||
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
from qlib.utils import code_to_fname
|
||||
|
||||
|
||||
class BaseCollector(abc.ABC):
|
||||
|
||||
CACHE_FLAG = "CACHED"
|
||||
NORMAL_FLAG = "NORMAL"
|
||||
|
||||
DEFAULT_START_DATETIME_1D = pd.Timestamp("2000-01-01")
|
||||
DEFAULT_START_DATETIME_1MIN = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 6))
|
||||
DEFAULT_END_DATETIME_1D = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))
|
||||
DEFAULT_END_DATETIME_1MIN = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))
|
||||
|
||||
INTERVAL_1min = "1min"
|
||||
INTERVAL_1d = "1d"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
save_dir: [str, Path],
|
||||
start=None,
|
||||
end=None,
|
||||
interval="1d",
|
||||
max_workers=4,
|
||||
max_collector_count=2,
|
||||
delay=0,
|
||||
check_data_length: bool = False,
|
||||
limit_nums: int = None,
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
save_dir: str
|
||||
instrument save dir
|
||||
max_workers: int
|
||||
workers, default 4
|
||||
max_collector_count: int
|
||||
default 2
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1d
|
||||
start: str
|
||||
start datetime, default None
|
||||
end: str
|
||||
end datetime, default None
|
||||
check_data_length: bool
|
||||
check data length, by default False
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
"""
|
||||
self.save_dir = Path(save_dir).expanduser().resolve()
|
||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.delay = delay
|
||||
self.max_workers = max_workers
|
||||
self.max_collector_count = max_collector_count
|
||||
self.mini_symbol_map = {}
|
||||
self.interval = interval
|
||||
self.check_small_data = check_data_length
|
||||
|
||||
self.start_datetime = self.normalize_start_datetime(start)
|
||||
self.end_datetime = self.normalize_end_datetime(end)
|
||||
|
||||
self.instrument_list = sorted(set(self.get_instrument_list()))
|
||||
|
||||
if limit_nums is not None:
|
||||
try:
|
||||
self.instrument_list = self.instrument_list[: int(limit_nums)]
|
||||
except Exception as e:
|
||||
logger.warning(f"Cannot use limit_nums={limit_nums}, the parameter will be ignored")
|
||||
|
||||
def normalize_start_datetime(self, start_datetime: [str, pd.Timestamp] = None):
|
||||
return (
|
||||
pd.Timestamp(str(start_datetime))
|
||||
if start_datetime
|
||||
else getattr(self, f"DEFAULT_START_DATETIME_{self.interval.upper()}")
|
||||
)
|
||||
|
||||
def normalize_end_datetime(self, end_datetime: [str, pd.Timestamp] = None):
|
||||
return (
|
||||
pd.Timestamp(str(end_datetime))
|
||||
if end_datetime
|
||||
else getattr(self, f"DEFAULT_END_DATETIME_{self.interval.upper()}")
|
||||
)
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def min_numbers_trading(self):
|
||||
# daily, one year: 252 / 4
|
||||
# us 1min, a week: 6.5 * 60 * 5
|
||||
# cn 1min, a week: 4 * 60 * 5
|
||||
raise NotImplementedError("rewrite min_numbers_trading")
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_instrument_list(self):
|
||||
raise NotImplementedError("rewrite get_instrument_list")
|
||||
|
||||
@abc.abstractmethod
|
||||
def normalize_symbol(self, symbol: str):
|
||||
"""normalize symbol"""
|
||||
raise NotImplementedError("rewrite normalize_symbol")
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_data(
|
||||
self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp
|
||||
) -> pd.DataFrame:
|
||||
"""get data with symbol
|
||||
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
interval: str
|
||||
value from [1min, 1d]
|
||||
start_datetime: pd.Timestamp
|
||||
end_datetime: pd.Timestamp
|
||||
|
||||
Returns
|
||||
---------
|
||||
pd.DataFrame, "symbol" in pd.columns
|
||||
|
||||
"""
|
||||
raise NotImplementedError("rewrite get_timezone")
|
||||
|
||||
def sleep(self):
|
||||
time.sleep(self.delay)
|
||||
|
||||
def _simple_collector(self, symbol: str):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
|
||||
"""
|
||||
self.sleep()
|
||||
df = self.get_data(symbol, self.interval, self.start_datetime, self.end_datetime)
|
||||
_result = self.NORMAL_FLAG
|
||||
if self.check_small_data:
|
||||
_result = self.cache_small_data(symbol, df)
|
||||
if _result == self.NORMAL_FLAG:
|
||||
self.save_instrument(symbol, df)
|
||||
return _result
|
||||
|
||||
def save_instrument(self, symbol, df: pd.DataFrame):
|
||||
"""save instrument data to file
|
||||
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
instrument code
|
||||
df : pd.DataFrame
|
||||
df.columns must contain "symbol" and "datetime"
|
||||
"""
|
||||
if df is None or df.empty:
|
||||
logger.warning(f"{symbol} is empty")
|
||||
return
|
||||
|
||||
symbol = self.normalize_symbol(symbol)
|
||||
symbol = code_to_fname(symbol)
|
||||
instrument_path = self.save_dir.joinpath(f"{symbol}.csv")
|
||||
df["symbol"] = symbol
|
||||
if instrument_path.exists():
|
||||
_old_df = pd.read_csv(instrument_path)
|
||||
df = _old_df.append(df, sort=False)
|
||||
df.to_csv(instrument_path, index=False)
|
||||
|
||||
def cache_small_data(self, symbol, df):
|
||||
if len(df) <= self.min_numbers_trading:
|
||||
logger.warning(f"the number of trading days of {symbol} is less than {self.min_numbers_trading}!")
|
||||
_temp = self.mini_symbol_map.setdefault(symbol, [])
|
||||
_temp.append(df.copy())
|
||||
return self.CACHE_FLAG
|
||||
else:
|
||||
if symbol in self.mini_symbol_map:
|
||||
self.mini_symbol_map.pop(symbol)
|
||||
return self.NORMAL_FLAG
|
||||
|
||||
def _collector(self, instrument_list):
|
||||
|
||||
error_symbol = []
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
with tqdm(total=len(instrument_list)) as p_bar:
|
||||
for _symbol, _result in zip(instrument_list, executor.map(self._simple_collector, instrument_list)):
|
||||
if _result != self.NORMAL_FLAG:
|
||||
error_symbol.append(_symbol)
|
||||
p_bar.update()
|
||||
print(error_symbol)
|
||||
logger.info(f"error symbol nums: {len(error_symbol)}")
|
||||
logger.info(f"current get symbol nums: {len(instrument_list)}")
|
||||
error_symbol.extend(self.mini_symbol_map.keys())
|
||||
return sorted(set(error_symbol))
|
||||
|
||||
def collector_data(self):
|
||||
"""collector data"""
|
||||
logger.info("start collector data......")
|
||||
instrument_list = self.instrument_list
|
||||
for i in range(self.max_collector_count):
|
||||
if not instrument_list:
|
||||
break
|
||||
logger.info(f"getting data: {i+1}")
|
||||
instrument_list = self._collector(instrument_list)
|
||||
logger.info(f"{i+1} finish.")
|
||||
for _symbol, _df_list in self.mini_symbol_map.items():
|
||||
self.save_instrument(
|
||||
_symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"])
|
||||
)
|
||||
if self.mini_symbol_map:
|
||||
logger.warning(f"less than {self.min_numbers_trading} instrument list: {list(self.mini_symbol_map.keys())}")
|
||||
logger.info(f"total {len(self.instrument_list)}, error: {len(set(instrument_list))}")
|
||||
|
||||
|
||||
class BaseNormalize(abc.ABC):
|
||||
def __init__(
|
||||
self,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
date_field_name: str
|
||||
date field name, default is date
|
||||
symbol_field_name: str
|
||||
symbol field name, default is symbol
|
||||
"""
|
||||
self._date_field_name = date_field_name
|
||||
self._symbol_field_name = symbol_field_name
|
||||
|
||||
self._calendar_list = self._get_calendar_list()
|
||||
|
||||
@abc.abstractmethod
|
||||
def normalize(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
# normalize
|
||||
raise NotImplementedError("")
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_calendar_list(self):
|
||||
"""Get benchmark calendar"""
|
||||
raise NotImplementedError("")
|
||||
|
||||
|
||||
class Normalize:
|
||||
def __init__(
|
||||
self,
|
||||
source_dir: [str, Path],
|
||||
target_dir: [str, Path],
|
||||
normalize_class: Type[BaseNormalize],
|
||||
max_workers: int = 16,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source_dir: str or Path
|
||||
The directory where the raw data collected from the Internet is saved
|
||||
target_dir: str or Path
|
||||
Directory for normalize data
|
||||
normalize_class: Type[YahooNormalize]
|
||||
normalize class
|
||||
max_workers: int
|
||||
Concurrent number, default is 16
|
||||
date_field_name: str
|
||||
date field name, default is date
|
||||
symbol_field_name: str
|
||||
symbol field name, default is symbol
|
||||
"""
|
||||
if not (source_dir and target_dir):
|
||||
raise ValueError("source_dir and target_dir cannot be None")
|
||||
self._source_dir = Path(source_dir).expanduser()
|
||||
self._target_dir = Path(target_dir).expanduser()
|
||||
self._target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._max_workers = max_workers
|
||||
|
||||
self._normalize_obj = normalize_class(date_field_name=date_field_name, symbol_field_name=symbol_field_name)
|
||||
|
||||
def _executor(self, file_path: Path):
|
||||
file_path = Path(file_path)
|
||||
df = pd.read_csv(file_path)
|
||||
df = self._normalize_obj.normalize(df)
|
||||
if not df.empty:
|
||||
df.to_csv(self._target_dir.joinpath(file_path.name), index=False)
|
||||
|
||||
def normalize(self):
|
||||
logger.info("normalize data......")
|
||||
|
||||
with ProcessPoolExecutor(max_workers=self._max_workers) as worker:
|
||||
file_list = list(self._source_dir.glob("*.csv"))
|
||||
with tqdm(total=len(file_list)) as p_bar:
|
||||
for _ in worker.map(self._executor, file_list):
|
||||
p_bar.update()
|
||||
|
||||
|
||||
class BaseRun(abc.ABC):
|
||||
def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, interval="1d"):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source_dir: str
|
||||
The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
|
||||
normalize_dir: str
|
||||
Directory for normalize data, default "Path(__file__).parent/normalize"
|
||||
max_workers: int
|
||||
Concurrent number, default is 4
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1d
|
||||
"""
|
||||
if source_dir is None:
|
||||
source_dir = Path(self.default_base_dir).joinpath("_source")
|
||||
self.source_dir = Path(source_dir).expanduser().resolve()
|
||||
self.source_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if normalize_dir is None:
|
||||
normalize_dir = Path(self.default_base_dir).joinpath("normalize")
|
||||
self.normalize_dir = Path(normalize_dir).expanduser().resolve()
|
||||
self.normalize_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._cur_module = importlib.import_module("collector")
|
||||
self.max_workers = max_workers
|
||||
self.interval = interval
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def collector_class_name(self):
|
||||
raise NotImplementedError("rewrite normalize_symbol")
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def normalize_class_name(self):
|
||||
raise NotImplementedError("rewrite normalize_symbol")
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def default_base_dir(self) -> [Path, str]:
|
||||
raise NotImplementedError("rewrite normalize_symbol")
|
||||
|
||||
def download_data(
|
||||
self,
|
||||
max_collector_count=2,
|
||||
delay=0,
|
||||
start=None,
|
||||
end=None,
|
||||
interval="1d",
|
||||
check_data_length=False,
|
||||
limit_nums=None,
|
||||
):
|
||||
"""download data from Internet
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_collector_count: int
|
||||
default 2
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1d
|
||||
start: str
|
||||
start datetime, default "2000-01-01"
|
||||
end: str
|
||||
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
|
||||
check_data_length: bool
|
||||
check data length, by default False
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
|
||||
Examples
|
||||
---------
|
||||
# get daily data
|
||||
$ python collector.py download_data --source_dir ~/.qlib/instrument_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
|
||||
# get 1m data
|
||||
$ python collector.py download_data --source_dir ~/.qlib/instrument_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m
|
||||
"""
|
||||
|
||||
_class = getattr(self._cur_module, self.collector_class_name) # type: Type[BaseCollector]
|
||||
_class(
|
||||
self.source_dir,
|
||||
max_workers=self.max_workers,
|
||||
max_collector_count=max_collector_count,
|
||||
delay=delay,
|
||||
start=start,
|
||||
end=end,
|
||||
interval=interval,
|
||||
check_data_length=check_data_length,
|
||||
limit_nums=limit_nums,
|
||||
).collector_data()
|
||||
|
||||
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"):
|
||||
"""normalize data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
date_field_name: str
|
||||
date field name, default date
|
||||
symbol_field_name: str
|
||||
symbol field name, default symbol
|
||||
|
||||
Examples
|
||||
---------
|
||||
$ python collector.py normalize_data --source_dir ~/.qlib/instrument_data/source --normalize_dir ~/.qlib/instrument_data/normalize --region CN --interval 1d
|
||||
"""
|
||||
_class = getattr(self._cur_module, self.normalize_class_name)
|
||||
yc = Normalize(
|
||||
source_dir=self.source_dir,
|
||||
target_dir=self.normalize_dir,
|
||||
normalize_class=_class,
|
||||
max_workers=self.max_workers,
|
||||
date_field_name=date_field_name,
|
||||
symbol_field_name=symbol_field_name,
|
||||
)
|
||||
yc.normalize()
|
||||
51
scripts/data_collector/fund/README.md
Normal file
51
scripts/data_collector/fund/README.md
Normal file
@@ -0,0 +1,51 @@
|
||||
# Collect Fund Data
|
||||
|
||||
> *Please pay **ATTENTION** that the data is collected from [天天基金网](https://fund.eastmoney.com/) and the data might not be perfect. We recommend users to prepare their own data if they have high-quality dataset. For more information, users can refer to the [related document](https://qlib.readthedocs.io/en/latest/component/data.html#converting-csv-format-into-qlib-format)*
|
||||
|
||||
## Requirements
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Collector Data
|
||||
|
||||
|
||||
### CN Data
|
||||
|
||||
#### 1d from East Money
|
||||
|
||||
```bash
|
||||
|
||||
# download from eastmoney.com
|
||||
python collector.py download_data --source_dir ~/.qlib/fund_data/source/cn_1d --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
|
||||
|
||||
# normalize
|
||||
python collector.py normalize_data --source_dir ~/.qlib/fund_data/source/cn_1d --normalize_dir ~/.qlib/fund_data/source/cn_1d_nor --region CN --interval 1d --date_field_name FSRQ
|
||||
|
||||
# dump data
|
||||
cd qlib/scripts
|
||||
python dump_bin.py dump_all --csv_path ~/.qlib/fund_data/source/cn_1d_nor --qlib_dir ~/.qlib/qlib_data/cn_fund_data --freq day --date_field_name FSRQ --include_fields DWJZ,LJJZ
|
||||
|
||||
```
|
||||
|
||||
### using data
|
||||
|
||||
```python
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
qlib.init(provider_uri="~/.qlib/qlib_data/cn_fund_data")
|
||||
df = D.features(D.instruments(market="all"), ["$DWJZ", "$LJJZ"], freq="day")
|
||||
```
|
||||
|
||||
|
||||
### Help
|
||||
```bash
|
||||
pythono collector.py collector_data --help
|
||||
```
|
||||
|
||||
## Parameters
|
||||
|
||||
- interval: 1d
|
||||
- region: CN
|
||||
312
scripts/data_collector/fund/collector.py
Normal file
312
scripts/data_collector/fund/collector.py
Normal file
@@ -0,0 +1,312 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import abc
|
||||
import sys
|
||||
import copy
|
||||
import time
|
||||
import datetime
|
||||
import importlib
|
||||
import json
|
||||
from abc import ABC
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Type
|
||||
|
||||
import fire
|
||||
import requests
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from loguru import logger
|
||||
from dateutil.tz import tzlocal
|
||||
from qlib.config import REG_CN as REGION_CN
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
from data_collector.base import BaseCollector, BaseNormalize, BaseRun
|
||||
from data_collector.utils import get_calendar_list, get_en_fund_symbols
|
||||
|
||||
INDEX_BENCH_URL = "http://api.fund.eastmoney.com/f10/lsjz?callback=jQuery_&fundCode={index_code}&pageIndex=1&pageSize={numberOfHistoricalDaysToCrawl}&startDate={startDate}&endDate={endDate}"
|
||||
|
||||
|
||||
class FundCollector(BaseCollector):
|
||||
def __init__(
|
||||
self,
|
||||
save_dir: [str, Path],
|
||||
start=None,
|
||||
end=None,
|
||||
interval="1d",
|
||||
max_workers=4,
|
||||
max_collector_count=2,
|
||||
delay=0,
|
||||
check_data_length: bool = False,
|
||||
limit_nums: int = None,
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
save_dir: str
|
||||
fund save dir
|
||||
max_workers: int
|
||||
workers, default 4
|
||||
max_collector_count: int
|
||||
default 2
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1min
|
||||
start: str
|
||||
start datetime, default None
|
||||
end: str
|
||||
end datetime, default None
|
||||
check_data_length: bool
|
||||
check data length, by default False
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
"""
|
||||
super(FundCollector, self).__init__(
|
||||
save_dir=save_dir,
|
||||
start=start,
|
||||
end=end,
|
||||
interval=interval,
|
||||
max_workers=max_workers,
|
||||
max_collector_count=max_collector_count,
|
||||
delay=delay,
|
||||
check_data_length=check_data_length,
|
||||
limit_nums=limit_nums,
|
||||
)
|
||||
|
||||
self.init_datetime()
|
||||
|
||||
def init_datetime(self):
|
||||
if self.interval == self.INTERVAL_1min:
|
||||
self.start_datetime = max(self.start_datetime, self.DEFAULT_START_DATETIME_1MIN)
|
||||
elif self.interval == self.INTERVAL_1d:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"interval error: {self.interval}")
|
||||
|
||||
self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone)
|
||||
self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone)
|
||||
|
||||
@staticmethod
|
||||
def convert_datetime(dt: [pd.Timestamp, datetime.date, str], timezone):
|
||||
try:
|
||||
dt = pd.Timestamp(dt, tz=timezone).timestamp()
|
||||
dt = pd.Timestamp(dt, tz=tzlocal(), unit="s")
|
||||
except ValueError as e:
|
||||
pass
|
||||
return dt
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def _timezone(self):
|
||||
raise NotImplementedError("rewrite get_timezone")
|
||||
|
||||
@staticmethod
|
||||
def get_data_from_remote(symbol, interval, start, end):
|
||||
error_msg = f"{symbol}-{interval}-{start}-{end}"
|
||||
|
||||
try:
|
||||
# TODO: numberOfHistoricalDaysToCrawl should be bigger enouhg
|
||||
url = INDEX_BENCH_URL.format(
|
||||
index_code=symbol, numberOfHistoricalDaysToCrawl=10000, startDate=start, endDate=end
|
||||
)
|
||||
resp = requests.get(url, headers={"referer": "http://fund.eastmoney.com/110022.html"})
|
||||
|
||||
if resp.status_code != 200:
|
||||
raise ValueError("request error")
|
||||
|
||||
data = json.loads(resp.text.split("(")[-1].split(")")[0])
|
||||
|
||||
# Some funds don't show the net value, example: http://fundf10.eastmoney.com/jjjz_010288.html
|
||||
SYType = data["Data"]["SYType"]
|
||||
if (SYType == "每万份收益") or (SYType == "每百份收益") or (SYType == "每百万份收益"):
|
||||
raise Exception("The fund contains 每*份收益")
|
||||
|
||||
# TODO: should we sort the value by datetime?
|
||||
_resp = pd.DataFrame(data["Data"]["LSJZList"])
|
||||
|
||||
if isinstance(_resp, pd.DataFrame):
|
||||
return _resp.reset_index()
|
||||
except Exception as e:
|
||||
logger.warning(f"{error_msg}:{e}")
|
||||
|
||||
def get_data(
|
||||
self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp
|
||||
) -> [pd.DataFrame]:
|
||||
def _get_simple(start_, end_):
|
||||
self.sleep()
|
||||
_remote_interval = interval
|
||||
return self.get_data_from_remote(
|
||||
symbol,
|
||||
interval=_remote_interval,
|
||||
start=start_,
|
||||
end=end_,
|
||||
)
|
||||
|
||||
if interval == self.INTERVAL_1d:
|
||||
_result = _get_simple(start_datetime, end_datetime)
|
||||
else:
|
||||
raise ValueError(f"cannot support {interval}")
|
||||
return _result
|
||||
|
||||
|
||||
class FundollectorCN(FundCollector, ABC):
|
||||
def get_instrument_list(self):
|
||||
logger.info("get cn fund symbols......")
|
||||
symbols = get_en_fund_symbols()
|
||||
logger.info(f"get {len(symbols)} symbols.")
|
||||
return symbols
|
||||
|
||||
def normalize_symbol(self, symbol):
|
||||
return symbol
|
||||
|
||||
@property
|
||||
def _timezone(self):
|
||||
return "Asia/Shanghai"
|
||||
|
||||
|
||||
class FundCollectorCN1d(FundollectorCN):
|
||||
@property
|
||||
def min_numbers_trading(self):
|
||||
return 252 / 4
|
||||
|
||||
|
||||
class FundNormalize(BaseNormalize):
|
||||
DAILY_FORMAT = "%Y-%m-%d"
|
||||
|
||||
@staticmethod
|
||||
def normalize_fund(
|
||||
df: pd.DataFrame,
|
||||
calendar_list: list = None,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
):
|
||||
if df.empty:
|
||||
return df
|
||||
df = df.copy()
|
||||
df.set_index(date_field_name, inplace=True)
|
||||
df.index = pd.to_datetime(df.index)
|
||||
df = df[~df.index.duplicated(keep="first")]
|
||||
if calendar_list is not None:
|
||||
df = df.reindex(
|
||||
pd.DataFrame(index=calendar_list)
|
||||
.loc[
|
||||
pd.Timestamp(df.index.min()).date() : pd.Timestamp(df.index.max()).date()
|
||||
+ pd.Timedelta(hours=23, minutes=59)
|
||||
]
|
||||
.index
|
||||
)
|
||||
df.sort_index(inplace=True)
|
||||
|
||||
df.index.names = [date_field_name]
|
||||
return df.reset_index()
|
||||
|
||||
def normalize(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
# normalize
|
||||
df = self.normalize_fund(df, self._calendar_list, self._date_field_name, self._symbol_field_name)
|
||||
return df
|
||||
|
||||
|
||||
class FundNormalize1d(FundNormalize):
|
||||
pass
|
||||
|
||||
|
||||
class FundNormalizeCN:
|
||||
def _get_calendar_list(self):
|
||||
return get_calendar_list("ALL")
|
||||
|
||||
|
||||
class FundNormalizeCN1d(FundNormalizeCN, FundNormalize1d):
|
||||
pass
|
||||
|
||||
|
||||
class Run(BaseRun):
|
||||
def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, interval="1d", region=REGION_CN):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source_dir: str
|
||||
The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
|
||||
normalize_dir: str
|
||||
Directory for normalize data, default "Path(__file__).parent/normalize"
|
||||
max_workers: int
|
||||
Concurrent number, default is 4
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1d
|
||||
region: str
|
||||
region, value from ["CN"], default "CN"
|
||||
"""
|
||||
super().__init__(source_dir, normalize_dir, max_workers, interval)
|
||||
self.region = region
|
||||
|
||||
@property
|
||||
def collector_class_name(self):
|
||||
return f"FundCollector{self.region.upper()}{self.interval}"
|
||||
|
||||
@property
|
||||
def normalize_class_name(self):
|
||||
return f"FundNormalize{self.region.upper()}{self.interval}"
|
||||
|
||||
@property
|
||||
def default_base_dir(self) -> [Path, str]:
|
||||
return CUR_DIR
|
||||
|
||||
def download_data(
|
||||
self,
|
||||
max_collector_count=2,
|
||||
delay=0,
|
||||
start=None,
|
||||
end=None,
|
||||
interval="1d",
|
||||
check_data_length=False,
|
||||
limit_nums=None,
|
||||
):
|
||||
"""download data from Internet
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_collector_count: int
|
||||
default 2
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1d
|
||||
start: str
|
||||
start datetime, default "2000-01-01"
|
||||
end: str
|
||||
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
|
||||
check_data_length: bool # if this param useful?
|
||||
check data length, by default False
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
|
||||
Examples
|
||||
---------
|
||||
# get daily data
|
||||
$ python collector.py download_data --source_dir ~/.qlib/fund_data/source/cn_1d --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
|
||||
"""
|
||||
|
||||
super(Run, self).download_data(max_collector_count, delay, start, end, interval, check_data_length, limit_nums)
|
||||
|
||||
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"):
|
||||
"""normalize data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
date_field_name: str
|
||||
date field name, default date
|
||||
symbol_field_name: str
|
||||
symbol field name, default symbol
|
||||
|
||||
Examples
|
||||
---------
|
||||
$ python collector.py normalize_data --source_dir ~/.qlib/fund_data/source/cn_1d --normalize_dir ~/.qlib/fund_data/source/cn_1d_nor --region CN --interval 1d --date_field_name FSRQ
|
||||
"""
|
||||
super(Run, self).normalize_data(date_field_name, symbol_field_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(Run)
|
||||
10
scripts/data_collector/fund/requirements.txt
Normal file
10
scripts/data_collector/fund/requirements.txt
Normal file
@@ -0,0 +1,10 @@
|
||||
loguru
|
||||
fire
|
||||
requests
|
||||
numpy
|
||||
pandas
|
||||
tqdm
|
||||
lxml
|
||||
loguru
|
||||
yahooquery
|
||||
json
|
||||
@@ -10,10 +10,10 @@ pip install -r requirements.txt
|
||||
|
||||
```bash
|
||||
# parse instruments, using in qlib/instruments.
|
||||
python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments
|
||||
python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/us_data --method parse_instruments
|
||||
|
||||
# parse new companies
|
||||
python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies
|
||||
python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/us_data --method save_new_companies
|
||||
|
||||
# index_name support: SP500, NASDAQ100, DJIA, SP400
|
||||
# help
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import re
|
||||
import os
|
||||
import time
|
||||
import bisect
|
||||
import pickle
|
||||
@@ -14,6 +15,9 @@ import pandas as pd
|
||||
from lxml import etree
|
||||
from loguru import logger
|
||||
from yahooquery import Ticker
|
||||
from tqdm import tqdm
|
||||
from functools import partial
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
|
||||
HS_SYMBOLS_URL = "http://app.finance.ifeng.com/hq/list.php?type=stock_a&class={s_type}"
|
||||
|
||||
@@ -34,6 +38,7 @@ _BENCH_CALENDAR_LIST = None
|
||||
_ALL_CALENDAR_LIST = None
|
||||
_HS_SYMBOLS = None
|
||||
_US_SYMBOLS = None
|
||||
_EN_FUND_SYMBOLS = None
|
||||
_CALENDAR_MAP = {}
|
||||
|
||||
# NOTE: Until 2020-10-20 20:00:00
|
||||
@@ -93,6 +98,78 @@ def get_calendar_list(bench_code="CSI300") -> list:
|
||||
return calendar
|
||||
|
||||
|
||||
def return_date_list(date_field_name: str, file_path: Path):
|
||||
date_list = pd.read_csv(file_path, sep=",", index_col=0)[date_field_name].to_list()
|
||||
return sorted(map(lambda x: pd.Timestamp(x), date_list))
|
||||
|
||||
|
||||
def get_calendar_list_by_ratio(
|
||||
source_dir: [str, Path],
|
||||
date_field_name: str = "date",
|
||||
threshold: float = 0.5,
|
||||
minimum_count: int = 10,
|
||||
max_workers: int = 16,
|
||||
) -> list:
|
||||
"""get calendar list by selecting the date when few funds trade in this day
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source_dir: str or Path
|
||||
The directory where the raw data collected from the Internet is saved
|
||||
date_field_name: str
|
||||
date field name, default is date
|
||||
threshold: float
|
||||
threshold to exclude some days when few funds trade in this day, default 0.5
|
||||
minimum_count: int
|
||||
minimum count of funds should trade in one day
|
||||
max_workers: int
|
||||
Concurrent number, default is 16
|
||||
|
||||
Returns
|
||||
-------
|
||||
history calendar list
|
||||
"""
|
||||
logger.info(f"get calendar list from {source_dir} by threshold = {threshold}......")
|
||||
|
||||
source_dir = Path(source_dir).expanduser()
|
||||
file_list = list(source_dir.glob("*.csv"))
|
||||
|
||||
_number_all_funds = len(file_list)
|
||||
|
||||
logger.info(f"count how many funds trade in this day......")
|
||||
_dict_count_trade = dict() # dict{date:count}
|
||||
_fun = partial(return_date_list, date_field_name)
|
||||
all_oldest_list = []
|
||||
with tqdm(total=_number_all_funds) as p_bar:
|
||||
with ProcessPoolExecutor(max_workers=max_workers) as executor:
|
||||
for date_list in executor.map(_fun, file_list):
|
||||
if date_list:
|
||||
all_oldest_list.append(date_list[0])
|
||||
for date in date_list:
|
||||
if date not in _dict_count_trade.keys():
|
||||
_dict_count_trade[date] = 0
|
||||
|
||||
_dict_count_trade[date] += 1
|
||||
|
||||
p_bar.update()
|
||||
|
||||
logger.info(f"count how many funds have founded in this day......")
|
||||
_dict_count_founding = {date: _number_all_funds for date in _dict_count_trade.keys()} # dict{date:count}
|
||||
with tqdm(total=_number_all_funds) as p_bar:
|
||||
for oldest_date in all_oldest_list:
|
||||
for date in _dict_count_founding.keys():
|
||||
if date < oldest_date:
|
||||
_dict_count_founding[date] -= 1
|
||||
|
||||
calendar = [
|
||||
date
|
||||
for date in _dict_count_trade
|
||||
if _dict_count_trade[date] >= max(int(_dict_count_founding[date] * threshold), minimum_count)
|
||||
]
|
||||
|
||||
return calendar
|
||||
|
||||
|
||||
def get_hs_stock_symbols() -> list:
|
||||
"""get SH/SZ stock symbols
|
||||
|
||||
@@ -220,6 +297,42 @@ def get_us_stock_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
return _US_SYMBOLS
|
||||
|
||||
|
||||
def get_en_fund_symbols(qlib_data_path: [str, Path] = None) -> list:
|
||||
"""get en fund symbols
|
||||
|
||||
Returns
|
||||
-------
|
||||
fund symbols in China
|
||||
"""
|
||||
global _EN_FUND_SYMBOLS
|
||||
|
||||
@deco_retry
|
||||
def _get_eastmoney():
|
||||
url = "http://fund.eastmoney.com/js/fundcode_search.js"
|
||||
resp = requests.get(url)
|
||||
if resp.status_code != 200:
|
||||
raise ValueError("request error")
|
||||
try:
|
||||
_symbols = []
|
||||
for sub_data in re.findall(r"[\[](.*?)[\]]", resp.content.decode().split("= [")[-1].replace("];", "")):
|
||||
data = sub_data.replace('"', "").replace("'", "")
|
||||
# TODO: do we need other informations, like fund_name from ['000001', 'HXCZHH', '华夏成长混合', '混合型', 'HUAXIACHENGZHANGHUNHE']
|
||||
_symbols.append(data.split(",")[0])
|
||||
except Exception as e:
|
||||
logger.warning(f"request error: {e}")
|
||||
raise
|
||||
if len(_symbols) < 8000:
|
||||
raise ValueError("request error")
|
||||
return _symbols
|
||||
|
||||
if _EN_FUND_SYMBOLS is None:
|
||||
_all_symbols = _get_eastmoney()
|
||||
|
||||
_EN_FUND_SYMBOLS = sorted(set(_all_symbols))
|
||||
|
||||
return _EN_FUND_SYMBOLS
|
||||
|
||||
|
||||
def symbol_suffix_to_prefix(symbol: str, capital: bool = True) -> str:
|
||||
"""symbol suffix to prefix
|
||||
|
||||
|
||||
@@ -48,7 +48,7 @@ python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1d -
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
qlib.init(provider_uri="~/.qlib/qlib_data/qlib_cn_1d", region="CN")
|
||||
qlib.init(provider_uri="~/.qlib/qlib_data/qlib_cn_1d", region="cn")
|
||||
df = D.features(D.instruments("all"), ["$close"], freq="day")
|
||||
```
|
||||
|
||||
@@ -78,7 +78,7 @@ python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1min
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
qlib.init(provider_uri="~/.qlib/qlib_data/qlib_cn_1min", region="CN")
|
||||
qlib.init(provider_uri="~/.qlib/qlib_data/qlib_cn_1min", region="cn")
|
||||
df = D.features(D.instruments("all"), ["$close"], freq="1min")
|
||||
|
||||
```
|
||||
@@ -97,7 +97,7 @@ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source/us_1d
|
||||
|
||||
# dump data
|
||||
cd qlib/scripts
|
||||
python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/cn_1d_nor --qlib_dir ~/.qlib/stock_data/source/qlib_us_1d --freq day --exclude_fields date,adjclose,dividends,splits,symbol
|
||||
python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/us_1d_nor --qlib_dir ~/.qlib/stock_data/source/qlib_us_1d --freq day --exclude_fields date,adjclose,dividends,splits,symbol
|
||||
```
|
||||
|
||||
#### 1d from qlib
|
||||
@@ -113,7 +113,7 @@ python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_us_1d -
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
qlib.init(provider_uri="~/.qlib/qlib_data/qlib_us_1d", region="US")
|
||||
qlib.init(provider_uri="~/.qlib/qlib_data/qlib_us_1d", region="us")
|
||||
df = D.features(D.instruments("all"), ["$close"], freq="day")
|
||||
|
||||
```
|
||||
|
||||
@@ -10,158 +10,26 @@ import importlib
|
||||
from abc import ABC
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Type
|
||||
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
||||
|
||||
import fire
|
||||
import requests
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
from yahooquery import Ticker
|
||||
from dateutil.tz import tzlocal
|
||||
from qlib.utils import code_to_fname, fname_to_code
|
||||
from qlib.config import REG_CN as REGION_CN
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
from data_collector.base import BaseCollector, BaseNormalize, BaseRun
|
||||
from data_collector.utils import get_calendar_list, get_hs_stock_symbols, get_us_stock_symbols
|
||||
|
||||
INDEX_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{index_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg={begin}&end={end}"
|
||||
REGION_CN = "CN"
|
||||
REGION_US = "US"
|
||||
|
||||
|
||||
class YahooData:
|
||||
START_DATETIME = pd.Timestamp("2000-01-01")
|
||||
HIGH_FREQ_START_DATETIME = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 6))
|
||||
END_DATETIME = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))
|
||||
INTERVAL_1min = "1min"
|
||||
INTERVAL_1d = "1d"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
timezone: str = None,
|
||||
start=None,
|
||||
end=None,
|
||||
interval="1d",
|
||||
delay=0,
|
||||
show_1min_logging: bool = False,
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
timezone: str
|
||||
The timezone where the data is located
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1min
|
||||
start: str
|
||||
start datetime, default None
|
||||
end: str
|
||||
end datetime, default None
|
||||
show_1min_logging: bool
|
||||
show 1min logging, by default False; if True, there may be many warning logs
|
||||
"""
|
||||
self._timezone = tzlocal() if timezone is None else timezone
|
||||
self._delay = delay
|
||||
self._interval = interval
|
||||
self._show_1min_logging = show_1min_logging
|
||||
self.start_datetime = pd.Timestamp(str(start)) if start else self.START_DATETIME
|
||||
self.end_datetime = min(pd.Timestamp(str(end)) if end else self.END_DATETIME, self.END_DATETIME)
|
||||
if self._interval == self.INTERVAL_1min:
|
||||
self.start_datetime = max(self.start_datetime, self.HIGH_FREQ_START_DATETIME)
|
||||
elif self._interval == self.INTERVAL_1d:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"interval error: {self._interval}")
|
||||
|
||||
# using for 1min
|
||||
self._next_datetime = self.convert_datetime(self.start_datetime.date() + pd.Timedelta(days=1), self._timezone)
|
||||
self._latest_datetime = self.convert_datetime(self.end_datetime.date(), self._timezone)
|
||||
|
||||
self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone)
|
||||
self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone)
|
||||
|
||||
@staticmethod
|
||||
def convert_datetime(dt: [pd.Timestamp, datetime.date, str], timezone):
|
||||
try:
|
||||
dt = pd.Timestamp(dt, tz=timezone).timestamp()
|
||||
dt = pd.Timestamp(dt, tz=tzlocal(), unit="s")
|
||||
except ValueError as e:
|
||||
pass
|
||||
return dt
|
||||
|
||||
def _sleep(self):
|
||||
time.sleep(self._delay)
|
||||
|
||||
@staticmethod
|
||||
def get_data_from_remote(symbol, interval, start, end, show_1min_logging: bool = False):
|
||||
error_msg = f"{symbol}-{interval}-{start}-{end}"
|
||||
|
||||
def _show_logging_func():
|
||||
if interval == YahooData.INTERVAL_1min and show_1min_logging:
|
||||
logger.warning(f"{error_msg}:{_resp}")
|
||||
|
||||
interval = "1m" if interval in ["1m", "1min"] else interval
|
||||
try:
|
||||
_resp = Ticker(symbol, asynchronous=False).history(interval=interval, start=start, end=end)
|
||||
if isinstance(_resp, pd.DataFrame):
|
||||
return _resp.reset_index()
|
||||
elif isinstance(_resp, dict):
|
||||
_temp_data = _resp.get(symbol, {})
|
||||
if isinstance(_temp_data, str) or (
|
||||
isinstance(_resp, dict) and _temp_data.get("indicators", {}).get("quote", None) is None
|
||||
):
|
||||
_show_logging_func()
|
||||
else:
|
||||
_show_logging_func()
|
||||
except Exception as e:
|
||||
logger.warning(f"{error_msg}:{e}")
|
||||
|
||||
def get_data(self, symbol: str) -> [pd.DataFrame]:
|
||||
def _get_simple(start_, end_):
|
||||
self._sleep()
|
||||
_remote_interval = "1m" if self._interval == self.INTERVAL_1min else self._interval
|
||||
return self.get_data_from_remote(
|
||||
symbol,
|
||||
interval=_remote_interval,
|
||||
start=start_,
|
||||
end=end_,
|
||||
show_1min_logging=self._show_1min_logging,
|
||||
)
|
||||
|
||||
_result = None
|
||||
if self._interval == self.INTERVAL_1d:
|
||||
_result = _get_simple(self.start_datetime, self.end_datetime)
|
||||
elif self._interval == self.INTERVAL_1min:
|
||||
if self._next_datetime >= self._latest_datetime:
|
||||
_result = _get_simple(self.start_datetime, self.end_datetime)
|
||||
else:
|
||||
_res = []
|
||||
|
||||
def _get_multi(start_, end_):
|
||||
_resp = _get_simple(start_, end_)
|
||||
if _resp is not None and not _resp.empty:
|
||||
_res.append(_resp)
|
||||
|
||||
for _s, _e in (
|
||||
(self.start_datetime, self._next_datetime),
|
||||
(self._latest_datetime, self.end_datetime),
|
||||
):
|
||||
_get_multi(_s, _e)
|
||||
for _start in pd.date_range(self._next_datetime, self._latest_datetime, closed="left"):
|
||||
_end = _start + pd.Timedelta(days=1)
|
||||
_get_multi(_start, _end)
|
||||
if _res:
|
||||
_result = pd.concat(_res, sort=False).sort_values(["symbol", "date"])
|
||||
else:
|
||||
raise ValueError(f"cannot support {self._interval}")
|
||||
return _result
|
||||
|
||||
|
||||
class YahooCollector:
|
||||
class YahooCollector(BaseCollector):
|
||||
def __init__(
|
||||
self,
|
||||
save_dir: [str, Path],
|
||||
@@ -173,7 +41,6 @@ class YahooCollector:
|
||||
delay=0,
|
||||
check_data_length: bool = False,
|
||||
limit_nums: int = None,
|
||||
show_1min_logging: bool = False,
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -197,131 +64,118 @@ class YahooCollector:
|
||||
check data length, by default False
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
show_1min_logging: bool
|
||||
show 1m logging, by default False; if True, there may be many warning logs
|
||||
"""
|
||||
self.save_dir = Path(save_dir).expanduser().resolve()
|
||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._delay = delay
|
||||
self.max_workers = max_workers
|
||||
self._max_collector_count = max_collector_count
|
||||
self._mini_symbol_map = {}
|
||||
self._interval = interval
|
||||
self._check_small_data = check_data_length
|
||||
|
||||
self.stock_list = sorted(set(self.get_stock_list()))
|
||||
if limit_nums is not None:
|
||||
try:
|
||||
self.stock_list = self.stock_list[: int(limit_nums)]
|
||||
except Exception as e:
|
||||
logger.warning(f"Cannot use limit_nums={limit_nums}, the parameter will be ignored")
|
||||
|
||||
self.yahoo_data = YahooData(
|
||||
timezone=self._timezone,
|
||||
super(YahooCollector, self).__init__(
|
||||
save_dir=save_dir,
|
||||
start=start,
|
||||
end=end,
|
||||
interval=interval,
|
||||
max_workers=max_workers,
|
||||
max_collector_count=max_collector_count,
|
||||
delay=delay,
|
||||
show_1min_logging=show_1min_logging,
|
||||
check_data_length=check_data_length,
|
||||
limit_nums=limit_nums,
|
||||
)
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def min_numbers_trading(self):
|
||||
# daily, one year: 252 / 4
|
||||
# us 1min, a week: 6.5 * 60 * 5
|
||||
# cn 1min, a week: 4 * 60 * 5
|
||||
raise NotImplementedError("rewrite min_numbers_trading")
|
||||
self.init_datetime()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_stock_list(self):
|
||||
raise NotImplementedError("rewrite get_stock_list")
|
||||
def init_datetime(self):
|
||||
if self.interval == self.INTERVAL_1min:
|
||||
self.start_datetime = max(self.start_datetime, self.DEFAULT_START_DATETIME_1MIN)
|
||||
elif self.interval == self.INTERVAL_1d:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"interval error: {self.interval}")
|
||||
|
||||
# using for 1min
|
||||
self._next_datetime = self.convert_datetime(self.start_datetime.date() + pd.Timedelta(days=1), self._timezone)
|
||||
self._latest_datetime = self.convert_datetime(self.end_datetime.date(), self._timezone)
|
||||
|
||||
self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone)
|
||||
self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone)
|
||||
|
||||
@staticmethod
|
||||
def convert_datetime(dt: [pd.Timestamp, datetime.date, str], timezone):
|
||||
try:
|
||||
dt = pd.Timestamp(dt, tz=timezone).timestamp()
|
||||
dt = pd.Timestamp(dt, tz=tzlocal(), unit="s")
|
||||
except ValueError as e:
|
||||
pass
|
||||
return dt
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def _timezone(self):
|
||||
raise NotImplementedError("rewrite get_timezone")
|
||||
|
||||
def save_stock(self, symbol, df: pd.DataFrame):
|
||||
"""save stock data to file
|
||||
@staticmethod
|
||||
def get_data_from_remote(symbol, interval, start, end, show_1min_logging: bool = False):
|
||||
error_msg = f"{symbol}-{interval}-{start}-{end}"
|
||||
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
stock code
|
||||
df : pd.DataFrame
|
||||
df.columns must contain "symbol" and "datetime"
|
||||
"""
|
||||
if df.empty:
|
||||
logger.warning(f"{symbol} is empty")
|
||||
return
|
||||
def _show_logging_func():
|
||||
if interval == YahooCollector.INTERVAL_1min and show_1min_logging:
|
||||
logger.warning(f"{error_msg}:{_resp}")
|
||||
|
||||
symbol = self.normalize_symbol(symbol)
|
||||
symbol = code_to_fname(symbol)
|
||||
stock_path = self.save_dir.joinpath(f"{symbol}.csv")
|
||||
df["symbol"] = symbol
|
||||
if stock_path.exists():
|
||||
_old_df = pd.read_csv(stock_path)
|
||||
df = _old_df.append(df, sort=False)
|
||||
df.to_csv(stock_path, index=False)
|
||||
interval = "1m" if interval in ["1m", "1min"] else interval
|
||||
try:
|
||||
_resp = Ticker(symbol, asynchronous=False).history(interval=interval, start=start, end=end)
|
||||
if isinstance(_resp, pd.DataFrame):
|
||||
return _resp.reset_index()
|
||||
elif isinstance(_resp, dict):
|
||||
_temp_data = _resp.get(symbol, {})
|
||||
if isinstance(_temp_data, str) or (
|
||||
isinstance(_resp, dict) and _temp_data.get("indicators", {}).get("quote", None) is None
|
||||
):
|
||||
_show_logging_func()
|
||||
else:
|
||||
_show_logging_func()
|
||||
except Exception as e:
|
||||
logger.warning(f"{error_msg}:{e}")
|
||||
|
||||
def _save_small_data(self, symbol, df):
|
||||
if len(df) <= self.min_numbers_trading:
|
||||
logger.warning(f"the number of trading days of {symbol} is less than {self.min_numbers_trading}!")
|
||||
_temp = self._mini_symbol_map.setdefault(symbol, [])
|
||||
_temp.append(df.copy())
|
||||
return None
|
||||
else:
|
||||
if symbol in self._mini_symbol_map:
|
||||
self._mini_symbol_map.pop(symbol)
|
||||
return symbol
|
||||
def get_data(
|
||||
self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp
|
||||
) -> pd.DataFrame:
|
||||
def _get_simple(start_, end_):
|
||||
self.sleep()
|
||||
_remote_interval = "1m" if interval == self.INTERVAL_1min else interval
|
||||
return self.get_data_from_remote(
|
||||
symbol,
|
||||
interval=_remote_interval,
|
||||
start=start_,
|
||||
end=end_,
|
||||
)
|
||||
|
||||
def _get_data(self, symbol):
|
||||
_result = None
|
||||
df = self.yahoo_data.get_data(symbol)
|
||||
if isinstance(df, pd.DataFrame):
|
||||
if not df.empty:
|
||||
if self._check_small_data:
|
||||
if self._save_small_data(symbol, df) is not None:
|
||||
_result = symbol
|
||||
self.save_stock(symbol, df)
|
||||
else:
|
||||
_result = symbol
|
||||
self.save_stock(symbol, df)
|
||||
return _result
|
||||
if interval == self.INTERVAL_1d:
|
||||
_result = _get_simple(start_datetime, end_datetime)
|
||||
elif interval == self.INTERVAL_1min:
|
||||
if self._next_datetime >= self._latest_datetime:
|
||||
_result = _get_simple(start_datetime, end_datetime)
|
||||
else:
|
||||
_res = []
|
||||
|
||||
def _collector(self, stock_list):
|
||||
def _get_multi(start_, end_):
|
||||
_resp = _get_simple(start_, end_)
|
||||
if _resp is not None and not _resp.empty:
|
||||
_res.append(_resp)
|
||||
|
||||
error_symbol = []
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
with tqdm(total=len(stock_list)) as p_bar:
|
||||
for _symbol, _result in zip(stock_list, executor.map(self._get_data, stock_list)):
|
||||
if _result is None:
|
||||
error_symbol.append(_symbol)
|
||||
p_bar.update()
|
||||
print(error_symbol)
|
||||
logger.info(f"error symbol nums: {len(error_symbol)}")
|
||||
logger.info(f"current get symbol nums: {len(stock_list)}")
|
||||
error_symbol.extend(self._mini_symbol_map.keys())
|
||||
return sorted(set(error_symbol))
|
||||
for _s, _e in (
|
||||
(self.start_datetime, self._next_datetime),
|
||||
(self._latest_datetime, self.end_datetime),
|
||||
):
|
||||
_get_multi(_s, _e)
|
||||
for _start in pd.date_range(self._next_datetime, self._latest_datetime, closed="left"):
|
||||
_end = _start + pd.Timedelta(days=1)
|
||||
_get_multi(_start, _end)
|
||||
if _res:
|
||||
_result = pd.concat(_res, sort=False).sort_values(["symbol", "date"])
|
||||
else:
|
||||
raise ValueError(f"cannot support {self.interval}")
|
||||
return pd.DataFrame() if _result is None else _result
|
||||
|
||||
def collector_data(self):
|
||||
"""collector data"""
|
||||
logger.info("start collector yahoo data......")
|
||||
stock_list = self.stock_list
|
||||
for i in range(self._max_collector_count):
|
||||
if not stock_list:
|
||||
break
|
||||
logger.info(f"getting data: {i+1}")
|
||||
stock_list = self._collector(stock_list)
|
||||
logger.info(f"{i+1} finish.")
|
||||
for _symbol, _df_list in self._mini_symbol_map.items():
|
||||
self.save_stock(_symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"]))
|
||||
if self._mini_symbol_map:
|
||||
logger.warning(f"less than {self.min_numbers_trading} stock list: {list(self._mini_symbol_map.keys())}")
|
||||
logger.info(f"total {len(self.stock_list)}, error: {len(set(stock_list))}")
|
||||
|
||||
super(YahooCollector, self).collector_data()
|
||||
self.download_index_data()
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -329,14 +183,9 @@ class YahooCollector:
|
||||
"""download index data"""
|
||||
raise NotImplementedError("rewrite download_index_data")
|
||||
|
||||
@abc.abstractmethod
|
||||
def normalize_symbol(self, symbol: str):
|
||||
"""normalize symbol"""
|
||||
raise NotImplementedError("rewrite normalize_symbol")
|
||||
|
||||
|
||||
class YahooCollectorCN(YahooCollector, ABC):
|
||||
def get_stock_list(self):
|
||||
def get_instrument_list(self):
|
||||
logger.info("get HS stock symbos......")
|
||||
symbols = get_hs_stock_symbols()
|
||||
logger.info(f"get {len(symbols)} symbols.")
|
||||
@@ -360,8 +209,8 @@ class YahooCollectorCN1d(YahooCollectorCN):
|
||||
def download_index_data(self):
|
||||
# TODO: from MSN
|
||||
_format = "%Y%m%d"
|
||||
_begin = self.yahoo_data.start_datetime.strftime(_format)
|
||||
_end = (self.yahoo_data.end_datetime + pd.Timedelta(days=-1)).strftime(_format)
|
||||
_begin = self.start_datetime.strftime(_format)
|
||||
_end = (self.end_datetime + pd.Timedelta(days=-1)).strftime(_format)
|
||||
for _index_name, _index_code in {"csi300": "000300", "csi100": "000903"}.items():
|
||||
logger.info(f"get bench data: {_index_name}({_index_code})......")
|
||||
try:
|
||||
@@ -396,11 +245,11 @@ class YahooCollectorCN1min(YahooCollectorCN):
|
||||
|
||||
def download_index_data(self):
|
||||
# TODO: 1m
|
||||
logger.warning(f"{self.__class__.__name__} {self._interval} does not support: download_index_data")
|
||||
logger.warning(f"{self.__class__.__name__} {self.interval} does not support: download_index_data")
|
||||
|
||||
|
||||
class YahooCollectorUS(YahooCollector, ABC):
|
||||
def get_stock_list(self):
|
||||
def get_instrument_list(self):
|
||||
logger.info("get US stock symbols......")
|
||||
symbols = get_us_stock_symbols() + [
|
||||
"^GSPC",
|
||||
@@ -433,29 +282,10 @@ class YahooCollectorUS1min(YahooCollectorUS):
|
||||
return 60 * 6.5 * 5
|
||||
|
||||
|
||||
class YahooNormalize:
|
||||
class YahooNormalize(BaseNormalize):
|
||||
COLUMNS = ["open", "close", "high", "low", "volume"]
|
||||
DAILY_FORMAT = "%Y-%m-%d"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
date_field_name: str
|
||||
date field name, default is date
|
||||
symbol_field_name: str
|
||||
symbol field name, default is symbol
|
||||
"""
|
||||
self._date_field_name = date_field_name
|
||||
self._symbol_field_name = symbol_field_name
|
||||
|
||||
self._calendar_list = self._get_calendar_list()
|
||||
|
||||
@staticmethod
|
||||
def normalize_yahoo(
|
||||
df: pd.DataFrame,
|
||||
@@ -498,11 +328,6 @@ class YahooNormalize:
|
||||
df = self.adjusted_price(df)
|
||||
return df
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_calendar_list(self):
|
||||
"""Get benchmark calendar"""
|
||||
raise NotImplementedError("")
|
||||
|
||||
@abc.abstractmethod
|
||||
def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""adjusted price"""
|
||||
@@ -618,7 +443,9 @@ class YahooNormalize1min(YahooNormalize, ABC):
|
||||
# get 1d data from yahoo
|
||||
_start = pd.Timestamp(df[self._date_field_name].min()).strftime(self.DAILY_FORMAT)
|
||||
_end = (pd.Timestamp(df[self._date_field_name].max()) + pd.Timedelta(days=1)).strftime(self.DAILY_FORMAT)
|
||||
data_1d = YahooData.get_data_from_remote(self.symbol_to_yahoo(symbol), interval="1d", start=_start, end=_end)
|
||||
data_1d = YahooCollector.get_data_from_remote(
|
||||
self.symbol_to_yahoo(symbol), interval="1d", start=_start, end=_end
|
||||
)
|
||||
if data_1d is None or data_1d.empty:
|
||||
df["factor"] = 1
|
||||
# TODO: np.nan or 1 or 0
|
||||
@@ -723,62 +550,8 @@ class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1min):
|
||||
return get_calendar_list("ALL")
|
||||
|
||||
|
||||
class Normalize:
|
||||
def __init__(
|
||||
self,
|
||||
source_dir: [str, Path],
|
||||
target_dir: [str, Path],
|
||||
normalize_class: Type[YahooNormalize],
|
||||
max_workers: int = 16,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source_dir: str or Path
|
||||
The directory where the raw data collected from the Internet is saved
|
||||
target_dir: str or Path
|
||||
Directory for normalize data
|
||||
normalize_class: Type[YahooNormalize]
|
||||
normalize class
|
||||
max_workers: int
|
||||
Concurrent number, default is 16
|
||||
date_field_name: str
|
||||
date field name, default is date
|
||||
symbol_field_name: str
|
||||
symbol field name, default is symbol
|
||||
"""
|
||||
if not (source_dir and target_dir):
|
||||
raise ValueError("source_dir and target_dir cannot be None")
|
||||
self._source_dir = Path(source_dir).expanduser()
|
||||
self._target_dir = Path(target_dir).expanduser()
|
||||
self._target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._max_workers = max_workers
|
||||
|
||||
self._normalize_obj = normalize_class(date_field_name=date_field_name, symbol_field_name=symbol_field_name)
|
||||
|
||||
def _executor(self, file_path: Path):
|
||||
file_path = Path(file_path)
|
||||
df = pd.read_csv(file_path)
|
||||
df = self._normalize_obj.normalize(df)
|
||||
if not df.empty:
|
||||
df.to_csv(self._target_dir.joinpath(file_path.name), index=False)
|
||||
|
||||
def normalize(self):
|
||||
logger.info("normalize data......")
|
||||
|
||||
with ProcessPoolExecutor(max_workers=self._max_workers) as worker:
|
||||
file_list = list(self._source_dir.glob("*.csv"))
|
||||
with tqdm(total=len(file_list)) as p_bar:
|
||||
for _ in worker.map(self._executor, file_list):
|
||||
p_bar.update()
|
||||
|
||||
|
||||
class Run:
|
||||
def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, region=REGION_CN):
|
||||
class Run(BaseRun):
|
||||
def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, interval="1d", region=REGION_CN):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
@@ -789,23 +562,26 @@ class Run:
|
||||
Directory for normalize data, default "Path(__file__).parent/normalize"
|
||||
max_workers: int
|
||||
Concurrent number, default is 4
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1d
|
||||
region: str
|
||||
region, value from ["CN", "US"], default "CN"
|
||||
"""
|
||||
if source_dir is None:
|
||||
source_dir = CUR_DIR.joinpath("source")
|
||||
self.source_dir = Path(source_dir).expanduser().resolve()
|
||||
self.source_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if normalize_dir is None:
|
||||
normalize_dir = CUR_DIR.joinpath("normalize")
|
||||
self.normalize_dir = Path(normalize_dir).expanduser().resolve()
|
||||
self.normalize_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._cur_module = importlib.import_module("collector")
|
||||
self.max_workers = max_workers
|
||||
super().__init__(source_dir, normalize_dir, max_workers, interval)
|
||||
self.region = region
|
||||
|
||||
@property
|
||||
def collector_class_name(self):
|
||||
return f"YahooCollector{self.region.upper()}{self.interval}"
|
||||
|
||||
@property
|
||||
def normalize_class_name(self):
|
||||
return f"YahooNormalize{self.region.upper()}{self.interval}"
|
||||
|
||||
@property
|
||||
def default_base_dir(self) -> [Path, str]:
|
||||
return CUR_DIR
|
||||
|
||||
def download_data(
|
||||
self,
|
||||
max_collector_count=2,
|
||||
@@ -815,7 +591,6 @@ class Run:
|
||||
interval="1d",
|
||||
check_data_length=False,
|
||||
limit_nums=None,
|
||||
show_1min_logging=False,
|
||||
):
|
||||
"""download data from Internet
|
||||
|
||||
@@ -835,8 +610,6 @@ class Run:
|
||||
check data length, by default False
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
show_1min_logging: bool
|
||||
show 1m logging, by default False; if True, there may be many warning logs
|
||||
|
||||
Examples
|
||||
---------
|
||||
@@ -846,29 +619,13 @@ class Run:
|
||||
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m
|
||||
"""
|
||||
|
||||
_class = getattr(
|
||||
self._cur_module, f"YahooCollector{self.region.upper()}{interval}"
|
||||
) # type: Type[YahooCollector]
|
||||
_class(
|
||||
self.source_dir,
|
||||
max_workers=self.max_workers,
|
||||
max_collector_count=max_collector_count,
|
||||
delay=delay,
|
||||
start=start,
|
||||
end=end,
|
||||
interval=interval,
|
||||
check_data_length=check_data_length,
|
||||
limit_nums=limit_nums,
|
||||
show_1min_logging=show_1min_logging,
|
||||
).collector_data()
|
||||
super(Run, self).download_data(max_collector_count, delay, start, end, interval, check_data_length, limit_nums)
|
||||
|
||||
def normalize_data(self, interval: str = "1d", date_field_name: str = "date", symbol_field_name: str = "symbol"):
|
||||
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"):
|
||||
"""normalize data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1d
|
||||
date_field_name: str
|
||||
date field name, default date
|
||||
symbol_field_name: str
|
||||
@@ -878,16 +635,7 @@ class Run:
|
||||
---------
|
||||
$ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --interval 1d
|
||||
"""
|
||||
_class = getattr(self._cur_module, f"YahooNormalize{self.region.upper()}{interval}")
|
||||
yc = Normalize(
|
||||
source_dir=self.source_dir,
|
||||
target_dir=self.normalize_dir,
|
||||
normalize_class=_class,
|
||||
max_workers=self.max_workers,
|
||||
date_field_name=date_field_name,
|
||||
symbol_field_name=symbol_field_name,
|
||||
)
|
||||
yc.normalize()
|
||||
super(Run, self).normalize_data(date_field_name, symbol_field_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -219,7 +219,7 @@ class DumpDataBase:
|
||||
# used when creating a bin file
|
||||
date_index = self.get_datetime_index(_df, calendar_list)
|
||||
for field in self.get_dump_fields(_df.columns):
|
||||
bin_path = features_dir.joinpath(f"{field}.{self.freq}{self.DUMP_FILE_SUFFIX}")
|
||||
bin_path = features_dir.joinpath(f"{field.lower()}.{self.freq}{self.DUMP_FILE_SUFFIX}")
|
||||
if field not in _df.columns:
|
||||
continue
|
||||
if bin_path.exists() and self._mode == self.UPDATE_MODE:
|
||||
@@ -234,7 +234,7 @@ class DumpDataBase:
|
||||
if isinstance(file_or_data, pd.DataFrame):
|
||||
if file_or_data.empty:
|
||||
return
|
||||
code = fname_to_code(file_or_data.iloc[0][self.symbol_field_name].lower())
|
||||
code = fname_to_code(str(file_or_data.iloc[0][self.symbol_field_name]).lower())
|
||||
df = file_or_data
|
||||
elif isinstance(file_or_data, Path):
|
||||
code = self.get_symbol_from_file(file_or_data)
|
||||
|
||||
1
setup.py
1
setup.py
@@ -56,6 +56,7 @@ REQUIRED = [
|
||||
"joblib>=0.17.0",
|
||||
"ruamel.yaml>=0.16.12",
|
||||
"pymongo==3.7.2", # For task management
|
||||
"scikit-learn>=0.22",
|
||||
]
|
||||
|
||||
# Numpy include
|
||||
|
||||
@@ -19,6 +19,7 @@ from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
from qlib.contrib.workflow.record_temp import SignalMseRecord
|
||||
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord, SigAnaRecord, PortAnaRecord
|
||||
@@ -96,7 +97,6 @@ port_analysis_config = {
|
||||
}
|
||||
|
||||
|
||||
# train
|
||||
def train():
|
||||
"""train model
|
||||
|
||||
@@ -111,6 +111,9 @@ def train():
|
||||
# model initiaiton
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
# To test __repr__
|
||||
print(dataset)
|
||||
print(R)
|
||||
|
||||
# start exp
|
||||
with R.start(experiment_name="workflow"):
|
||||
@@ -119,6 +122,10 @@ def train():
|
||||
|
||||
# prediction
|
||||
recorder = R.get_recorder()
|
||||
# To test __repr__
|
||||
print(recorder)
|
||||
# To test get_local_dir
|
||||
print(recorder.get_local_dir())
|
||||
rid = recorder.id
|
||||
sr = SignalRecord(model, dataset, recorder)
|
||||
sr.generate()
|
||||
@@ -133,6 +140,59 @@ def train():
|
||||
return pred_score, {"ic": ic, "ric": ric}, rid
|
||||
|
||||
|
||||
def train_with_sigana():
|
||||
"""train model followed by SigAnaRecord
|
||||
|
||||
Returns
|
||||
-------
|
||||
pred_score: pandas.DataFrame
|
||||
predict scores
|
||||
performance: dict
|
||||
model performance
|
||||
"""
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
|
||||
# start exp
|
||||
with R.start(experiment_name="workflow_with_sigana"):
|
||||
R.log_params(**flatten_dict(task))
|
||||
model.fit(dataset)
|
||||
|
||||
# predict and calculate ic and ric
|
||||
recorder = R.get_recorder()
|
||||
sar = SigAnaRecord(recorder, model=model, dataset=dataset)
|
||||
sar.generate()
|
||||
ic = sar.load(sar.get_path("ic.pkl"))
|
||||
ric = sar.load(sar.get_path("ric.pkl"))
|
||||
pred_score = sar.load("pred.pkl")
|
||||
|
||||
smr = SignalMseRecord(recorder)
|
||||
smr.generate()
|
||||
uri_path = R.get_uri()
|
||||
return pred_score, {"ic": ic, "ric": ric}, uri_path
|
||||
|
||||
|
||||
def fake_experiment():
|
||||
"""A fake experiment workflow to test uri
|
||||
|
||||
Returns
|
||||
-------
|
||||
pass_or_not_for_default_uri: bool
|
||||
pass_or_not_for_current_uri: bool
|
||||
temporary_exp_dir: str
|
||||
"""
|
||||
|
||||
# start exp
|
||||
default_uri = R.get_uri()
|
||||
current_uri = "file:./temp-test-exp-mag"
|
||||
with R.start(experiment_name="fake_workflow_for_expm", uri=current_uri):
|
||||
R.log_params(**flatten_dict(task))
|
||||
|
||||
current_uri_to_check = R.get_uri()
|
||||
default_uri_to_check = R.get_uri()
|
||||
return default_uri == default_uri_to_check, current_uri == current_uri_to_check, current_uri
|
||||
|
||||
|
||||
def backtest_analysis(pred, rid):
|
||||
"""backtest and analysis
|
||||
|
||||
@@ -168,12 +228,18 @@ class TestAllFlow(TestAutoData):
|
||||
def tearDownClass(cls) -> None:
|
||||
shutil.rmtree(str(Path(C["exp_manager"]["kwargs"]["uri"].strip("file:")).resolve()))
|
||||
|
||||
def test_0_train(self):
|
||||
def test_0_train_with_sigana(self):
|
||||
TestAllFlow.PRED_SCORE, ic_ric, uri_path = train_with_sigana()
|
||||
self.assertGreaterEqual(ic_ric["ic"].all(), 0, "train failed")
|
||||
self.assertGreaterEqual(ic_ric["ric"].all(), 0, "train failed")
|
||||
shutil.rmtree(str(Path(uri_path.strip("file:")).resolve()))
|
||||
|
||||
def test_1_train(self):
|
||||
TestAllFlow.PRED_SCORE, ic_ric, TestAllFlow.RID = train()
|
||||
self.assertGreaterEqual(ic_ric["ic"].all(), 0, "train failed")
|
||||
self.assertGreaterEqual(ic_ric["ric"].all(), 0, "train failed")
|
||||
|
||||
def test_1_backtest(self):
|
||||
def test_2_backtest(self):
|
||||
analyze_df = backtest_analysis(TestAllFlow.PRED_SCORE, TestAllFlow.RID)
|
||||
self.assertGreaterEqual(
|
||||
analyze_df.loc(axis=0)["excess_return_with_cost", "annualized_return"].values[0],
|
||||
@@ -181,6 +247,12 @@ class TestAllFlow(TestAutoData):
|
||||
"backtest failed",
|
||||
)
|
||||
|
||||
def test_3_expmanager(self):
|
||||
pass_default, pass_current, uri_path = fake_experiment()
|
||||
self.assertTrue(pass_default, msg="default uri is incorrect")
|
||||
self.assertTrue(pass_current, msg="current uri is incorrect")
|
||||
shutil.rmtree(str(Path(uri_path.strip("file:")).resolve()))
|
||||
|
||||
|
||||
def suite():
|
||||
_suite = unittest.TestSuite()
|
||||
|
||||
111
tests/test_structured_cov_estimator.py
Normal file
111
tests/test_structured_cov_estimator.py
Normal file
@@ -0,0 +1,111 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
from scipy.linalg import sqrtm
|
||||
|
||||
from qlib.model.riskmodel import StructuredCovEstimator
|
||||
|
||||
|
||||
class TestStructuredCovEstimator(unittest.TestCase):
|
||||
def test_random_covariance(self):
|
||||
# Try to estimate the covariance from a randomly generated matrix.
|
||||
NUM_VARIABLE = 10
|
||||
NUM_OBSERVATION = 200
|
||||
EPS = 1e-6
|
||||
|
||||
estimator = StructuredCovEstimator(scale_return=False, assume_centered=True)
|
||||
|
||||
X = np.random.rand(NUM_OBSERVATION, NUM_VARIABLE)
|
||||
|
||||
est_cov = estimator.predict(X, is_price=False)
|
||||
np_cov = np.cov(X.T) # While numpy assume row means variable, qlib assume the other wise.
|
||||
|
||||
delta = abs(est_cov - np_cov)
|
||||
if_identical = (delta < EPS).all()
|
||||
|
||||
self.assertTrue(if_identical)
|
||||
|
||||
def test_nan_option_covariance(self):
|
||||
# Test if nan_option is correctly passed.
|
||||
NUM_VARIABLE = 10
|
||||
NUM_OBSERVATION = 200
|
||||
EPS = 1e-6
|
||||
|
||||
estimator = StructuredCovEstimator(scale_return=False, assume_centered=True, nan_option="fill")
|
||||
|
||||
X = np.random.rand(NUM_OBSERVATION, NUM_VARIABLE)
|
||||
|
||||
est_cov = estimator.predict(X, is_price=False)
|
||||
np_cov = np.cov(X.T) # While numpy assume row means variable, qlib assume the other wise.
|
||||
|
||||
delta = abs(est_cov - np_cov)
|
||||
if_identical = (delta < EPS).all()
|
||||
|
||||
self.assertTrue(if_identical)
|
||||
|
||||
def test_decompose_covariance(self):
|
||||
# Test if return_decomposed_components is correctly passed.
|
||||
NUM_VARIABLE = 10
|
||||
NUM_OBSERVATION = 200
|
||||
|
||||
estimator = StructuredCovEstimator(scale_return=False, assume_centered=True, nan_option="fill")
|
||||
|
||||
X = np.random.rand(NUM_OBSERVATION, NUM_VARIABLE)
|
||||
|
||||
F, cov_b, var_u = estimator.predict(X, is_price=False, return_decomposed_components=True)
|
||||
|
||||
self.assertTrue(F is not None and cov_b is not None and var_u is not None)
|
||||
|
||||
def test_constructed_covariance(self):
|
||||
# Try to estimate the covariance from a specially crafted matrix.
|
||||
# There should be some significant correlation since X is specially crafted.
|
||||
NUM_VARIABLE = 7
|
||||
NUM_OBSERVATION = 500
|
||||
EPS = 0.1
|
||||
|
||||
estimator = StructuredCovEstimator(scale_return=False, assume_centered=True, num_factors=NUM_VARIABLE - 1)
|
||||
|
||||
sqrt_cov = None
|
||||
while sqrt_cov is None or (np.iscomplex(sqrt_cov)).any():
|
||||
cov = np.random.rand(NUM_VARIABLE, NUM_VARIABLE)
|
||||
for i in range(NUM_VARIABLE):
|
||||
cov[i][i] = 1
|
||||
sqrt_cov = sqrtm(cov)
|
||||
X = np.random.rand(NUM_OBSERVATION, NUM_VARIABLE) @ sqrt_cov
|
||||
|
||||
est_cov = estimator.predict(X, is_price=False)
|
||||
np_cov = np.cov(X.T) # While numpy assume row means variable, qlib assume the other wise.
|
||||
|
||||
delta = abs(est_cov - np_cov)
|
||||
if_identical = (delta < EPS).all()
|
||||
|
||||
self.assertTrue(if_identical)
|
||||
|
||||
def test_decomposition(self):
|
||||
# Try to estimate the covariance from a specially crafted matrix.
|
||||
# The matrix is generated in the assumption that observations can be predicted by multiple factors.
|
||||
NUM_VARIABLE = 30
|
||||
NUM_OBSERVATION = 100
|
||||
NUM_FACTOR = 10
|
||||
EPS = 0.1
|
||||
|
||||
estimator = StructuredCovEstimator(scale_return=False, assume_centered=True, num_factors=NUM_FACTOR)
|
||||
|
||||
F = np.random.rand(NUM_VARIABLE, NUM_FACTOR)
|
||||
B = np.random.rand(NUM_FACTOR, NUM_OBSERVATION)
|
||||
U = np.random.rand(NUM_OBSERVATION, NUM_VARIABLE)
|
||||
X = (F @ B).T + U
|
||||
|
||||
est_cov = estimator.predict(X, is_price=False)
|
||||
np_cov = np.cov(X.T) # While numpy assume row means variable, qlib assume the other wise.
|
||||
|
||||
delta = abs(est_cov - np_cov)
|
||||
if_identical = (delta < EPS).all()
|
||||
|
||||
self.assertTrue(if_identical)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user