mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Merge remote-tracking branch 'remoteGit/main' into addFund
This commit is contained in:
12
.deepsource.toml
Normal file
12
.deepsource.toml
Normal file
@@ -0,0 +1,12 @@
|
||||
version = 1
|
||||
|
||||
test_patterns = ["tests/test_*.py"]
|
||||
|
||||
exclude_patterns = ["examples/**"]
|
||||
|
||||
[[analyzers]]
|
||||
name = "python"
|
||||
enabled = true
|
||||
|
||||
[analyzers.meta]
|
||||
runtime_version = "3.x.x"
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -34,3 +34,7 @@ tags
|
||||
|
||||
.pytest_cache/
|
||||
.vscode/
|
||||
|
||||
*.swp
|
||||
|
||||
./pretrain
|
||||
|
||||
@@ -237,6 +237,7 @@ Here is a list of models built on `Qlib`.
|
||||
- [SFM based on pytorch (Liheng Zhang, et al. 2017)](qlib/contrib/model/pytorch_sfm.py)
|
||||
- [TFT based on tensorflow (Bryan Lim, et al. 2019)](examples/benchmarks/TFT/tft.py)
|
||||
- [TabNet based on pytorch (Sercan O. Arik, et al. 2019)](qlib/contrib/model/pytorch_tabnet.py)
|
||||
- [DoubleEnsemble based on LightGBM (Chuheng Zhang, et al. 2020)](qlib/contrib/model/double_ensemble.py)
|
||||
|
||||
Your PR of new Quant models is highly welcomed.
|
||||
|
||||
@@ -318,6 +319,7 @@ Qlib data are stored in a compact format, which is efficient to be combined into
|
||||
|
||||
|
||||
# Related Reports
|
||||
- [【华泰金工林晓明团队】图神经网络选股与Qlib实践——华泰人工智能系列之四十二](https://mp.weixin.qq.com/s/w5fDB6oAv9dO6vlhf1kmhA)
|
||||
- [Guide To Qlib: Microsoft’s AI Investment Platform](https://analyticsindiamag.com/qlib/)
|
||||
- [【华泰金工林晓明团队】微软AI量化投资平台Qlib体验——华泰人工智能系列之四十](https://mp.weixin.qq.com/s/Brcd7im4NibJOJzZfMn6tQ)
|
||||
- [微软也搞AI量化平台?还是开源的!](https://mp.weixin.qq.com/s/47bP5YwxfTp2uTHjUBzJQQ)
|
||||
|
||||
@@ -70,3 +70,31 @@ If the issue is not resolved, use ``keys *`` to find if multiple keys exist. If
|
||||
|
||||
|
||||
Also, feel free to post a new issue in our GitHub repository. We always check each issue carefully and try our best to solve them.
|
||||
|
||||
3. ModuleNotFoundError: No module named 'qlib.data._libs.rolling'
|
||||
------------------------------------------------------------------------------------------------------------------------------------
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
#### Do not import qlib package in the repository directory in case of importing qlib from . without compiling #####
|
||||
Traceback (most recent call last):
|
||||
File "<stdin>", line 1, in <module>
|
||||
File "qlib/qlib/__init__.py", line 19, in init
|
||||
from .data.cache import H
|
||||
File "qlib/qlib/data/__init__.py", line 8, in <module>
|
||||
from .data import (
|
||||
File "qlib/qlib/data/data.py", line 20, in <module>
|
||||
from .cache import H
|
||||
File "qlib/qlib/data/cache.py", line 36, in <module>
|
||||
from .ops import Operators
|
||||
File "qlib/qlib/data/ops.py", line 19, in <module>
|
||||
from ._libs.rolling import rolling_slope, rolling_rsquare, rolling_resi
|
||||
ModuleNotFoundError: No module named 'qlib.data._libs.rolling'
|
||||
|
||||
- If the error occurs when importing ``qlib`` package with ``PyCharm`` IDE, users can execute the following command in the project root folder to compile Cython files and generate executable files:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python setup.py build_ext --inplace
|
||||
|
||||
- If the error occurs when importing ``qlib`` package with command ``python`` , users need to change the running directory to ensure that the script does not run in the project directory.
|
||||
@@ -61,7 +61,7 @@ In addition to China-Stock data, ``Qlib`` also includes a US-Stock dataset, whic
|
||||
|
||||
python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/us_data --region us
|
||||
|
||||
After running the above command, users can find china-stock and us-stock data in ``Qlib`` format in the ``~/.qlib/csv_data/cn_data`` directory and ``~/.qlib/csv_data/us_data`` directory respectively.
|
||||
After running the above command, users can find china-stock and us-stock data in ``Qlib`` format in the ``~/.qlib/qlib_data/cn_data`` directory and ``~/.qlib/qlib_data/us_data`` directory respectively.
|
||||
|
||||
``Qlib`` also provides the scripts in ``scripts/data_collector`` to help users crawl the latest data on the Internet and convert it to qlib format.
|
||||
|
||||
@@ -163,7 +163,7 @@ The `trade unit` defines the unit number of stocks can be used in a trade, and t
|
||||
- If users use ``Qlib`` in china-stock mode, china-stock data is required. Users can use ``Qlib`` in china-stock mode according to the following steps:
|
||||
- Download china-stock in qlib format, please refer to section `Qlib Format Dataset <#qlib-format-dataset>`_.
|
||||
- Initialize ``Qlib`` in china-stock mode
|
||||
Supposed that users download their Qlib format data in the directory ``~/.qlib/csv_data/cn_data``. Users only need to initialize ``Qlib`` as follows.
|
||||
Supposed that users download their Qlib format data in the directory ``~/.qlib/qlib_data/cn_data``. Users only need to initialize ``Qlib`` as follows.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -174,7 +174,7 @@ The `trade unit` defines the unit number of stocks can be used in a trade, and t
|
||||
- If users use ``Qlib`` in US-stock mode, US-stock data is required. ``Qlib`` also provides a script to download US-stock data. Users can use ``Qlib`` in US-stock mode according to the following steps:
|
||||
- Download us-stock in qlib format, please refer to section `Qlib Format Dataset <#qlib-format-dataset>`_.
|
||||
- Initialize ``Qlib`` in US-stock mode
|
||||
Supposed that users prepare their Qlib format data in the directory ``~/.qlib/csv_data/us_data``. Users only need to initialize ``Qlib`` as follows.
|
||||
Supposed that users prepare their Qlib format data in the directory ``~/.qlib/qlib_data/us_data``. Users only need to initialize ``Qlib`` as follows.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@@ -218,6 +218,25 @@ Filter
|
||||
- `cross-sectional features filter` \: rule_expression = '$rank($close)<10'
|
||||
- `time-sequence features filter`: rule_expression = '$Ref($close, 3)>100'
|
||||
|
||||
Here is a simple example showing how to use filter in a basic ``Qlib`` workflow configuration file:
|
||||
|
||||
.. code-block:: yaml
|
||||
|
||||
filter: &filter
|
||||
filter_type: ExpressionDFilter
|
||||
rule_expression: "Ref($close, -2) / Ref($close, -1) > 1"
|
||||
filter_start_time: 2010-01-01
|
||||
filter_end_time: 2010-01-07
|
||||
keep: False
|
||||
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2010-01-01
|
||||
end_time: 2021-01-22
|
||||
fit_start_time: 2010-01-01
|
||||
fit_end_time: 2015-12-31
|
||||
instruments: *market
|
||||
filter_pipe: [*filter]
|
||||
|
||||
To know more about ``Filter``, please refer to `Filter API <../reference/api.html#module-qlib.data.filter>`_.
|
||||
|
||||
Reference
|
||||
|
||||
4
examples/benchmarks/DoubleEnsemble/README.md
Normal file
4
examples/benchmarks/DoubleEnsemble/README.md
Normal file
@@ -0,0 +1,4 @@
|
||||
# DoubleEnsemble
|
||||
* DoubleEnsemble is an ensemble framework leveraging learning trajectory based sample reweighting and shuffling based feature selection, to solve both the low signal-to-noise ratio and increasing number of features problems. They identify the key samples based on the training dynamics on each sample and elicit key features based on the ablation impact of each feature via shuffling. The model is applicable to a wide range of base models, capable of extracting complex patterns, while mitigating the overfitting and instability issues for financial market prediction.
|
||||
* This code used in Qlib is implemented by ourselves.
|
||||
* Paper: DoubleEnsemble: A New Ensemble Method Based on Sample Reweighting and Feature Selection for Financial Data Analysis [https://arxiv.org/pdf/2010.01265.pdf](https://arxiv.org/pdf/2010.01265.pdf).
|
||||
3
examples/benchmarks/DoubleEnsemble/requirements.txt
Normal file
3
examples/benchmarks/DoubleEnsemble/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.17.4
|
||||
lightgbm==3.1.0
|
||||
@@ -0,0 +1,90 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: DEnsembleModel
|
||||
module_path: qlib.contrib.model.double_ensemble
|
||||
kwargs:
|
||||
base_model: "gbm"
|
||||
loss: mse
|
||||
num_models: 6
|
||||
enable_sr: True
|
||||
enable_fs: True
|
||||
alpha1: 1
|
||||
alpha2: 1
|
||||
bins_sr: 10
|
||||
bins_fs: 5
|
||||
decay: 0.5
|
||||
sample_ratios:
|
||||
- 0.8
|
||||
- 0.7
|
||||
- 0.6
|
||||
- 0.5
|
||||
- 0.4
|
||||
sub_weights:
|
||||
- 1
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
epochs: 28
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.2
|
||||
subsample: 0.8789
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
verbosity: -1
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -0,0 +1,97 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors: []
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: DEnsembleModel
|
||||
module_path: qlib.contrib.model.double_ensemble
|
||||
kwargs:
|
||||
base_model: "gbm"
|
||||
loss: mse
|
||||
num_models: 6
|
||||
enable_sr: True
|
||||
enable_fs: True
|
||||
alpha1: 1
|
||||
alpha2: 1
|
||||
bins_sr: 10
|
||||
bins_fs: 5
|
||||
decay: 0.5
|
||||
sample_ratios:
|
||||
- 0.8
|
||||
- 0.7
|
||||
- 0.6
|
||||
- 0.5
|
||||
- 0.4
|
||||
sub_weights:
|
||||
- 1
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
- 0.2
|
||||
epochs: 136
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.0421
|
||||
subsample: 0.8789
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
verbosity: -1
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -16,6 +16,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
| LSTM (Sepp Hochreiter, et al.) | Alpha360 | 0.0443±0.01 | 0.3401±0.05| 0.0536±0.01 | 0.4248±0.05 | 0.0627±0.03 | 0.8441±0.48| -0.0882±0.03 |
|
||||
| ALSTM (Yao Qin, et al.) | Alpha360 | 0.0493±0.01 | 0.3778±0.06| 0.0585±0.00 | 0.4606±0.04 | 0.0513±0.03 | 0.6727±0.38| -0.1085±0.02 |
|
||||
| GATs (Petar Velickovic, et al.) | Alpha360 | 0.0475±0.00 | 0.3515±0.02| 0.0592±0.00 | 0.4585±0.01 | 0.0876±0.02 | 1.1513±0.27| -0.0795±0.02 |
|
||||
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha360 | 0.0407±0.00| 0.3053±0.00 | 0.0490±0.00 | 0.3840±0.00 | 0.0380±0.02 | 0.5000±0.21 | -0.0984±0.02 |
|
||||
|
||||
## Alpha158 dataset
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
@@ -25,11 +26,12 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
| XGBoost (Tianqi Chen, et al.) | Alpha158 | 0.0481±0.00 | 0.3659±0.00| 0.0495±0.00 | 0.4033±0.00 | 0.1111±0.00 | 1.2915±0.00| -0.0893±0.00 |
|
||||
| LightGBM (Guolin Ke, et al.) | Alpha158 | 0.0475±0.00 | 0.3979±0.00| 0.0485±0.00 | 0.4123±0.00 | 0.1143±0.00 | 1.2744±0.00| -0.0800±0.00 |
|
||||
| MLP | Alpha158 | 0.0358±0.00 | 0.2738±0.03| 0.0425±0.00 | 0.3221±0.01 | 0.0836±0.02 | 1.0323±0.25| -0.1127±0.02 |
|
||||
| TabNet with pretrain (Sercan O. Arikm et al) | Alpha158 | 0.0344±0.00|0.205±0.11|0.0398±0.00 |0.3479±0.01|0.0827±0.02|1.1141±0.32 |-0.0925±0.02 |
|
||||
| TFT (Bryan Lim, et al.) | Alpha158 (with selected 20 features) | 0.0343±0.00 | 0.2071±0.02| 0.0107±0.00 | 0.0660±0.02 | 0.0623±0.02 | 0.5818±0.20| -0.1762±0.01 |
|
||||
| GRU (Kyunghyun Cho, et al.) | Alpha158 (with selected 20 features) | 0.0311±0.00 | 0.2418±0.04| 0.0425±0.00 | 0.3434±0.02 | 0.0330±0.02 | 0.4805±0.30| -0.1021±0.02 |
|
||||
| LSTM (Sepp Hochreiter, et al.) | Alpha158 (with selected 20 features) | 0.0312±0.00 | 0.2394±0.04| 0.0418±0.00 | 0.3324±0.03 | 0.0298±0.02 | 0.4198±0.33| -0.1348±0.03 |
|
||||
| ALSTM (Yao Qin, et al.) | Alpha158 (with selected 20 features) | 0.0385±0.01 | 0.3022±0.06| 0.0478±0.00 | 0.3874±0.04 | 0.0486±0.03 | 0.7141±0.45| -0.1088±0.03 |
|
||||
| GATs (Petar Velickovic, et al.) | Alpha158 (with selected 20 features) | 0.0349±0.00 | 0.2511±0.01| 0.0457±0.00 | 0.3537±0.01 | 0.0578±0.02 | 0.8221±0.25| -0.0824±0.02 |
|
||||
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha158 | 0.0544±0.00 | 0.4338±0.01 | 0.0523±0.00 | 0.4257±0.01 | 0.1253±0.01 | 1.4105±0.14 | -0.0902±0.01 |
|
||||
|
||||
- The selected 20 features are based on the feature importance of a lightgbm-based model.
|
||||
- The base model of DoubleEnsemble is LGBM.
|
||||
|
||||
Binary file not shown.
@@ -55,7 +55,7 @@ task:
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
pretrain: [2008-01-01, 2014-12-31]
|
||||
pretrain_validation: [2015-01-01, 2020-08-01]
|
||||
pretrain_validation: [2015-01-01, 2016-12-31]
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
|
||||
@@ -105,7 +105,7 @@ _default_config = {
|
||||
"redis_port": 6379,
|
||||
"redis_task_db": 1,
|
||||
# This value can be reset via qlib.init
|
||||
"logging_level": "INFO",
|
||||
"logging_level": logging.INFO,
|
||||
# Global configuration of qlib log
|
||||
# logging_level can control the logging level more finely
|
||||
"logging_config": {
|
||||
@@ -124,12 +124,12 @@ _default_config = {
|
||||
"handlers": {
|
||||
"console": {
|
||||
"class": "logging.StreamHandler",
|
||||
"level": "DEBUG",
|
||||
"level": logging.DEBUG,
|
||||
"formatter": "logger_format",
|
||||
"filters": ["field_not_found"],
|
||||
}
|
||||
},
|
||||
"loggers": {"qlib": {"level": "DEBUG", "handlers": ["console"]}},
|
||||
"loggers": {"qlib": {"level": logging.DEBUG, "handlers": ["console"]}},
|
||||
},
|
||||
# Defatult config for experiment manager
|
||||
"exp_manager": {
|
||||
@@ -185,7 +185,7 @@ MODE_CONF = {
|
||||
# The nfs should be auto-mounted by qlib on other
|
||||
# serversS(such as PAI) [auto_mount:True]
|
||||
"timeout": 100,
|
||||
"logging_level": "INFO",
|
||||
"logging_level": logging.INFO,
|
||||
"region": REG_CN,
|
||||
## Custom Operator
|
||||
"custom_ops": [],
|
||||
|
||||
@@ -104,10 +104,9 @@ class Account:
|
||||
# if suspend, no new price to be updated, profit is 0
|
||||
if trader.check_stock_suspended(code, today):
|
||||
continue
|
||||
else:
|
||||
today_close = trader.get_close(code, today)
|
||||
profit += (today_close - self.current.position[code]["price"]) * self.current.position[code]["amount"]
|
||||
self.current.update_stock_price(stock_id=code, price=today_close)
|
||||
today_close = trader.get_close(code, today)
|
||||
profit += (today_close - self.current.position[code]["price"]) * self.current.position[code]["amount"]
|
||||
self.current.update_stock_price(stock_id=code, price=today_close)
|
||||
self.rtn += profit
|
||||
# update holding day count
|
||||
self.current.add_count_all()
|
||||
|
||||
@@ -61,7 +61,7 @@ def get_position_value(evaluate_date, position):
|
||||
# load close price for position
|
||||
# position should also consider cash
|
||||
instruments = list(position.keys())
|
||||
instruments = list(set(instruments) - set(["cash"])) # filter 'cash'
|
||||
instruments = list(set(instruments) - {"cash"}) # filter 'cash'
|
||||
fields = ["$close"]
|
||||
close_data_df = D.features(
|
||||
instruments,
|
||||
@@ -80,7 +80,7 @@ def get_position_list_value(positions):
|
||||
instruments = set()
|
||||
for day, position in positions.items():
|
||||
instruments.update(position.keys())
|
||||
instruments = list(set(instruments) - set(["cash"])) # filter 'cash'
|
||||
instruments = list(set(instruments) - {"cash"}) # filter 'cash'
|
||||
instruments.sort()
|
||||
day_list = list(positions.keys())
|
||||
day_list.sort()
|
||||
|
||||
247
qlib/contrib/model/double_ensemble.py
Normal file
247
qlib/contrib/model/double_ensemble.py
Normal file
@@ -0,0 +1,247 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import lightgbm as lgb
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...log import get_module_logger
|
||||
|
||||
|
||||
class DEnsembleModel(Model):
|
||||
"""Double Ensemble Model"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_model="gbm",
|
||||
loss="mse",
|
||||
num_models=6,
|
||||
enable_sr=True,
|
||||
enable_fs=True,
|
||||
alpha1=1.0,
|
||||
alpha2=1.0,
|
||||
bins_sr=10,
|
||||
bins_fs=5,
|
||||
decay=None,
|
||||
sample_ratios=None,
|
||||
sub_weights=None,
|
||||
epochs=100,
|
||||
**kwargs
|
||||
):
|
||||
self.base_model = base_model # "gbm" or "mlp", specifically, we use lgbm for "gbm"
|
||||
self.num_models = num_models # the number of sub-models
|
||||
self.enable_sr = enable_sr
|
||||
self.enable_fs = enable_fs
|
||||
self.alpha1 = alpha1
|
||||
self.alpha2 = alpha2
|
||||
self.bins_sr = bins_sr
|
||||
self.bins_fs = bins_fs
|
||||
self.decay = decay
|
||||
if not len(sample_ratios) == bins_fs:
|
||||
raise ValueError("The length of sample_ratios should be equal to bins_fs.")
|
||||
self.sample_ratios = sample_ratios
|
||||
if not len(sub_weights) == num_models:
|
||||
raise ValueError("The length of sub_weights should be equal to num_models.")
|
||||
self.sub_weights = sub_weights
|
||||
self.epochs = epochs
|
||||
self.logger = get_module_logger("DEnsembleModel")
|
||||
self.logger.info("Double Ensemble Model...")
|
||||
self.ensemble = [] # the current ensemble model, a list contains all the sub-models
|
||||
self.sub_features = [] # the features for each sub model in the form of pandas.Index
|
||||
self.params = {"objective": loss}
|
||||
self.params.update(kwargs)
|
||||
self.loss = loss
|
||||
|
||||
def fit(self, dataset: DatasetH):
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
)
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
# initialize the sample weights
|
||||
N, F = x_train.shape
|
||||
weights = pd.Series(np.ones(N, dtype=float))
|
||||
# initialize the features
|
||||
features = x_train.columns
|
||||
pred_sub = pd.DataFrame(np.zeros((N, self.num_models), dtype=float), index=x_train.index)
|
||||
# train sub-models
|
||||
for k in range(self.num_models):
|
||||
self.sub_features.append(features)
|
||||
self.logger.info("Training sub-model: ({}/{})".format(k + 1, self.num_models))
|
||||
model_k = self.train_submodel(df_train, df_valid, weights, features)
|
||||
self.ensemble.append(model_k)
|
||||
# no further sample re-weight and feature selection needed for the last sub-model
|
||||
if k + 1 == self.num_models:
|
||||
break
|
||||
|
||||
self.logger.info("Retrieving loss curve and loss values...")
|
||||
loss_curve = self.retrieve_loss_curve(model_k, df_train, features)
|
||||
pred_k = self.predict_sub(model_k, df_train, features)
|
||||
pred_sub.iloc[:, k] = pred_k
|
||||
pred_ensemble = pred_sub.iloc[:, : k + 1].mean(axis=1)
|
||||
loss_values = pd.Series(self.get_loss(y_train.values.squeeze(), pred_ensemble.values))
|
||||
|
||||
if self.enable_sr:
|
||||
self.logger.info("Sample re-weighting...")
|
||||
weights = self.sample_reweight(loss_curve, loss_values, k + 1)
|
||||
|
||||
if self.enable_fs:
|
||||
self.logger.info("Feature selection...")
|
||||
features = self.feature_selection(df_train, loss_values)
|
||||
|
||||
def train_submodel(self, df_train, df_valid, weights, features):
|
||||
dtrain, dvalid = self._prepare_data_gbm(df_train, df_valid, weights, features)
|
||||
evals_result = dict()
|
||||
model = lgb.train(
|
||||
self.params,
|
||||
dtrain,
|
||||
num_boost_round=self.epochs,
|
||||
valid_sets=[dtrain, dvalid],
|
||||
valid_names=["train", "valid"],
|
||||
verbose_eval=20,
|
||||
evals_result=evals_result,
|
||||
)
|
||||
evals_result["train"] = list(evals_result["train"].values())[0]
|
||||
evals_result["valid"] = list(evals_result["valid"].values())[0]
|
||||
return model
|
||||
|
||||
def _prepare_data_gbm(self, df_train, df_valid, weights, features):
|
||||
x_train, y_train = df_train["feature"].loc[:, features], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"].loc[:, features], df_valid["label"]
|
||||
|
||||
# Lightgbm need 1D array as its label
|
||||
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
|
||||
y_train, y_valid = np.squeeze(y_train.values), np.squeeze(y_valid.values)
|
||||
else:
|
||||
raise ValueError("LightGBM doesn't support multi-label training")
|
||||
|
||||
dtrain = lgb.Dataset(x_train.values, label=y_train, weight=weights)
|
||||
dvalid = lgb.Dataset(x_valid.values, label=y_valid)
|
||||
return dtrain, dvalid
|
||||
|
||||
def sample_reweight(self, loss_curve, loss_values, k_th):
|
||||
"""
|
||||
the SR module of Double Ensemble
|
||||
:param loss_curve: the shape is NxT
|
||||
the loss curve for the previous sub-model, where the element (i, t) if the error on the i-th sample
|
||||
after the t-th iteration in the training of the previous sub-model.
|
||||
:param loss_values: the shape is N
|
||||
the loss of the current ensemble on the i-th sample.
|
||||
:param k_th: the index of the current sub-model, starting from 1
|
||||
:return: weights
|
||||
the weights for all the samples.
|
||||
"""
|
||||
# normalize loss_curve and loss_values with ranking
|
||||
loss_curve_norm = loss_curve.rank(axis=0, pct=True)
|
||||
loss_values_norm = (-loss_values).rank(pct=True)
|
||||
|
||||
# calculate l_start and l_end from loss_curve
|
||||
N, T = loss_curve.shape
|
||||
part = np.maximum(int(T * 0.1), 1)
|
||||
l_start = loss_curve_norm.iloc[:, :part].mean(axis=1)
|
||||
l_end = loss_curve_norm.iloc[:, -part:].mean(axis=1)
|
||||
|
||||
# calculate h-value for each sample
|
||||
h1 = loss_values_norm
|
||||
h2 = (l_end / l_start).rank(pct=True)
|
||||
h = pd.DataFrame({"h_value": self.alpha1 * h1 + self.alpha2 * h2})
|
||||
|
||||
# calculate weights
|
||||
h["bins"] = pd.cut(h["h_value"], self.bins_sr)
|
||||
h_avg = h.groupby("bins")["h_value"].mean()
|
||||
weights = pd.Series(np.zeros(N, dtype=float))
|
||||
for i_b, b in enumerate(h_avg.index):
|
||||
weights[h["bins"] == b] = 1.0 / (self.decay ** k_th * h_avg[i_b] + 0.1)
|
||||
return weights
|
||||
|
||||
def feature_selection(self, df_train, loss_values):
|
||||
"""
|
||||
the FS module of Double Ensemble
|
||||
:param df_train: the shape is NxF
|
||||
:param loss_values: the shape is N
|
||||
the loss of the current ensemble on the i-th sample.
|
||||
:return: res_feat: in the form of pandas.Index
|
||||
|
||||
"""
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
features = x_train.columns
|
||||
N, F = x_train.shape
|
||||
g = pd.DataFrame({"g_value": np.zeros(F, dtype=float)})
|
||||
M = len(self.ensemble)
|
||||
|
||||
# shuffle specific columns and calculate g-value for each feature
|
||||
x_train_tmp = x_train.copy()
|
||||
for i_f, feat in enumerate(features):
|
||||
x_train_tmp.loc[:, feat] = np.random.permutation(x_train_tmp.loc[:, feat].values)
|
||||
pred = pd.Series(np.zeros(N), index=x_train_tmp.index)
|
||||
for i_s, submodel in enumerate(self.ensemble):
|
||||
pred += (
|
||||
pd.Series(
|
||||
submodel.predict(x_train_tmp.loc[:, self.sub_features[i_s]].values), index=x_train_tmp.index
|
||||
)
|
||||
/ M
|
||||
)
|
||||
loss_feat = self.get_loss(y_train.values.squeeze(), pred.values)
|
||||
g.loc[i_f, "g_value"] = np.mean(loss_feat - loss_values) / (np.std(loss_feat - loss_values) + 1e-7)
|
||||
x_train_tmp.loc[:, feat] = x_train.loc[:, feat].copy()
|
||||
|
||||
# one column in train features is all-nan # if g['g_value'].isna().any()
|
||||
g["g_value"].replace(np.nan, 0, inplace=True)
|
||||
|
||||
# divide features into bins_fs bins
|
||||
g["bins"] = pd.cut(g["g_value"], self.bins_fs)
|
||||
|
||||
# randomly sample features from bins to construct the new features
|
||||
res_feat = []
|
||||
sorted_bins = sorted(g["bins"].unique(), reverse=True)
|
||||
for i_b, b in enumerate(sorted_bins):
|
||||
b_feat = features[g["bins"] == b]
|
||||
num_feat = int(np.ceil(self.sample_ratios[i_b] * len(b_feat)))
|
||||
res_feat = res_feat + np.random.choice(b_feat, size=num_feat).tolist()
|
||||
return pd.Index(res_feat)
|
||||
|
||||
def get_loss(self, label, pred):
|
||||
if self.loss == "mse":
|
||||
return (label - pred) ** 2
|
||||
else:
|
||||
raise ValueError("not implemented yet")
|
||||
|
||||
def retrieve_loss_curve(self, model, df_train, features):
|
||||
if self.base_model == "gbm":
|
||||
num_trees = model.num_trees()
|
||||
x_train, y_train = df_train["feature"].loc[:, features], df_train["label"]
|
||||
# Lightgbm need 1D array as its label
|
||||
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
|
||||
y_train = np.squeeze(y_train.values)
|
||||
else:
|
||||
raise ValueError("LightGBM doesn't support multi-label training")
|
||||
|
||||
N = x_train.shape[0]
|
||||
loss_curve = pd.DataFrame(np.zeros((N, num_trees)))
|
||||
pred_tree = np.zeros(N, dtype=float)
|
||||
for i_tree in range(num_trees):
|
||||
pred_tree += model.predict(x_train.values, start_iteration=i_tree, num_iteration=1)
|
||||
loss_curve.iloc[:, i_tree] = self.get_loss(y_train, pred_tree)
|
||||
else:
|
||||
raise ValueError("not implemented yet")
|
||||
return loss_curve
|
||||
|
||||
def predict(self, dataset):
|
||||
if self.ensemble is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
pred = pd.Series(np.zeros(x_test.shape[0]), index=x_test.index)
|
||||
for i_sub, submodel in enumerate(self.ensemble):
|
||||
feat_sub = self.sub_features[i_sub]
|
||||
pred += (
|
||||
pd.Series(submodel.predict(x_test.loc[:, feat_sub].values), index=x_test.index)
|
||||
* self.sub_weights[i_sub]
|
||||
)
|
||||
return pred
|
||||
|
||||
def predict_sub(self, submodel, df_data, features):
|
||||
x_data, y_data = df_data["feature"].loc[:, features], df_data["label"]
|
||||
pred_sub = pd.Series(submodel.predict(x_data.values), index=x_data.index)
|
||||
return pred_sub
|
||||
@@ -14,7 +14,7 @@ import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
@@ -23,6 +23,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
@@ -39,8 +40,8 @@ class ALSTM(Model):
|
||||
the evaluate metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
GPU : int
|
||||
the GPU ID used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -76,8 +77,7 @@ class ALSTM(Model):
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
@@ -93,7 +93,7 @@ class ALSTM(Model):
|
||||
"\nearly_stop : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\ndevice : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
d_feat,
|
||||
@@ -107,7 +107,7 @@ class ALSTM(Model):
|
||||
early_stop,
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
GPU,
|
||||
self.device,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
)
|
||||
@@ -123,6 +123,9 @@ class ALSTM(Model):
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.ALSTM_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.ALSTM_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.ALSTM_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
@@ -133,6 +136,10 @@ class ALSTM(Model):
|
||||
self.fitted = False
|
||||
self.ALSTM_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
@@ -201,12 +208,13 @@ class ALSTM(Model):
|
||||
feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
|
||||
pred = self.ALSTM_model(feature)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
with torch.no_grad():
|
||||
pred = self.ALSTM_model(feature)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
@@ -214,7 +222,6 @@ class ALSTM(Model):
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
@@ -227,8 +234,7 @@ class ALSTM(Model):
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = -np.inf
|
||||
@@ -290,10 +296,7 @@ class ALSTM(Model):
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_gpu:
|
||||
pred = self.ALSTM_model(x_batch).detach().cpu().numpy()
|
||||
else:
|
||||
pred = self.ALSTM_model(x_batch).detach().numpy()
|
||||
pred = self.ALSTM_model(x_batch).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
@@ -24,6 +24,7 @@ import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH, TSDatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
@@ -40,8 +41,8 @@ class ALSTM(Model):
|
||||
the evaluate metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
GPU : int
|
||||
the GPU ID used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -78,9 +79,8 @@ class ALSTM(Model):
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.n_jobs = n_jobs
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
@@ -96,7 +96,7 @@ class ALSTM(Model):
|
||||
"\nearly_stop : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\ndevice : {}"
|
||||
"\nn_jobs : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
@@ -111,7 +111,7 @@ class ALSTM(Model):
|
||||
early_stop,
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
GPU,
|
||||
self.device,
|
||||
n_jobs,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
@@ -127,7 +127,10 @@ class ALSTM(Model):
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
).to(self.device)
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.ALSTM_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.ALSTM_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.ALSTM_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
@@ -138,6 +141,10 @@ class ALSTM(Model):
|
||||
self.fitted = False
|
||||
self.ALSTM_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
@@ -188,12 +195,13 @@ class ALSTM(Model):
|
||||
# feature[torch.isnan(feature)] = 0
|
||||
label = data[:, -1, -1].to(self.device)
|
||||
|
||||
pred = self.ALSTM_model(feature.float())
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
with torch.no_grad():
|
||||
pred = self.ALSTM_model(feature.float())
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
@@ -201,7 +209,6 @@ class ALSTM(Model):
|
||||
self,
|
||||
dataset,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
@@ -210,11 +217,14 @@ class ALSTM(Model):
|
||||
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
|
||||
train_loader = DataLoader(dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs)
|
||||
valid_loader = DataLoader(dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs)
|
||||
train_loader = DataLoader(
|
||||
dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True
|
||||
)
|
||||
valid_loader = DataLoader(
|
||||
dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True
|
||||
)
|
||||
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
save_path = get_or_create_path(save_path)
|
||||
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
@@ -271,10 +281,7 @@ class ALSTM(Model):
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_gpu:
|
||||
pred = self.ALSTM_model(feature.float()).detach().cpu().numpy()
|
||||
else:
|
||||
pred = self.ALSTM_model(feature.float()).detach().numpy()
|
||||
pred = self.ALSTM_model(feature.float()).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
@@ -22,6 +22,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
@@ -42,8 +43,8 @@ class GATs(Model):
|
||||
the evaluate metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
GPU : int
|
||||
the GPU ID used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -83,7 +84,7 @@ class GATs(Model):
|
||||
self.base_model = base_model
|
||||
self.with_pretrain = with_pretrain
|
||||
self.model_path = model_path
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.seed = seed
|
||||
|
||||
@@ -102,7 +103,7 @@ class GATs(Model):
|
||||
"\nbase_model : {}"
|
||||
"\nwith_pretrain : {}"
|
||||
"\nmodel_path : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\ndevice : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
d_feat,
|
||||
@@ -118,7 +119,7 @@ class GATs(Model):
|
||||
base_model,
|
||||
with_pretrain,
|
||||
model_path,
|
||||
GPU,
|
||||
self.device,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
)
|
||||
@@ -135,6 +136,9 @@ class GATs(Model):
|
||||
dropout=self.dropout,
|
||||
base_model=self.base_model,
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.GAT_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.GAT_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.GAT_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
@@ -145,6 +149,10 @@ class GATs(Model):
|
||||
self.fitted = False
|
||||
self.GAT_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
@@ -232,7 +240,6 @@ class GATs(Model):
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
@@ -245,8 +252,7 @@ class GATs(Model):
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
@@ -324,10 +330,7 @@ class GATs(Model):
|
||||
x_batch = torch.from_numpy(x_values[batch]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_gpu:
|
||||
pred = self.GAT_model(x_batch).detach().cpu().numpy()
|
||||
else:
|
||||
pred = self.GAT_model(x_batch).detach().numpy()
|
||||
pred = self.GAT_model(x_batch).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
@@ -24,6 +24,7 @@ import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import Sampler
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
@@ -62,8 +63,8 @@ class GATs(Model):
|
||||
the evaluate metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
GPU : int
|
||||
the GPU ID used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -104,9 +105,8 @@ class GATs(Model):
|
||||
self.base_model = base_model
|
||||
self.with_pretrain = with_pretrain
|
||||
self.model_path = model_path
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.n_jobs = n_jobs
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
@@ -157,6 +157,9 @@ class GATs(Model):
|
||||
dropout=self.dropout,
|
||||
base_model=self.base_model,
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.GAT_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.GAT_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.GAT_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
@@ -167,6 +170,10 @@ class GATs(Model):
|
||||
self.fitted = False
|
||||
self.GAT_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
@@ -245,7 +252,6 @@ class GATs(Model):
|
||||
self,
|
||||
dataset,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
@@ -258,11 +264,10 @@ class GATs(Model):
|
||||
sampler_train = DailyBatchSampler(dl_train)
|
||||
sampler_valid = DailyBatchSampler(dl_valid)
|
||||
|
||||
train_loader = DataLoader(dl_train, sampler=sampler_train, num_workers=self.n_jobs)
|
||||
valid_loader = DataLoader(dl_valid, sampler=sampler_valid, num_workers=self.n_jobs)
|
||||
train_loader = DataLoader(dl_train, sampler=sampler_train, num_workers=self.n_jobs, drop_last=True)
|
||||
valid_loader = DataLoader(dl_valid, sampler=sampler_valid, num_workers=self.n_jobs, drop_last=True)
|
||||
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
save_path = get_or_create_path(save_path)
|
||||
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
@@ -345,10 +350,7 @@ class GATs(Model):
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_gpu:
|
||||
pred = self.GAT_model(feature.float()).detach().cpu().numpy()
|
||||
else:
|
||||
pred = self.GAT_model(feature.float()).detach().numpy()
|
||||
pred = self.GAT_model(feature.float()).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
@@ -23,6 +23,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
@@ -76,8 +77,7 @@ class GRU(Model):
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
@@ -123,6 +123,9 @@ class GRU(Model):
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.gru_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.gru_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.gru_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
@@ -133,6 +136,10 @@ class GRU(Model):
|
||||
self.fitted = False
|
||||
self.gru_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
@@ -201,12 +208,13 @@ class GRU(Model):
|
||||
feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
|
||||
pred = self.gru_model(feature)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
with torch.no_grad():
|
||||
pred = self.gru_model(feature)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
@@ -214,7 +222,6 @@ class GRU(Model):
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
@@ -227,8 +234,7 @@ class GRU(Model):
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = -np.inf
|
||||
@@ -290,10 +296,7 @@ class GRU(Model):
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_gpu:
|
||||
pred = self.gru_model(x_batch).detach().cpu().numpy()
|
||||
else:
|
||||
pred = self.gru_model(x_batch).detach().numpy()
|
||||
pred = self.gru_model(x_batch).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
@@ -24,6 +24,7 @@ import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH, TSDatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
@@ -78,9 +79,8 @@ class GRU(Model):
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.n_jobs = n_jobs
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
@@ -96,7 +96,7 @@ class GRU(Model):
|
||||
"\nearly_stop : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\ndevice : {}"
|
||||
"\nn_jobs : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
@@ -111,7 +111,7 @@ class GRU(Model):
|
||||
early_stop,
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
GPU,
|
||||
self.device,
|
||||
n_jobs,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
@@ -127,7 +127,10 @@ class GRU(Model):
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
).to(self.device)
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.gru_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.gru_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.GRU_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
@@ -138,6 +141,10 @@ class GRU(Model):
|
||||
self.fitted = False
|
||||
self.GRU_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
@@ -188,12 +195,13 @@ class GRU(Model):
|
||||
# feature[torch.isnan(feature)] = 0
|
||||
label = data[:, -1, -1].to(self.device)
|
||||
|
||||
pred = self.GRU_model(feature.float())
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
with torch.no_grad():
|
||||
pred = self.GRU_model(feature.float())
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
@@ -201,7 +209,6 @@ class GRU(Model):
|
||||
self,
|
||||
dataset,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
@@ -210,11 +217,14 @@ class GRU(Model):
|
||||
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
|
||||
train_loader = DataLoader(dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs)
|
||||
valid_loader = DataLoader(dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs)
|
||||
train_loader = DataLoader(
|
||||
dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True
|
||||
)
|
||||
valid_loader = DataLoader(
|
||||
dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True
|
||||
)
|
||||
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
save_path = get_or_create_path(save_path)
|
||||
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
@@ -271,10 +281,7 @@ class GRU(Model):
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_gpu:
|
||||
pred = self.GRU_model(feature.float()).detach().cpu().numpy()
|
||||
else:
|
||||
pred = self.GRU_model(feature.float()).detach().numpy()
|
||||
pred = self.GRU_model(feature.float()).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
@@ -76,8 +76,7 @@ class LSTM(Model):
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
@@ -133,6 +132,10 @@ class LSTM(Model):
|
||||
self.fitted = False
|
||||
self.lstm_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
@@ -214,7 +217,6 @@ class LSTM(Model):
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
@@ -227,8 +229,7 @@ class LSTM(Model):
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = -np.inf
|
||||
@@ -290,10 +291,7 @@ class LSTM(Model):
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_gpu:
|
||||
pred = self.lstm_model(x_batch).detach().cpu().numpy()
|
||||
else:
|
||||
pred = self.lstm_model(x_batch).detach().numpy()
|
||||
pred = self.lstm_model(x_batch).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
@@ -78,9 +78,8 @@ class LSTM(Model):
|
||||
self.early_stop = early_stop
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.n_jobs = n_jobs
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
@@ -96,7 +95,7 @@ class LSTM(Model):
|
||||
"\nearly_stop : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\ndevice : {}"
|
||||
"\nn_jobs : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
@@ -111,7 +110,7 @@ class LSTM(Model):
|
||||
early_stop,
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
GPU,
|
||||
self.device,
|
||||
n_jobs,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
@@ -138,6 +137,10 @@ class LSTM(Model):
|
||||
self.fitted = False
|
||||
self.LSTM_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def mse(self, pred, label):
|
||||
loss = (pred - label) ** 2
|
||||
return torch.mean(loss)
|
||||
@@ -201,7 +204,6 @@ class LSTM(Model):
|
||||
self,
|
||||
dataset,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
dl_train = dataset.prepare("train", col_set=["feature", "label"], data_key=DataHandlerLP.DK_L)
|
||||
@@ -210,11 +212,14 @@ class LSTM(Model):
|
||||
dl_train.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
dl_valid.config(fillna_type="ffill+bfill") # process nan brought by dataloader
|
||||
|
||||
train_loader = DataLoader(dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs)
|
||||
valid_loader = DataLoader(dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs)
|
||||
train_loader = DataLoader(
|
||||
dl_train, batch_size=self.batch_size, shuffle=True, num_workers=self.n_jobs, drop_last=True
|
||||
)
|
||||
valid_loader = DataLoader(
|
||||
dl_valid, batch_size=self.batch_size, shuffle=False, num_workers=self.n_jobs, drop_last=True
|
||||
)
|
||||
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
save_path = get_or_create_path(save_path)
|
||||
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
@@ -271,10 +276,7 @@ class LSTM(Model):
|
||||
feature = data[:, :, 0:-1].to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_gpu:
|
||||
pred = self.LSTM_model(feature.float()).detach().cpu().numpy()
|
||||
else:
|
||||
pred = self.LSTM_model(feature.float()).detach().numpy()
|
||||
pred = self.LSTM_model(feature.float()).detach().cpu().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
|
||||
@@ -15,10 +15,11 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, create_save_path, drop_nan_by_y_index
|
||||
from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, get_or_create_path, drop_nan_by_y_index
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
from ...workflow import R
|
||||
|
||||
@@ -42,8 +43,8 @@ class DNNModelPytorch(Model):
|
||||
learning rate decay steps
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
GPU : int
|
||||
the GPU ID used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -80,8 +81,7 @@ class DNNModelPytorch(Model):
|
||||
self.lr_decay_steps = lr_decay_steps
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss_type = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.use_GPU = torch.cuda.is_available()
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
self.weight_decay = weight_decay
|
||||
|
||||
@@ -99,7 +99,7 @@ class DNNModelPytorch(Model):
|
||||
"\nloss_type : {}"
|
||||
"\neval_steps : {}"
|
||||
"\nseed : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\ndevice : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nweight_decay : {}".format(
|
||||
layers,
|
||||
@@ -114,8 +114,8 @@ class DNNModelPytorch(Model):
|
||||
loss,
|
||||
eval_steps,
|
||||
seed,
|
||||
GPU,
|
||||
self.use_GPU,
|
||||
self.device,
|
||||
self.use_gpu,
|
||||
weight_decay,
|
||||
)
|
||||
)
|
||||
@@ -129,6 +129,9 @@ class DNNModelPytorch(Model):
|
||||
self._scorer = mean_squared_error if loss == "mse" else roc_auc_score
|
||||
|
||||
self.dnn_model = Net(input_dim, output_dim, layers, loss=self.loss_type)
|
||||
self.logger.info("model:\n{:}".format(self.dnn_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.dnn_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.dnn_model.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
||||
elif optimizer.lower() == "gd":
|
||||
@@ -153,6 +156,10 @@ class DNNModelPytorch(Model):
|
||||
self.fitted = False
|
||||
self.dnn_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
@@ -172,7 +179,7 @@ class DNNModelPytorch(Model):
|
||||
w_train = pd.DataFrame(np.ones_like(y_train.values), index=y_train.index)
|
||||
w_valid = pd.DataFrame(np.ones_like(y_valid.values), index=y_valid.index)
|
||||
|
||||
save_path = create_save_path(save_path)
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_loss = np.inf
|
||||
@@ -215,7 +222,8 @@ class DNNModelPytorch(Model):
|
||||
|
||||
# validation
|
||||
train_loss += loss.val
|
||||
if step and step % self.eval_steps == 0:
|
||||
# for evert `eval_steps` steps or at the last steps, we will evaluate the model.
|
||||
if step % self.eval_steps == 0 or step + 1 == self.max_steps:
|
||||
stop_steps += 1
|
||||
train_loss /= self.eval_steps
|
||||
|
||||
@@ -248,9 +256,9 @@ class DNNModelPytorch(Model):
|
||||
# update learning rate
|
||||
self.scheduler.step(cur_loss_val)
|
||||
|
||||
# restore the optimal parameters after training ??
|
||||
# restore the optimal parameters after training
|
||||
self.dnn_model.load_state_dict(torch.load(save_path))
|
||||
if self.use_GPU:
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_loss(self, pred, w, target, loss_type):
|
||||
@@ -272,10 +280,7 @@ class DNNModelPytorch(Model):
|
||||
self.dnn_model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_GPU:
|
||||
preds = self.dnn_model(x_test).detach().cpu().numpy()
|
||||
else:
|
||||
preds = self.dnn_model(x_test).detach().numpy()
|
||||
preds = self.dnn_model(x_test).detach().cpu().numpy()
|
||||
return pd.Series(np.squeeze(preds), index=x_test_pd.index)
|
||||
|
||||
def save(self, filename, **kwargs):
|
||||
|
||||
@@ -13,7 +13,7 @@ import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
@@ -23,6 +23,7 @@ import torch.nn as nn
|
||||
import torch.nn.init as init
|
||||
import torch.optim as optim
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
@@ -196,8 +197,8 @@ class SFM(Model):
|
||||
learning rate
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
GPU : int
|
||||
the GPU ID used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -216,7 +217,7 @@ class SFM(Model):
|
||||
eval_steps=5,
|
||||
loss="mse",
|
||||
optimizer="gd",
|
||||
GPU="0",
|
||||
GPU=0,
|
||||
seed=None,
|
||||
**kwargs
|
||||
):
|
||||
@@ -239,8 +240,7 @@ class SFM(Model):
|
||||
self.eval_steps = eval_steps
|
||||
self.optimizer = optimizer.lower()
|
||||
self.loss = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu")
|
||||
self.seed = seed
|
||||
|
||||
self.logger.info(
|
||||
@@ -259,7 +259,7 @@ class SFM(Model):
|
||||
"\neval_steps : {}"
|
||||
"\noptimizer : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\ndevice : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
d_feat,
|
||||
@@ -276,7 +276,7 @@ class SFM(Model):
|
||||
eval_steps,
|
||||
optimizer.lower(),
|
||||
loss,
|
||||
GPU,
|
||||
self.device,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
)
|
||||
@@ -295,6 +295,9 @@ class SFM(Model):
|
||||
dropout_U=self.dropout_U,
|
||||
device=self.device,
|
||||
)
|
||||
self.logger.info("model:\n{:}".format(self.sfm_model))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters(self.sfm_model)))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.train_optimizer = optim.Adam(self.sfm_model.parameters(), lr=self.lr)
|
||||
elif optimizer.lower() == "gd":
|
||||
@@ -305,6 +308,10 @@ class SFM(Model):
|
||||
self.fitted = False
|
||||
self.sfm_model.to(self.device)
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def test_epoch(self, data_x, data_y):
|
||||
|
||||
# prepare training data
|
||||
@@ -365,7 +372,6 @@ class SFM(Model):
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
@@ -377,6 +383,7 @@ class SFM(Model):
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
|
||||
save_path = get_or_create_path(save_path)
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = -np.inf
|
||||
@@ -409,7 +416,10 @@ class SFM(Model):
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.sfm_model.load_state_dict(best_param)
|
||||
torch.save(best_param, save_path)
|
||||
if self.device != "cpu":
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
get_or_create_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
@@ -23,6 +23,7 @@ import torch.optim as optim
|
||||
import torch.nn.functional as F
|
||||
from torch.autograd import Function
|
||||
|
||||
from .pytorch_utils import count_parameters
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
@@ -49,12 +50,12 @@ class TabnetModel(Model):
|
||||
loss="mse",
|
||||
metric="",
|
||||
early_stop=20,
|
||||
GPU="1",
|
||||
GPU=0,
|
||||
pretrain_loss="custom",
|
||||
ps=0.3,
|
||||
lr=0.01,
|
||||
pretrain=True,
|
||||
pretrain_file="./pretrain/best.model",
|
||||
pretrain_file=None,
|
||||
):
|
||||
"""
|
||||
TabNet model for Qlib
|
||||
@@ -75,18 +76,18 @@ class TabnetModel(Model):
|
||||
self.n_epochs = n_epochs
|
||||
self.logger = get_module_logger("TabNet")
|
||||
self.pretrain_n_epochs = pretrain_n_epochs
|
||||
self.device = "cuda:%s" % (GPU) if torch.cuda.is_available() else "cpu"
|
||||
self.device = "cuda:%s" % (GPU) if torch.cuda.is_available() and GPU >= 0 else "cpu"
|
||||
self.loss = loss
|
||||
self.metric = metric
|
||||
self.early_stop = early_stop
|
||||
self.pretrain = pretrain
|
||||
self.pretrain_file = pretrain_file
|
||||
self.pretrain_file = get_or_create_path(pretrain_file)
|
||||
self.logger.info(
|
||||
"TabNet:"
|
||||
"\nbatch_size : {}"
|
||||
"\nvirtual bs : {}"
|
||||
"\nGPU : {}"
|
||||
"\npretrain: {}".format(self.batch_size, vbs, GPU, pretrain)
|
||||
"\ndevice : {}"
|
||||
"\npretrain: {}".format(self.batch_size, vbs, self.device, self.pretrain)
|
||||
)
|
||||
self.fitted = False
|
||||
np.random.seed(self.seed)
|
||||
@@ -98,6 +99,8 @@ class TabnetModel(Model):
|
||||
self.tabnet_decoder = TabNet_Decoder(self.out_dim, self.d_feat, n_shared, n_ind, vbs, n_steps, self.device).to(
|
||||
self.device
|
||||
)
|
||||
self.logger.info("model:\n{:}\n{:}".format(self.tabnet_model, self.tabnet_decoder))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters([self.tabnet_model, self.tabnet_decoder])))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.pretrain_optimizer = optim.Adam(
|
||||
@@ -113,11 +116,12 @@ class TabnetModel(Model):
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(optimizer))
|
||||
|
||||
@property
|
||||
def use_gpu(self):
|
||||
return self.device != torch.device("cpu")
|
||||
|
||||
def pretrain_fn(self, dataset=DatasetH, pretrain_file="./pretrain/best.model"):
|
||||
# make a directory if pretrian director does not exist
|
||||
if pretrain_file.startswith("./pretrain") and not os.path.exists("pretrain"):
|
||||
self.logger.info("make folder to store model...")
|
||||
os.makedirs("pretrain")
|
||||
get_or_create_path(pretrain_file)
|
||||
|
||||
[df_train, df_valid] = dataset.prepare(
|
||||
["pretrain", "pretrain_validation"],
|
||||
@@ -159,7 +163,6 @@ class TabnetModel(Model):
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
if self.pretrain:
|
||||
@@ -179,10 +182,11 @@ class TabnetModel(Model):
|
||||
df_train.fillna(df_train.mean(), inplace=True)
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
save_path = get_or_create_path(save_path)
|
||||
|
||||
stop_steps = 0
|
||||
train_loss = 0
|
||||
best_score = np.inf
|
||||
best_score = -np.inf
|
||||
best_epoch = 0
|
||||
evals_result["train"] = []
|
||||
evals_result["valid"] = []
|
||||
@@ -201,16 +205,23 @@ class TabnetModel(Model):
|
||||
evals_result["train"].append(train_score)
|
||||
evals_result["valid"].append(val_score)
|
||||
|
||||
if val_score < best_score:
|
||||
if val_score > best_score:
|
||||
best_score = val_score
|
||||
stop_steps = 0
|
||||
best_epoch = epoch_idx
|
||||
best_param = copy.deepcopy(self.tabnet_model.state_dict())
|
||||
else:
|
||||
stop_steps += 1
|
||||
if stop_steps >= self.early_stop:
|
||||
self.logger.info("early stop")
|
||||
break
|
||||
|
||||
self.logger.info("best score: %.6lf @ %d" % (best_score, best_epoch))
|
||||
self.tabnet_model.load_state_dict(best_param)
|
||||
torch.save(best_param, save_path)
|
||||
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self.fitted:
|
||||
@@ -260,12 +271,13 @@ class TabnetModel(Model):
|
||||
feature = x_values[indices[i : i + self.batch_size]].float().to(self.device)
|
||||
label = y_values[indices[i : i + self.batch_size]].float().to(self.device)
|
||||
priors = torch.ones(self.batch_size, self.d_feat).to(self.device)
|
||||
pred = self.tabnet_model(feature, priors)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
with torch.no_grad():
|
||||
pred = self.tabnet_model(feature, priors)
|
||||
loss = self.loss_fn(pred, label)
|
||||
losses.append(loss.item())
|
||||
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
score = self.metric_fn(pred, label)
|
||||
scores.append(score.item())
|
||||
|
||||
return np.mean(losses), np.mean(scores)
|
||||
|
||||
@@ -348,10 +360,11 @@ class TabnetModel(Model):
|
||||
label = y_train_values.float().to(self.device)
|
||||
S_mask = S_mask.to(self.device)
|
||||
priors = 1 - S_mask
|
||||
(vec, sparse_loss) = self.tabnet_model(feature, priors)
|
||||
f = self.tabnet_decoder(vec)
|
||||
with torch.no_grad():
|
||||
(vec, sparse_loss) = self.tabnet_model(feature, priors)
|
||||
f = self.tabnet_decoder(vec)
|
||||
|
||||
loss = self.pretrain_loss_fn(label, f, S_mask)
|
||||
loss = self.pretrain_loss_fn(label, f, S_mask)
|
||||
losses.append(loss.item())
|
||||
|
||||
return np.mean(losses)
|
||||
|
||||
37
qlib/contrib/model/pytorch_utils.py
Normal file
37
qlib/contrib/model/pytorch_utils.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def count_parameters(models_or_parameters, unit="m"):
|
||||
"""
|
||||
This function is to obtain the storage size unit of a (or multiple) models.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
models_or_parameters : PyTorch model(s) or a list of parameters.
|
||||
unit : the storage size unit.
|
||||
|
||||
Returns
|
||||
-------
|
||||
The number of parameters of the given model(s) or parameters.
|
||||
"""
|
||||
if isinstance(models_or_parameters, nn.Module):
|
||||
counts = sum(v.numel() for v in models_or_parameters.parameters())
|
||||
elif isinstance(models_or_parameters, nn.Parameter):
|
||||
counts = models_or_parameters.numel()
|
||||
elif isinstance(models_or_parameters, (list, tuple)):
|
||||
return sum(count_parameters(x, unit) for x in models_or_parameters)
|
||||
else:
|
||||
counts = sum(v.numel() for v in models_or_parameters)
|
||||
unit = unit.lower()
|
||||
if unit == "kb" or unit == "k":
|
||||
counts /= 2 ** 10
|
||||
elif unit == "mb" or unit == "m":
|
||||
counts /= 2 ** 20
|
||||
elif unit == "gb" or unit == "g":
|
||||
counts /= 2 ** 30
|
||||
elif unit is not None:
|
||||
raise ValueError("Unknow unit: {:}".format(unit))
|
||||
return counts
|
||||
@@ -63,7 +63,7 @@ class UserManager:
|
||||
account_path = self.data_path / user_id
|
||||
strategy_file = self.data_path / user_id / "strategy_{}.pickle".format(user_id)
|
||||
model_file = self.data_path / user_id / "model_{}.pickle".format(user_id)
|
||||
cur_user_list = [user_id for user_id in self.users]
|
||||
cur_user_list = list(self.users)
|
||||
if user_id in cur_user_list:
|
||||
raise ValueError("User {} has been loaded".format(user_id))
|
||||
else:
|
||||
|
||||
@@ -148,7 +148,7 @@ class Operator:
|
||||
for user_id, user in um.users.items():
|
||||
dates, trade_exchange = prepare(um, trade_date, user_id, exchange_config)
|
||||
executor = SimulatorExecutor(trade_exchange=trade_exchange)
|
||||
if not str(dates[0].date()) == str(pred_date.date()):
|
||||
if str(dates[0].date()) != str(pred_date.date()):
|
||||
raise ValueError(
|
||||
"The account data is not newest! last trading date {}, today {}".format(
|
||||
dates[0].date(), trade_date.date()
|
||||
|
||||
@@ -161,7 +161,7 @@ class DistplotGraph(BaseGraph):
|
||||
"""
|
||||
_t_df = self._df.dropna()
|
||||
_data_list = [_t_df[_col] for _col in self._name_dict]
|
||||
_label_list = [_name for _name in self._name_dict.values()]
|
||||
_label_list = list(self._name_dict.values())
|
||||
_fig = create_distplot(_data_list, _label_list, show_rug=False, **self._graph_kwargs)
|
||||
|
||||
return _fig["data"]
|
||||
|
||||
@@ -1045,9 +1045,6 @@ class SimpleDatasetCache(DatasetCache):
|
||||
class DatasetURICache(DatasetCache):
|
||||
"""Prepared cache mechanism for server."""
|
||||
|
||||
def __init__(self, provider):
|
||||
super(DatasetURICache, self).__init__(provider)
|
||||
|
||||
def _uri(self, instruments, fields, start_time, end_time, freq, disk_cache=1, **kwargs):
|
||||
return hash_args(*self.normalize_uri_args(instruments, fields, freq), disk_cache)
|
||||
|
||||
|
||||
@@ -654,9 +654,6 @@ class LocalExpressionProvider(ExpressionProvider):
|
||||
Provide expression data from local data source.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def expression(self, instrument, field, start_time=None, end_time=None, freq="day"):
|
||||
expression = self.get_expression_instance(field)
|
||||
start_time = pd.Timestamp(start_time)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from ...utils.serial import Serializable
|
||||
from typing import Union, List, Tuple
|
||||
from typing import Union, List, Tuple, Dict, Text, Optional
|
||||
from ...utils import init_instance_by_config, np_ffill
|
||||
from ...log import get_module_logger
|
||||
from .handler import DataHandler, DataHandlerLP
|
||||
@@ -76,17 +76,6 @@ class DatasetH(Dataset):
|
||||
- The processing is related to data split.
|
||||
"""
|
||||
|
||||
def __init__(self, handler: Union[dict, DataHandler], segments: dict):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
handler : Union[dict, DataHandler]
|
||||
handler will be passed into setup_data.
|
||||
segments : dict
|
||||
handler will be passed into setup_data.
|
||||
"""
|
||||
super().__init__(handler, segments)
|
||||
|
||||
def init(self, handler_kwargs: dict = None, segment_kwargs: dict = None):
|
||||
"""
|
||||
Initialize the DatasetH
|
||||
@@ -124,7 +113,7 @@ class DatasetH(Dataset):
|
||||
raise TypeError(f"param handler_kwargs must be type dict, not {type(segment_kwargs)}")
|
||||
self.segments = segment_kwargs.copy()
|
||||
|
||||
def setup_data(self, handler: Union[dict, DataHandler], segments: dict):
|
||||
def setup_data(self, handler: Union[Dict, DataHandler], segments: Dict[Text, Tuple]):
|
||||
"""
|
||||
Setup the underlying data.
|
||||
|
||||
@@ -156,6 +145,11 @@ class DatasetH(Dataset):
|
||||
self.handler = init_instance_by_config(handler, accept_types=DataHandler)
|
||||
self.segments = segments.copy()
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(handler={handler}, segments={segments})".format(
|
||||
name=self.__class__.__name__, handler=self.handler, segments=self.segments
|
||||
)
|
||||
|
||||
def _prepare_seg(self, slc: slice, **kwargs):
|
||||
"""
|
||||
Give a slice, retrieve the according data
|
||||
@@ -168,7 +162,7 @@ class DatasetH(Dataset):
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
segments: Union[List[str], Tuple[str], str, slice],
|
||||
segments: Union[List[Text], Tuple[Text], Text, slice],
|
||||
col_set=DataHandler.CS_ALL,
|
||||
data_key=DataHandlerLP.DK_I,
|
||||
**kwargs,
|
||||
@@ -178,7 +172,7 @@ class DatasetH(Dataset):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
segments : Union[List[str], Tuple[str], str, slice]
|
||||
segments : Union[List[Text], Tuple[Text], Text, slice]
|
||||
Describe the scope of the data to be prepared
|
||||
Here are some examples:
|
||||
|
||||
@@ -408,7 +402,7 @@ class TSDataSampler:
|
||||
# 1) for better performance, use the last nan line for padding the lost date
|
||||
# 2) In case of precision problems. We use np.float64. # TODO: I'm not sure if whether np.float64 will result in
|
||||
# precision problems. It will not cause any problems in my tests at least
|
||||
indices = np.nan_to_num(indices.astype(np.float64), nan=self.nan_idx).astype(np.int)
|
||||
indices = np.nan_to_num(indices.astype(np.float64), nan=self.nan_idx).astype(int)
|
||||
|
||||
data = self.data_arr[indices]
|
||||
if isinstance(idx, mtit):
|
||||
|
||||
@@ -35,7 +35,7 @@ class DataHandler(Serializable):
|
||||
The data handler try to maintain a handler with 2 level.
|
||||
`datetime` & `instruments`.
|
||||
|
||||
Any order of the index level can be suported(The order will implied in the data).
|
||||
Any order of the index level can be suported (The order will be implied in the data).
|
||||
The order <`datetime`, `instruments`> will be used when the dataframe index name is missed.
|
||||
|
||||
Example of the data:
|
||||
@@ -47,8 +47,8 @@ class DataHandler(Serializable):
|
||||
$close $volume Ref($close, 1) Mean($close, 3) $high-$low LABEL0
|
||||
datetime instrument
|
||||
2010-01-04 SH600000 81.807068 17145150.0 83.737389 83.016739 2.741058 0.0032
|
||||
SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042
|
||||
SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
|
||||
SH600004 13.313329 11800983.0 13.313329 13.317701 0.183632 0.0042
|
||||
SH600005 37.796539 12231662.0 38.258602 37.919757 0.970325 0.0289
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@@ -74,7 +74,6 @@ class NpElemOperator(ElemOperator):
|
||||
"""
|
||||
|
||||
def __init__(self, feature, func):
|
||||
self.feature = feature
|
||||
self.func = func
|
||||
super(NpElemOperator, self).__init__(feature)
|
||||
|
||||
@@ -289,8 +288,6 @@ class NpPairOperator(PairOperator):
|
||||
"""
|
||||
|
||||
def __init__(self, feature_left, feature_right, func):
|
||||
self.feature_left = feature_left
|
||||
self.feature_right = feature_right
|
||||
self.func = func
|
||||
super(NpPairOperator, self).__init__(feature_left, feature_right)
|
||||
|
||||
@@ -1182,7 +1179,7 @@ class Slope(Rolling):
|
||||
Returns
|
||||
----------
|
||||
Expression
|
||||
a feature instance with regression slope of given window
|
||||
a feature instance with linear regression slope of given window
|
||||
"""
|
||||
|
||||
def __init__(self, feature, N):
|
||||
@@ -1210,7 +1207,7 @@ class Rsquare(Rolling):
|
||||
Returns
|
||||
----------
|
||||
Expression
|
||||
a feature instance with regression r-value square of given window
|
||||
a feature instance with linear regression r-value square of given window
|
||||
"""
|
||||
|
||||
def __init__(self, feature, N):
|
||||
@@ -1489,7 +1486,7 @@ OpsList = [
|
||||
]
|
||||
|
||||
|
||||
class OpsWrapper(object):
|
||||
class OpsWrapper:
|
||||
"""Ops Wrapper"""
|
||||
|
||||
def __init__(self):
|
||||
|
||||
10
qlib/log.py
10
qlib/log.py
@@ -3,8 +3,7 @@
|
||||
|
||||
|
||||
import logging
|
||||
import logging.handlers
|
||||
import os
|
||||
from typing import Optional, Text, Dict, Any
|
||||
import re
|
||||
from logging import config as logging_config
|
||||
from time import time
|
||||
@@ -13,16 +12,13 @@ from contextlib import contextmanager
|
||||
from .config import C
|
||||
|
||||
|
||||
def get_module_logger(module_name, level=None):
|
||||
def get_module_logger(module_name, level: Optional[int] = None):
|
||||
"""
|
||||
Get a logger for a specific module.
|
||||
|
||||
:param module_name: str
|
||||
Logic module name.
|
||||
:param level: int
|
||||
:param sh_level: int
|
||||
Stream handler log level.
|
||||
:param log_format: str
|
||||
:return: Logger
|
||||
Logger object.
|
||||
"""
|
||||
@@ -103,7 +99,7 @@ class TimeInspector:
|
||||
cls.log_cost_time(info=f"{name} Done")
|
||||
|
||||
|
||||
def set_log_with_config(log_config: dict):
|
||||
def set_log_with_config(log_config: Dict[Text, Any]):
|
||||
"""set log with config
|
||||
|
||||
:param log_config:
|
||||
|
||||
@@ -24,7 +24,7 @@ import collections
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from typing import Union, Tuple
|
||||
from typing import Union, Tuple, Text, Optional
|
||||
|
||||
from ..config import C
|
||||
from ..log import get_module_logger, set_log_with_config
|
||||
@@ -64,7 +64,7 @@ def np_ffill(arr: np.array):
|
||||
arr : np.array
|
||||
Input numpy 1D array
|
||||
"""
|
||||
mask = np.isnan(arr.astype(np.float)) # np.isnan only works on np.float
|
||||
mask = np.isnan(arr.astype(float)) # np.isnan only works on np.float
|
||||
# get fill index
|
||||
idx = np.where(~mask, np.arange(mask.shape[0]), 0)
|
||||
np.maximum.accumulate(idx, out=idx)
|
||||
@@ -212,7 +212,7 @@ def get_cls_kwargs(config: Union[dict, str], module) -> (type, dict):
|
||||
|
||||
|
||||
def init_instance_by_config(
|
||||
config: Union[str, dict, object], module=None, accept_types: Union[type, Tuple[type]] = tuple([]), **kwargs
|
||||
config: Union[str, dict, object], module=None, accept_types: Union[type, Tuple[type]] = (), **kwargs
|
||||
) -> object:
|
||||
"""
|
||||
get initialized instance with config
|
||||
@@ -276,23 +276,31 @@ def compare_dict_value(src_data: dict, dst_data: dict):
|
||||
return changes
|
||||
|
||||
|
||||
def create_save_path(save_path=None):
|
||||
"""Create save path
|
||||
def get_or_create_path(path: Optional[Text] = None, return_dir: bool = False):
|
||||
"""Create or get a file or directory given the path and return_dir.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
save_path: str
|
||||
path: a string indicates the path or None indicates creating a temporary path.
|
||||
return_dir: if True, create and return a directory; otherwise c&r a file.
|
||||
|
||||
"""
|
||||
if save_path:
|
||||
if not os.path.exists(save_path):
|
||||
os.makedirs(save_path)
|
||||
if path:
|
||||
if return_dir and not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
elif not return_dir: # return a file, thus we need to create its parent directory
|
||||
xpath = os.path.abspath(os.path.join(path, ".."))
|
||||
if not os.path.exists(xpath):
|
||||
os.makedirs(xpath)
|
||||
else:
|
||||
temp_dir = os.path.expanduser("~/tmp")
|
||||
if not os.path.exists(temp_dir):
|
||||
os.makedirs(temp_dir)
|
||||
_, save_path = tempfile.mkstemp(dir=temp_dir)
|
||||
return save_path
|
||||
if return_dir:
|
||||
_, path = tempfile.mkdtemp(dir=temp_dir)
|
||||
else:
|
||||
_, path = tempfile.mkstemp(dir=temp_dir)
|
||||
return path
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@@ -722,6 +730,9 @@ class Wrapper:
|
||||
def register(self, provider):
|
||||
self._provider = provider
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(provider={provider})".format(name=self.__class__.__name__, provider=self._provider)
|
||||
|
||||
def __getattr__(self, key):
|
||||
if self._provider is None:
|
||||
raise AttributeError("Please run qlib.init() first using qlib")
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Text, Optional
|
||||
from .expm import MLflowExpManager
|
||||
from .exp import Experiment
|
||||
from .recorder import Recorder
|
||||
@@ -16,8 +17,13 @@ class QlibRecorder:
|
||||
def __init__(self, exp_manager):
|
||||
self.exp_manager = exp_manager
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(manager={manager})".format(name=self.__class__.__name__, manager=self.exp_manager)
|
||||
|
||||
@contextmanager
|
||||
def start(self, experiment_name=None, recorder_name=None):
|
||||
def start(
|
||||
self, experiment_name: Optional[Text] = None, recorder_name: Optional[Text] = None, uri: Optional[Text] = None
|
||||
):
|
||||
"""
|
||||
Method to start an experiment. This method can only be called within a Python's `with` statement. Here is the example code:
|
||||
|
||||
@@ -34,8 +40,13 @@ class QlibRecorder:
|
||||
name of the experiment one wants to start.
|
||||
recorder_name : str
|
||||
name of the recorder under the experiment one wants to start.
|
||||
uri : str
|
||||
The tracking uri of the experiment, where all the artifacts/metrics etc. will be stored.
|
||||
The default uri is set in the qlib.config. Note that this uri argument will not change the one defined in the config file.
|
||||
Therefore, the next time when users call this function in the same experiment,
|
||||
they have to also specify this argument with the same value. Otherwise, inconsistent uri may occur.
|
||||
"""
|
||||
run = self.start_exp(experiment_name, recorder_name)
|
||||
run = self.start_exp(experiment_name, recorder_name, uri)
|
||||
try:
|
||||
yield run
|
||||
except Exception as e:
|
||||
@@ -272,7 +283,13 @@ class QlibRecorder:
|
||||
-------
|
||||
The uri of current experiment manager.
|
||||
"""
|
||||
return self.exp_manager.get_uri()
|
||||
return self.exp_manager.uri
|
||||
|
||||
def set_uri(self, uri: Optional[Text]):
|
||||
"""
|
||||
Method to reset the current uri of current experiment manager.
|
||||
"""
|
||||
self.exp_manager.set_uri(uri)
|
||||
|
||||
def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=None):
|
||||
"""
|
||||
|
||||
@@ -16,7 +16,7 @@ def get_path_list(path):
|
||||
if isinstance(path, str):
|
||||
return [path]
|
||||
else:
|
||||
return [p for p in path]
|
||||
return list(path)
|
||||
|
||||
|
||||
def sys_config(config, config_path):
|
||||
|
||||
@@ -23,7 +23,7 @@ class Experiment:
|
||||
self.active_recorder = None # only one recorder can running each time
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.info)
|
||||
return "{name}(id={id}, info={info})".format(name=self.__class__.__name__, id=self.id, info=self.info)
|
||||
|
||||
def __str__(self):
|
||||
return str(self.info)
|
||||
@@ -173,11 +173,12 @@ class MLflowExperiment(Experiment):
|
||||
self._uri = uri
|
||||
self._default_name = None
|
||||
self._default_rec_name = "mlflow_recorder"
|
||||
self.client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
|
||||
self._client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
|
||||
|
||||
def __repr__(self):
|
||||
return "{name}(id={id}, info={info})".format(name=self.__class__.__name__, id=self.id, info=self.info)
|
||||
|
||||
def start(self, recorder_name=None):
|
||||
# set the active experiment
|
||||
mlflow.set_experiment(self.name)
|
||||
logger.info(f"Experiment {self.id} starts running ...")
|
||||
# set up recorder
|
||||
recorder = self.create_recorder(recorder_name)
|
||||
@@ -210,7 +211,6 @@ class MLflowExperiment(Experiment):
|
||||
else:
|
||||
recorder, is_new = self._get_recorder(recorder_id=recorder_id, recorder_name=recorder_name), False
|
||||
if is_new:
|
||||
mlflow.set_experiment(self.name)
|
||||
self.active_recorder = recorder
|
||||
# start the recorder
|
||||
self.active_recorder.start_run()
|
||||
@@ -239,7 +239,7 @@ class MLflowExperiment(Experiment):
|
||||
), "Please input at least one of recorder id or name before retrieving recorder."
|
||||
if recorder_id is not None:
|
||||
try:
|
||||
run = self.client.get_run(recorder_id)
|
||||
run = self._client.get_run(recorder_id)
|
||||
recorder = MLflowRecorder(self.id, self._uri, mlflow_run=run)
|
||||
return recorder
|
||||
except MlflowException:
|
||||
@@ -260,7 +260,7 @@ class MLflowExperiment(Experiment):
|
||||
max_results = 100000 if kwargs.get("max_results") is None else kwargs.get("max_results")
|
||||
order_by = kwargs.get("order_by")
|
||||
|
||||
return self.client.search_runs([self.id], filter_string, run_view_type, max_results, order_by)
|
||||
return self._client.search_runs([self.id], filter_string, run_view_type, max_results, order_by)
|
||||
|
||||
def delete_recorder(self, recorder_id=None, recorder_name=None):
|
||||
assert (
|
||||
@@ -268,10 +268,10 @@ class MLflowExperiment(Experiment):
|
||||
), "Please input a valid recorder id or name before deleting."
|
||||
try:
|
||||
if recorder_id is not None:
|
||||
self.client.delete_run(recorder_id)
|
||||
self._client.delete_run(recorder_id)
|
||||
else:
|
||||
recorder = self._get_recorder(recorder_name=recorder_name)
|
||||
self.client.delete_run(recorder.id)
|
||||
self._client.delete_run(recorder.id)
|
||||
except MlflowException as e:
|
||||
raise Exception(
|
||||
f"Error: {e}. Something went wrong when deleting recorder. Please check if the name/id of the recorder is correct."
|
||||
@@ -280,7 +280,7 @@ class MLflowExperiment(Experiment):
|
||||
UNLIMITED = 50000 # FIXME: Mlflow can only list 50000 records at most!!!!!!!
|
||||
|
||||
def list_recorders(self, max_results=UNLIMITED):
|
||||
runs = self.client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)[::-1]
|
||||
runs = self._client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)[::-1]
|
||||
recorders = dict()
|
||||
for i in range(len(runs)):
|
||||
recorder = MLflowRecorder(self.id, self._uri, mlflow_run=runs[i])
|
||||
|
||||
@@ -7,8 +7,11 @@ from mlflow.entities import ViewType
|
||||
import os
|
||||
from pathlib import Path
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Text
|
||||
|
||||
from .exp import MLflowExperiment, Experiment
|
||||
from .recorder import Recorder, MLflowRecorder
|
||||
from ..config import C
|
||||
from .recorder import Recorder
|
||||
from ..log import get_module_logger
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
@@ -20,12 +23,21 @@ class ExpManager:
|
||||
(The link: https://mlflow.org/docs/latest/python_api/mlflow.html)
|
||||
"""
|
||||
|
||||
def __init__(self, uri, default_exp_name):
|
||||
self.uri = uri
|
||||
def __init__(self, uri: Text, default_exp_name: Optional[Text]):
|
||||
self._current_uri = uri
|
||||
self.default_exp_name = default_exp_name
|
||||
self.active_experiment = None # only one experiment can active each time
|
||||
|
||||
def start_exp(self, experiment_name=None, recorder_name=None, uri=None, **kwargs):
|
||||
def __repr__(self):
|
||||
return "{name}(current_uri={curi})".format(name=self.__class__.__name__, curi=self._current_uri)
|
||||
|
||||
def start_exp(
|
||||
self,
|
||||
experiment_name: Optional[Text] = None,
|
||||
recorder_name: Optional[Text] = None,
|
||||
uri: Optional[Text] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Start an experiment. This method includes first get_or_create an experiment, and then
|
||||
set it to be active.
|
||||
@@ -45,7 +57,7 @@ class ExpManager:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `start_exp` method.")
|
||||
|
||||
def end_exp(self, recorder_status: str = Recorder.STATUS_S, **kwargs):
|
||||
def end_exp(self, recorder_status: Text = Recorder.STATUS_S, **kwargs):
|
||||
"""
|
||||
End an active experiment.
|
||||
|
||||
@@ -58,7 +70,7 @@ class ExpManager:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `end_exp` method.")
|
||||
|
||||
def create_exp(self, experiment_name=None):
|
||||
def create_exp(self, experiment_name: Optional[Text] = None):
|
||||
"""
|
||||
Create an experiment.
|
||||
|
||||
@@ -203,7 +215,17 @@ class ExpManager:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `delete_exp` method.")
|
||||
|
||||
def get_uri(self):
|
||||
@property
|
||||
def default_uri(self):
|
||||
"""
|
||||
Get the default tracking URI from qlib.config.C
|
||||
"""
|
||||
if "kwargs" not in C.exp_manager or "uri" not in C.exp_manager["kwargs"]:
|
||||
raise ValueError("The default URI is not set in qlib.config.C")
|
||||
return C.exp_manager["kwargs"]["uri"]
|
||||
|
||||
@property
|
||||
def uri(self):
|
||||
"""
|
||||
Get the default tracking URI or current URI.
|
||||
|
||||
@@ -211,7 +233,31 @@ class ExpManager:
|
||||
-------
|
||||
The tracking URI string.
|
||||
"""
|
||||
return self.uri
|
||||
return self._current_uri or self.default_uri
|
||||
|
||||
def set_uri(self, uri: Optional[Text] = None):
|
||||
"""
|
||||
Set the current tracking URI and the corresponding variables.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
uri : str
|
||||
|
||||
"""
|
||||
if uri is None:
|
||||
logger.info("No tracking URI is provided. Use the default tracking URI.")
|
||||
self._current_uri = self.default_uri
|
||||
else:
|
||||
# Temporarily re-set the current uri as the uri argument.
|
||||
self._current_uri = uri
|
||||
# Customized features for subclasses.
|
||||
self._set_uri()
|
||||
|
||||
def _set_uri(self):
|
||||
"""
|
||||
Customized features for subclasses' set_uri function.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `_set_uri` method.")
|
||||
|
||||
def list_experiments(self):
|
||||
"""
|
||||
@@ -229,37 +275,43 @@ class MLflowExpManager(ExpManager):
|
||||
Use mlflow to implement ExpManager.
|
||||
"""
|
||||
|
||||
def __init__(self, uri, default_exp_name):
|
||||
def __init__(self, uri: Text, default_exp_name: Optional[Text]):
|
||||
super(MLflowExpManager, self).__init__(uri, default_exp_name)
|
||||
self._client = None
|
||||
|
||||
def _set_uri(self):
|
||||
self._client = mlflow.tracking.MlflowClient(tracking_uri=self.uri)
|
||||
logger.info("{:}".format(self._client))
|
||||
|
||||
@property
|
||||
def client(self):
|
||||
# Delay the creation of mlflow client in case of creating `mlruns` folder when importing qlib
|
||||
if not hasattr(self, "_client"):
|
||||
if self._client is None:
|
||||
self._client = mlflow.tracking.MlflowClient(tracking_uri=self.uri)
|
||||
return self._client
|
||||
|
||||
def start_exp(self, experiment_name=None, recorder_name=None, uri=None):
|
||||
# set the tracking uri
|
||||
if uri is None:
|
||||
logger.info("No tracking URI is provided. Use the default tracking URI.")
|
||||
else:
|
||||
self.uri = uri
|
||||
# create experiment
|
||||
def start_exp(
|
||||
self, experiment_name: Optional[Text] = None, recorder_name: Optional[Text] = None, uri: Optional[Text] = None
|
||||
):
|
||||
# Set the tracking uri
|
||||
self.set_uri(uri)
|
||||
# Create experiment
|
||||
experiment, _ = self._get_or_create_exp(experiment_name=experiment_name)
|
||||
# set up active experiment
|
||||
# Set up active experiment
|
||||
self.active_experiment = experiment
|
||||
# start the experiment
|
||||
# Start the experiment
|
||||
self.active_experiment.start(recorder_name)
|
||||
|
||||
return self.active_experiment
|
||||
|
||||
def end_exp(self, recorder_status: str = Recorder.STATUS_S):
|
||||
def end_exp(self, recorder_status: Text = Recorder.STATUS_S):
|
||||
if self.active_experiment is not None:
|
||||
self.active_experiment.end(recorder_status)
|
||||
self.active_experiment = None
|
||||
# When an experiment end, we will release the current uri.
|
||||
self._current_uri = None
|
||||
|
||||
def create_exp(self, experiment_name=None):
|
||||
def create_exp(self, experiment_name: Optional[Text] = None):
|
||||
assert experiment_name is not None
|
||||
# init experiment
|
||||
experiment_id = self.client.create_experiment(experiment_name)
|
||||
|
||||
@@ -34,7 +34,7 @@ class Recorder:
|
||||
self.status = Recorder.STATUS_S
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.info)
|
||||
return "{name}(info={info})".format(name=self.__class__.__name__, info=self.info)
|
||||
|
||||
def __str__(self):
|
||||
return str(self.info)
|
||||
@@ -201,7 +201,7 @@ class MLflowRecorder(Recorder):
|
||||
def __init__(self, experiment_id, uri, name=None, mlflow_run=None):
|
||||
super(MLflowRecorder, self).__init__(experiment_id, name)
|
||||
self._uri = uri
|
||||
self.artifact_uri = None
|
||||
self._artifact_uri = None
|
||||
self.client = mlflow.tracking.MlflowClient(tracking_uri=self._uri)
|
||||
# construct from mlflow run
|
||||
if mlflow_run is not None:
|
||||
@@ -220,14 +220,51 @@ class MLflowRecorder(Recorder):
|
||||
else None
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
name = self.__class__.__name__
|
||||
space_length = len(name) + 1
|
||||
return "{name}(info={info},\n{space}uri={uri},\n{space}artifact_uri={artifact_uri},\n{space}client={client})".format(
|
||||
name=name,
|
||||
space=" " * space_length,
|
||||
info=self.info,
|
||||
uri=self.uri,
|
||||
artifact_uri=self.artifact_uri,
|
||||
client=self.client,
|
||||
)
|
||||
|
||||
@property
|
||||
def uri(self):
|
||||
return self._uri
|
||||
|
||||
@property
|
||||
def artifact_uri(self):
|
||||
return self._artifact_uri
|
||||
|
||||
def get_local_dir(self):
|
||||
"""
|
||||
This function will return the directory path of this recorder.
|
||||
"""
|
||||
if self.artifact_uri is not None:
|
||||
local_dir_path = Path(self.artifact_uri.lstrip("file:")) / ".."
|
||||
local_dir_path = str(local_dir_path.resolve())
|
||||
if os.path.isdir(local_dir_path):
|
||||
return local_dir_path
|
||||
else:
|
||||
raise RuntimeError("This recorder is not saved in the local file system.")
|
||||
|
||||
else:
|
||||
raise Exception(
|
||||
"Please make sure the recorder has been created and started properly before getting artifact uri."
|
||||
)
|
||||
|
||||
def start_run(self):
|
||||
# set the tracking uri
|
||||
mlflow.set_tracking_uri(self._uri)
|
||||
mlflow.set_tracking_uri(self.uri)
|
||||
# start the run
|
||||
run = mlflow.start_run(self.id, self.experiment_id, self.name)
|
||||
# save the run id and artifact_uri
|
||||
self.id = run.info.run_id
|
||||
self.artifact_uri = run.info.artifact_uri
|
||||
self._artifact_uri = run.info.artifact_uri
|
||||
self.start_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
self.status = Recorder.STATUS_R
|
||||
logger.info(f"Recorder {self.id} starts running under Experiment {self.experiment_id} ...")
|
||||
@@ -247,7 +284,7 @@ class MLflowRecorder(Recorder):
|
||||
self.status = status
|
||||
|
||||
def save_objects(self, local_path=None, artifact_path=None, **kwargs):
|
||||
assert self._uri is not None, "Please start the experiment and recorder first before using recorder directly."
|
||||
assert self.uri is not None, "Please start the experiment and recorder first before using recorder directly."
|
||||
if local_path is not None:
|
||||
self.client.log_artifacts(self.id, local_path, artifact_path)
|
||||
else:
|
||||
@@ -259,7 +296,7 @@ class MLflowRecorder(Recorder):
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
def load_object(self, name):
|
||||
assert self._uri is not None, "Please start the experiment and recorder first before using recorder directly."
|
||||
assert self.uri is not None, "Please start the experiment and recorder first before using recorder directly."
|
||||
path = self.client.download_artifacts(self.id, name)
|
||||
with Path(path).open("rb") as f:
|
||||
return pickle.load(f)
|
||||
@@ -289,7 +326,7 @@ class MLflowRecorder(Recorder):
|
||||
)
|
||||
|
||||
def list_artifacts(self, artifact_path=None):
|
||||
assert self._uri is not None, "Please start the experiment and recorder first before using recorder directly."
|
||||
assert self.uri is not None, "Please start the experiment and recorder first before using recorder directly."
|
||||
artifacts = self.client.list_artifacts(self.id, artifact_path)
|
||||
return [art.path for art in artifacts]
|
||||
|
||||
|
||||
430
scripts/data_collector/base.py
Normal file
430
scripts/data_collector/base.py
Normal file
@@ -0,0 +1,430 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
import abc
|
||||
import time
|
||||
import datetime
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
from typing import Type
|
||||
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
||||
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
from qlib.utils import code_to_fname
|
||||
|
||||
|
||||
class BaseCollector(abc.ABC):
|
||||
|
||||
CACHE_FLAG = "CACHED"
|
||||
NORMAL_FLAG = "NORMAL"
|
||||
|
||||
DEFAULT_START_DATETIME_1D = pd.Timestamp("2000-01-01")
|
||||
DEFAULT_START_DATETIME_1MIN = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 6))
|
||||
DEFAULT_END_DATETIME_1D = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))
|
||||
DEFAULT_END_DATETIME_1MIN = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))
|
||||
|
||||
INTERVAL_1min = "1min"
|
||||
INTERVAL_1d = "1d"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
save_dir: [str, Path],
|
||||
start=None,
|
||||
end=None,
|
||||
interval="1d",
|
||||
max_workers=4,
|
||||
max_collector_count=2,
|
||||
delay=0,
|
||||
check_data_length: bool = False,
|
||||
limit_nums: int = None,
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
save_dir: str
|
||||
stock save dir
|
||||
max_workers: int
|
||||
workers, default 4
|
||||
max_collector_count: int
|
||||
default 2
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1d
|
||||
start: str
|
||||
start datetime, default None
|
||||
end: str
|
||||
end datetime, default None
|
||||
check_data_length: bool
|
||||
check data length, by default False
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
"""
|
||||
self.save_dir = Path(save_dir).expanduser().resolve()
|
||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self.delay = delay
|
||||
self.max_workers = max_workers
|
||||
self.max_collector_count = max_collector_count
|
||||
self.mini_symbol_map = {}
|
||||
self.interval = interval
|
||||
self.check_small_data = check_data_length
|
||||
|
||||
self.start_datetime = self.normalize_start_datetime(start)
|
||||
self.end_datetime = self.normalize_end_datetime(end)
|
||||
|
||||
self.stock_list = sorted(set(self.get_stock_list()))
|
||||
|
||||
if limit_nums is not None:
|
||||
try:
|
||||
self.stock_list = self.stock_list[: int(limit_nums)]
|
||||
except Exception as e:
|
||||
logger.warning(f"Cannot use limit_nums={limit_nums}, the parameter will be ignored")
|
||||
|
||||
def normalize_start_datetime(self, start_datetime: [str, pd.Timestamp] = None):
|
||||
return (
|
||||
pd.Timestamp(str(start_datetime))
|
||||
if start_datetime
|
||||
else getattr(self, f"DEFAULT_START_DATETIME_{self.interval.upper()}")
|
||||
)
|
||||
|
||||
def normalize_end_datetime(self, end_datetime: [str, pd.Timestamp] = None):
|
||||
return (
|
||||
pd.Timestamp(str(end_datetime))
|
||||
if end_datetime
|
||||
else getattr(self, f"DEFAULT_END_DATETIME_{self.interval.upper()}")
|
||||
)
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def min_numbers_trading(self):
|
||||
# daily, one year: 252 / 4
|
||||
# us 1min, a week: 6.5 * 60 * 5
|
||||
# cn 1min, a week: 4 * 60 * 5
|
||||
raise NotImplementedError("rewrite min_numbers_trading")
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_stock_list(self):
|
||||
raise NotImplementedError("rewrite get_stock_list")
|
||||
|
||||
@abc.abstractmethod
|
||||
def normalize_symbol(self, symbol: str):
|
||||
"""normalize symbol"""
|
||||
raise NotImplementedError("rewrite normalize_symbol")
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_data(
|
||||
self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp
|
||||
) -> pd.DataFrame:
|
||||
"""get data with symbol
|
||||
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
interval: str
|
||||
value from [1min, 1d]
|
||||
start_datetime: pd.Timestamp
|
||||
end_datetime: pd.Timestamp
|
||||
|
||||
Returns
|
||||
---------
|
||||
pd.DataFrame, "symbol" in pd.columns
|
||||
|
||||
"""
|
||||
raise NotImplementedError("rewrite get_timezone")
|
||||
|
||||
def sleep(self):
|
||||
time.sleep(self.delay)
|
||||
|
||||
def _simple_collector(self, symbol: str):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
|
||||
"""
|
||||
self.sleep()
|
||||
df = self.get_data(symbol, self.interval, self.start_datetime, self.end_datetime)
|
||||
_result = self.NORMAL_FLAG
|
||||
if self.check_small_data:
|
||||
_result = self.cache_small_data(symbol, df)
|
||||
if _result == self.NORMAL_FLAG:
|
||||
self.save_instrument(symbol, df)
|
||||
return _result
|
||||
|
||||
def save_instrument(self, symbol, df: pd.DataFrame):
|
||||
"""save stock data to file
|
||||
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
stock code
|
||||
df : pd.DataFrame
|
||||
df.columns must contain "symbol" and "datetime"
|
||||
"""
|
||||
if df.empty:
|
||||
logger.warning(f"{symbol} is empty")
|
||||
return
|
||||
|
||||
symbol = self.normalize_symbol(symbol)
|
||||
symbol = code_to_fname(symbol)
|
||||
stock_path = self.save_dir.joinpath(f"{symbol}.csv")
|
||||
df["symbol"] = symbol
|
||||
if stock_path.exists():
|
||||
_old_df = pd.read_csv(stock_path)
|
||||
df = _old_df.append(df, sort=False)
|
||||
df.to_csv(stock_path, index=False)
|
||||
|
||||
def cache_small_data(self, symbol, df):
|
||||
if len(df) <= self.min_numbers_trading:
|
||||
logger.warning(f"the number of trading days of {symbol} is less than {self.min_numbers_trading}!")
|
||||
_temp = self.mini_symbol_map.setdefault(symbol, [])
|
||||
_temp.append(df.copy())
|
||||
return self.CACHE_FLAG
|
||||
else:
|
||||
if symbol in self.mini_symbol_map:
|
||||
self.mini_symbol_map.pop(symbol)
|
||||
return self.NORMAL_FLAG
|
||||
|
||||
def _collector(self, stock_list):
|
||||
|
||||
error_symbol = []
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
with tqdm(total=len(stock_list)) as p_bar:
|
||||
for _symbol, _result in zip(stock_list, executor.map(self._simple_collector, stock_list)):
|
||||
if _result != self.NORMAL_FLAG:
|
||||
error_symbol.append(_symbol)
|
||||
p_bar.update()
|
||||
print(error_symbol)
|
||||
logger.info(f"error symbol nums: {len(error_symbol)}")
|
||||
logger.info(f"current get symbol nums: {len(stock_list)}")
|
||||
error_symbol.extend(self.mini_symbol_map.keys())
|
||||
return sorted(set(error_symbol))
|
||||
|
||||
def collector_data(self):
|
||||
"""collector data"""
|
||||
logger.info("start collector data......")
|
||||
stock_list = self.stock_list
|
||||
for i in range(self.max_collector_count):
|
||||
if not stock_list:
|
||||
break
|
||||
logger.info(f"getting data: {i+1}")
|
||||
stock_list = self._collector(stock_list)
|
||||
logger.info(f"{i+1} finish.")
|
||||
for _symbol, _df_list in self.mini_symbol_map.items():
|
||||
self.save_instrument(
|
||||
_symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"])
|
||||
)
|
||||
if self.mini_symbol_map:
|
||||
logger.warning(f"less than {self.min_numbers_trading} stock list: {list(self.mini_symbol_map.keys())}")
|
||||
logger.info(f"total {len(self.stock_list)}, error: {len(set(stock_list))}")
|
||||
|
||||
|
||||
class BaseNormalize(abc.ABC):
|
||||
def __init__(
|
||||
self,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
date_field_name: str
|
||||
date field name, default is date
|
||||
symbol_field_name: str
|
||||
symbol field name, default is symbol
|
||||
"""
|
||||
self._date_field_name = date_field_name
|
||||
self._symbol_field_name = symbol_field_name
|
||||
|
||||
self._calendar_list = self._get_calendar_list()
|
||||
|
||||
@abc.abstractmethod
|
||||
def normalize(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
# normalize
|
||||
raise NotImplementedError("")
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_calendar_list(self):
|
||||
"""Get benchmark calendar"""
|
||||
raise NotImplementedError("")
|
||||
|
||||
|
||||
class Normalize:
|
||||
def __init__(
|
||||
self,
|
||||
source_dir: [str, Path],
|
||||
target_dir: [str, Path],
|
||||
normalize_class: Type[BaseNormalize],
|
||||
max_workers: int = 16,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source_dir: str or Path
|
||||
The directory where the raw data collected from the Internet is saved
|
||||
target_dir: str or Path
|
||||
Directory for normalize data
|
||||
normalize_class: Type[YahooNormalize]
|
||||
normalize class
|
||||
max_workers: int
|
||||
Concurrent number, default is 16
|
||||
date_field_name: str
|
||||
date field name, default is date
|
||||
symbol_field_name: str
|
||||
symbol field name, default is symbol
|
||||
"""
|
||||
if not (source_dir and target_dir):
|
||||
raise ValueError("source_dir and target_dir cannot be None")
|
||||
self._source_dir = Path(source_dir).expanduser()
|
||||
self._target_dir = Path(target_dir).expanduser()
|
||||
self._target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._max_workers = max_workers
|
||||
|
||||
self._normalize_obj = normalize_class(date_field_name=date_field_name, symbol_field_name=symbol_field_name)
|
||||
|
||||
def _executor(self, file_path: Path):
|
||||
file_path = Path(file_path)
|
||||
df = pd.read_csv(file_path)
|
||||
df = self._normalize_obj.normalize(df)
|
||||
if not df.empty:
|
||||
df.to_csv(self._target_dir.joinpath(file_path.name), index=False)
|
||||
|
||||
def normalize(self):
|
||||
logger.info("normalize data......")
|
||||
|
||||
with ProcessPoolExecutor(max_workers=self._max_workers) as worker:
|
||||
file_list = list(self._source_dir.glob("*.csv"))
|
||||
with tqdm(total=len(file_list)) as p_bar:
|
||||
for _ in worker.map(self._executor, file_list):
|
||||
p_bar.update()
|
||||
|
||||
|
||||
class BaseRun(abc.ABC):
|
||||
def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, interval="1d"):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source_dir: str
|
||||
The directory where the raw data collected from the Internet is saved, default "Path(__file__).parent/source"
|
||||
normalize_dir: str
|
||||
Directory for normalize data, default "Path(__file__).parent/normalize"
|
||||
max_workers: int
|
||||
Concurrent number, default is 4
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1d
|
||||
"""
|
||||
if source_dir is None:
|
||||
source_dir = Path(self.default_base_dir).joinpath("_source")
|
||||
self.source_dir = Path(source_dir).expanduser().resolve()
|
||||
self.source_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if normalize_dir is None:
|
||||
normalize_dir = Path(self.default_base_dir).joinpath("normalize")
|
||||
self.normalize_dir = Path(normalize_dir).expanduser().resolve()
|
||||
self.normalize_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._cur_module = importlib.import_module("collector")
|
||||
self.max_workers = max_workers
|
||||
self.interval = interval
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def collector_class_name(self):
|
||||
raise NotImplementedError("rewrite normalize_symbol")
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def normalize_class_name(self):
|
||||
raise NotImplementedError("rewrite normalize_symbol")
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def default_base_dir(self) -> [Path, str]:
|
||||
raise NotImplementedError("rewrite normalize_symbol")
|
||||
|
||||
def download_data(
|
||||
self,
|
||||
max_collector_count=2,
|
||||
delay=0,
|
||||
start=None,
|
||||
end=None,
|
||||
interval="1d",
|
||||
check_data_length=False,
|
||||
limit_nums=None,
|
||||
):
|
||||
"""download data from Internet
|
||||
|
||||
Parameters
|
||||
----------
|
||||
max_collector_count: int
|
||||
default 2
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1d
|
||||
start: str
|
||||
start datetime, default "2000-01-01"
|
||||
end: str
|
||||
end datetime, default ``pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))``
|
||||
check_data_length: bool
|
||||
check data length, by default False
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
|
||||
Examples
|
||||
---------
|
||||
# get daily data
|
||||
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1d
|
||||
# get 1m data
|
||||
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m
|
||||
"""
|
||||
|
||||
_class = getattr(self._cur_module, self.collector_class_name) # type: Type[BaseCollector]
|
||||
_class(
|
||||
self.source_dir,
|
||||
max_workers=self.max_workers,
|
||||
max_collector_count=max_collector_count,
|
||||
delay=delay,
|
||||
start=start,
|
||||
end=end,
|
||||
interval=interval,
|
||||
check_data_length=check_data_length,
|
||||
limit_nums=limit_nums,
|
||||
).collector_data()
|
||||
|
||||
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"):
|
||||
"""normalize data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
date_field_name: str
|
||||
date field name, default date
|
||||
symbol_field_name: str
|
||||
symbol field name, default symbol
|
||||
|
||||
Examples
|
||||
---------
|
||||
$ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --interval 1d
|
||||
"""
|
||||
_class = getattr(self._cur_module, self.normalize_class_name)
|
||||
yc = Normalize(
|
||||
source_dir=self.source_dir,
|
||||
target_dir=self.normalize_dir,
|
||||
normalize_class=_class,
|
||||
max_workers=self.max_workers,
|
||||
date_field_name=date_field_name,
|
||||
symbol_field_name=symbol_field_name,
|
||||
)
|
||||
yc.normalize()
|
||||
@@ -10,10 +10,10 @@ pip install -r requirements.txt
|
||||
|
||||
```bash
|
||||
# parse instruments, using in qlib/instruments.
|
||||
python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/cn_data --method parse_instruments
|
||||
python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/us_data --method parse_instruments
|
||||
|
||||
# parse new companies
|
||||
python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/cn_data --method save_new_companies
|
||||
python collector.py --index_name SP500 --qlib_dir ~/.qlib/qlib_data/us_data --method save_new_companies
|
||||
|
||||
# index_name support: SP500, NASDAQ100, DJIA, SP400
|
||||
# help
|
||||
|
||||
@@ -48,7 +48,7 @@ python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1d -
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
qlib.init(provider_uri="~/.qlib/qlib_data/qlib_cn_1d", region="CN")
|
||||
qlib.init(provider_uri="~/.qlib/qlib_data/qlib_cn_1d", region="cn")
|
||||
df = D.features(D.instruments("all"), ["$close"], freq="day")
|
||||
```
|
||||
|
||||
@@ -78,7 +78,7 @@ python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_cn_1min
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
qlib.init(provider_uri="~/.qlib/qlib_data/qlib_cn_1min", region="CN")
|
||||
qlib.init(provider_uri="~/.qlib/qlib_data/qlib_cn_1min", region="cn")
|
||||
df = D.features(D.instruments("all"), ["$close"], freq="1min")
|
||||
|
||||
```
|
||||
@@ -97,7 +97,7 @@ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source/us_1d
|
||||
|
||||
# dump data
|
||||
cd qlib/scripts
|
||||
python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/cn_1d_nor --qlib_dir ~/.qlib/stock_data/source/qlib_us_1d --freq day --exclude_fields date,adjclose,dividends,splits,symbol
|
||||
python dump_bin.py dump_all --csv_path ~/.qlib/stock_data/source/us_1d_nor --qlib_dir ~/.qlib/stock_data/source/qlib_us_1d --freq day --exclude_fields date,adjclose,dividends,splits,symbol
|
||||
```
|
||||
|
||||
#### 1d from qlib
|
||||
@@ -113,7 +113,7 @@ python scripts/get_data.py qlib_data --target_dir ~/.qlib/qlib_data/qlib_us_1d -
|
||||
import qlib
|
||||
from qlib.data import D
|
||||
|
||||
qlib.init(provider_uri="~/.qlib/qlib_data/qlib_us_1d", region="US")
|
||||
qlib.init(provider_uri="~/.qlib/qlib_data/qlib_us_1d", region="us")
|
||||
df = D.features(D.instruments("all"), ["$close"], freq="day")
|
||||
|
||||
```
|
||||
|
||||
@@ -10,158 +10,26 @@ import importlib
|
||||
from abc import ABC
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Type
|
||||
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
|
||||
|
||||
import fire
|
||||
import requests
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
from loguru import logger
|
||||
from yahooquery import Ticker
|
||||
from dateutil.tz import tzlocal
|
||||
from qlib.utils import code_to_fname, fname_to_code
|
||||
from qlib.config import REG_CN as REGION_CN
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
from data_collector.base import BaseCollector, BaseNormalize, BaseRun
|
||||
from data_collector.utils import get_calendar_list, get_hs_stock_symbols, get_us_stock_symbols
|
||||
|
||||
INDEX_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{index_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg={begin}&end={end}"
|
||||
REGION_CN = "CN"
|
||||
REGION_US = "US"
|
||||
|
||||
|
||||
class YahooData:
|
||||
START_DATETIME = pd.Timestamp("2000-01-01")
|
||||
HIGH_FREQ_START_DATETIME = pd.Timestamp(datetime.datetime.now() - pd.Timedelta(days=5 * 6))
|
||||
END_DATETIME = pd.Timestamp(datetime.datetime.now() + pd.Timedelta(days=1))
|
||||
INTERVAL_1min = "1min"
|
||||
INTERVAL_1d = "1d"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
timezone: str = None,
|
||||
start=None,
|
||||
end=None,
|
||||
interval="1d",
|
||||
delay=0,
|
||||
show_1min_logging: bool = False,
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
timezone: str
|
||||
The timezone where the data is located
|
||||
delay: float
|
||||
time.sleep(delay), default 0
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1min
|
||||
start: str
|
||||
start datetime, default None
|
||||
end: str
|
||||
end datetime, default None
|
||||
show_1min_logging: bool
|
||||
show 1min logging, by default False; if True, there may be many warning logs
|
||||
"""
|
||||
self._timezone = tzlocal() if timezone is None else timezone
|
||||
self._delay = delay
|
||||
self._interval = interval
|
||||
self._show_1min_logging = show_1min_logging
|
||||
self.start_datetime = pd.Timestamp(str(start)) if start else self.START_DATETIME
|
||||
self.end_datetime = min(pd.Timestamp(str(end)) if end else self.END_DATETIME, self.END_DATETIME)
|
||||
if self._interval == self.INTERVAL_1min:
|
||||
self.start_datetime = max(self.start_datetime, self.HIGH_FREQ_START_DATETIME)
|
||||
elif self._interval == self.INTERVAL_1d:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"interval error: {self._interval}")
|
||||
|
||||
# using for 1min
|
||||
self._next_datetime = self.convert_datetime(self.start_datetime.date() + pd.Timedelta(days=1), self._timezone)
|
||||
self._latest_datetime = self.convert_datetime(self.end_datetime.date(), self._timezone)
|
||||
|
||||
self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone)
|
||||
self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone)
|
||||
|
||||
@staticmethod
|
||||
def convert_datetime(dt: [pd.Timestamp, datetime.date, str], timezone):
|
||||
try:
|
||||
dt = pd.Timestamp(dt, tz=timezone).timestamp()
|
||||
dt = pd.Timestamp(dt, tz=tzlocal(), unit="s")
|
||||
except ValueError as e:
|
||||
pass
|
||||
return dt
|
||||
|
||||
def _sleep(self):
|
||||
time.sleep(self._delay)
|
||||
|
||||
@staticmethod
|
||||
def get_data_from_remote(symbol, interval, start, end, show_1min_logging: bool = False):
|
||||
error_msg = f"{symbol}-{interval}-{start}-{end}"
|
||||
|
||||
def _show_logging_func():
|
||||
if interval == YahooData.INTERVAL_1min and show_1min_logging:
|
||||
logger.warning(f"{error_msg}:{_resp}")
|
||||
|
||||
interval = "1m" if interval in ["1m", "1min"] else interval
|
||||
try:
|
||||
_resp = Ticker(symbol, asynchronous=False).history(interval=interval, start=start, end=end)
|
||||
if isinstance(_resp, pd.DataFrame):
|
||||
return _resp.reset_index()
|
||||
elif isinstance(_resp, dict):
|
||||
_temp_data = _resp.get(symbol, {})
|
||||
if isinstance(_temp_data, str) or (
|
||||
isinstance(_resp, dict) and _temp_data.get("indicators", {}).get("quote", None) is None
|
||||
):
|
||||
_show_logging_func()
|
||||
else:
|
||||
_show_logging_func()
|
||||
except Exception as e:
|
||||
logger.warning(f"{error_msg}:{e}")
|
||||
|
||||
def get_data(self, symbol: str) -> [pd.DataFrame]:
|
||||
def _get_simple(start_, end_):
|
||||
self._sleep()
|
||||
_remote_interval = "1m" if self._interval == self.INTERVAL_1min else self._interval
|
||||
return self.get_data_from_remote(
|
||||
symbol,
|
||||
interval=_remote_interval,
|
||||
start=start_,
|
||||
end=end_,
|
||||
show_1min_logging=self._show_1min_logging,
|
||||
)
|
||||
|
||||
_result = None
|
||||
if self._interval == self.INTERVAL_1d:
|
||||
_result = _get_simple(self.start_datetime, self.end_datetime)
|
||||
elif self._interval == self.INTERVAL_1min:
|
||||
if self._next_datetime >= self._latest_datetime:
|
||||
_result = _get_simple(self.start_datetime, self.end_datetime)
|
||||
else:
|
||||
_res = []
|
||||
|
||||
def _get_multi(start_, end_):
|
||||
_resp = _get_simple(start_, end_)
|
||||
if _resp is not None and not _resp.empty:
|
||||
_res.append(_resp)
|
||||
|
||||
for _s, _e in (
|
||||
(self.start_datetime, self._next_datetime),
|
||||
(self._latest_datetime, self.end_datetime),
|
||||
):
|
||||
_get_multi(_s, _e)
|
||||
for _start in pd.date_range(self._next_datetime, self._latest_datetime, closed="left"):
|
||||
_end = _start + pd.Timedelta(days=1)
|
||||
_get_multi(_start, _end)
|
||||
if _res:
|
||||
_result = pd.concat(_res, sort=False).sort_values(["symbol", "date"])
|
||||
else:
|
||||
raise ValueError(f"cannot support {self._interval}")
|
||||
return _result
|
||||
|
||||
|
||||
class YahooCollector:
|
||||
class YahooCollector(BaseCollector):
|
||||
def __init__(
|
||||
self,
|
||||
save_dir: [str, Path],
|
||||
@@ -173,7 +41,6 @@ class YahooCollector:
|
||||
delay=0,
|
||||
check_data_length: bool = False,
|
||||
limit_nums: int = None,
|
||||
show_1min_logging: bool = False,
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -197,131 +64,118 @@ class YahooCollector:
|
||||
check data length, by default False
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
show_1min_logging: bool
|
||||
show 1m logging, by default False; if True, there may be many warning logs
|
||||
"""
|
||||
self.save_dir = Path(save_dir).expanduser().resolve()
|
||||
self.save_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._delay = delay
|
||||
self.max_workers = max_workers
|
||||
self._max_collector_count = max_collector_count
|
||||
self._mini_symbol_map = {}
|
||||
self._interval = interval
|
||||
self._check_small_data = check_data_length
|
||||
|
||||
self.stock_list = sorted(set(self.get_stock_list()))
|
||||
if limit_nums is not None:
|
||||
try:
|
||||
self.stock_list = self.stock_list[: int(limit_nums)]
|
||||
except Exception as e:
|
||||
logger.warning(f"Cannot use limit_nums={limit_nums}, the parameter will be ignored")
|
||||
|
||||
self.yahoo_data = YahooData(
|
||||
timezone=self._timezone,
|
||||
super(YahooCollector, self).__init__(
|
||||
save_dir=save_dir,
|
||||
start=start,
|
||||
end=end,
|
||||
interval=interval,
|
||||
max_workers=max_workers,
|
||||
max_collector_count=max_collector_count,
|
||||
delay=delay,
|
||||
show_1min_logging=show_1min_logging,
|
||||
check_data_length=check_data_length,
|
||||
limit_nums=limit_nums,
|
||||
)
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def min_numbers_trading(self):
|
||||
# daily, one year: 252 / 4
|
||||
# us 1min, a week: 6.5 * 60 * 5
|
||||
# cn 1min, a week: 4 * 60 * 5
|
||||
raise NotImplementedError("rewrite min_numbers_trading")
|
||||
self.init_datetime()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_stock_list(self):
|
||||
raise NotImplementedError("rewrite get_stock_list")
|
||||
def init_datetime(self):
|
||||
if self.interval == self.INTERVAL_1min:
|
||||
self.start_datetime = max(self.start_datetime, self.DEFAULT_START_DATETIME_1MIN)
|
||||
elif self.interval == self.INTERVAL_1d:
|
||||
pass
|
||||
else:
|
||||
raise ValueError(f"interval error: {self.interval}")
|
||||
|
||||
# using for 1min
|
||||
self._next_datetime = self.convert_datetime(self.start_datetime.date() + pd.Timedelta(days=1), self._timezone)
|
||||
self._latest_datetime = self.convert_datetime(self.end_datetime.date(), self._timezone)
|
||||
|
||||
self.start_datetime = self.convert_datetime(self.start_datetime, self._timezone)
|
||||
self.end_datetime = self.convert_datetime(self.end_datetime, self._timezone)
|
||||
|
||||
@staticmethod
|
||||
def convert_datetime(dt: [pd.Timestamp, datetime.date, str], timezone):
|
||||
try:
|
||||
dt = pd.Timestamp(dt, tz=timezone).timestamp()
|
||||
dt = pd.Timestamp(dt, tz=tzlocal(), unit="s")
|
||||
except ValueError as e:
|
||||
pass
|
||||
return dt
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def _timezone(self):
|
||||
raise NotImplementedError("rewrite get_timezone")
|
||||
|
||||
def save_stock(self, symbol, df: pd.DataFrame):
|
||||
"""save stock data to file
|
||||
@staticmethod
|
||||
def get_data_from_remote(symbol, interval, start, end, show_1min_logging: bool = False):
|
||||
error_msg = f"{symbol}-{interval}-{start}-{end}"
|
||||
|
||||
Parameters
|
||||
----------
|
||||
symbol: str
|
||||
stock code
|
||||
df : pd.DataFrame
|
||||
df.columns must contain "symbol" and "datetime"
|
||||
"""
|
||||
if df.empty:
|
||||
logger.warning(f"{symbol} is empty")
|
||||
return
|
||||
def _show_logging_func():
|
||||
if interval == YahooCollector.INTERVAL_1min and show_1min_logging:
|
||||
logger.warning(f"{error_msg}:{_resp}")
|
||||
|
||||
symbol = self.normalize_symbol(symbol)
|
||||
symbol = code_to_fname(symbol)
|
||||
stock_path = self.save_dir.joinpath(f"{symbol}.csv")
|
||||
df["symbol"] = symbol
|
||||
if stock_path.exists():
|
||||
_old_df = pd.read_csv(stock_path)
|
||||
df = _old_df.append(df, sort=False)
|
||||
df.to_csv(stock_path, index=False)
|
||||
interval = "1m" if interval in ["1m", "1min"] else interval
|
||||
try:
|
||||
_resp = Ticker(symbol, asynchronous=False).history(interval=interval, start=start, end=end)
|
||||
if isinstance(_resp, pd.DataFrame):
|
||||
return _resp.reset_index()
|
||||
elif isinstance(_resp, dict):
|
||||
_temp_data = _resp.get(symbol, {})
|
||||
if isinstance(_temp_data, str) or (
|
||||
isinstance(_resp, dict) and _temp_data.get("indicators", {}).get("quote", None) is None
|
||||
):
|
||||
_show_logging_func()
|
||||
else:
|
||||
_show_logging_func()
|
||||
except Exception as e:
|
||||
logger.warning(f"{error_msg}:{e}")
|
||||
|
||||
def _save_small_data(self, symbol, df):
|
||||
if len(df) <= self.min_numbers_trading:
|
||||
logger.warning(f"the number of trading days of {symbol} is less than {self.min_numbers_trading}!")
|
||||
_temp = self._mini_symbol_map.setdefault(symbol, [])
|
||||
_temp.append(df.copy())
|
||||
return None
|
||||
else:
|
||||
if symbol in self._mini_symbol_map:
|
||||
self._mini_symbol_map.pop(symbol)
|
||||
return symbol
|
||||
def get_data(
|
||||
self, symbol: str, interval: str, start_datetime: pd.Timestamp, end_datetime: pd.Timestamp
|
||||
) -> pd.DataFrame:
|
||||
def _get_simple(start_, end_):
|
||||
self.sleep()
|
||||
_remote_interval = "1m" if interval == self.INTERVAL_1min else interval
|
||||
return self.get_data_from_remote(
|
||||
symbol,
|
||||
interval=_remote_interval,
|
||||
start=start_,
|
||||
end=end_,
|
||||
)
|
||||
|
||||
def _get_data(self, symbol):
|
||||
_result = None
|
||||
df = self.yahoo_data.get_data(symbol)
|
||||
if isinstance(df, pd.DataFrame):
|
||||
if not df.empty:
|
||||
if self._check_small_data:
|
||||
if self._save_small_data(symbol, df) is not None:
|
||||
_result = symbol
|
||||
self.save_stock(symbol, df)
|
||||
else:
|
||||
_result = symbol
|
||||
self.save_stock(symbol, df)
|
||||
return _result
|
||||
if interval == self.INTERVAL_1d:
|
||||
_result = _get_simple(start_datetime, end_datetime)
|
||||
elif interval == self.INTERVAL_1min:
|
||||
if self._next_datetime >= self._latest_datetime:
|
||||
_result = _get_simple(start_datetime, end_datetime)
|
||||
else:
|
||||
_res = []
|
||||
|
||||
def _collector(self, stock_list):
|
||||
def _get_multi(start_, end_):
|
||||
_resp = _get_simple(start_, end_)
|
||||
if _resp is not None and not _resp.empty:
|
||||
_res.append(_resp)
|
||||
|
||||
error_symbol = []
|
||||
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
with tqdm(total=len(stock_list)) as p_bar:
|
||||
for _symbol, _result in zip(stock_list, executor.map(self._get_data, stock_list)):
|
||||
if _result is None:
|
||||
error_symbol.append(_symbol)
|
||||
p_bar.update()
|
||||
print(error_symbol)
|
||||
logger.info(f"error symbol nums: {len(error_symbol)}")
|
||||
logger.info(f"current get symbol nums: {len(stock_list)}")
|
||||
error_symbol.extend(self._mini_symbol_map.keys())
|
||||
return sorted(set(error_symbol))
|
||||
for _s, _e in (
|
||||
(self.start_datetime, self._next_datetime),
|
||||
(self._latest_datetime, self.end_datetime),
|
||||
):
|
||||
_get_multi(_s, _e)
|
||||
for _start in pd.date_range(self._next_datetime, self._latest_datetime, closed="left"):
|
||||
_end = _start + pd.Timedelta(days=1)
|
||||
_get_multi(_start, _end)
|
||||
if _res:
|
||||
_result = pd.concat(_res, sort=False).sort_values(["symbol", "date"])
|
||||
else:
|
||||
raise ValueError(f"cannot support {self.interval}")
|
||||
return pd.DataFrame() if _result is None else _result
|
||||
|
||||
def collector_data(self):
|
||||
"""collector data"""
|
||||
logger.info("start collector yahoo data......")
|
||||
stock_list = self.stock_list
|
||||
for i in range(self._max_collector_count):
|
||||
if not stock_list:
|
||||
break
|
||||
logger.info(f"getting data: {i+1}")
|
||||
stock_list = self._collector(stock_list)
|
||||
logger.info(f"{i+1} finish.")
|
||||
for _symbol, _df_list in self._mini_symbol_map.items():
|
||||
self.save_stock(_symbol, pd.concat(_df_list, sort=False).drop_duplicates(["date"]).sort_values(["date"]))
|
||||
if self._mini_symbol_map:
|
||||
logger.warning(f"less than {self.min_numbers_trading} stock list: {list(self._mini_symbol_map.keys())}")
|
||||
logger.info(f"total {len(self.stock_list)}, error: {len(set(stock_list))}")
|
||||
|
||||
super(YahooCollector, self).collector_data()
|
||||
self.download_index_data()
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -329,11 +183,6 @@ class YahooCollector:
|
||||
"""download index data"""
|
||||
raise NotImplementedError("rewrite download_index_data")
|
||||
|
||||
@abc.abstractmethod
|
||||
def normalize_symbol(self, symbol: str):
|
||||
"""normalize symbol"""
|
||||
raise NotImplementedError("rewrite normalize_symbol")
|
||||
|
||||
|
||||
class YahooCollectorCN(YahooCollector, ABC):
|
||||
def get_stock_list(self):
|
||||
@@ -360,8 +209,8 @@ class YahooCollectorCN1d(YahooCollectorCN):
|
||||
def download_index_data(self):
|
||||
# TODO: from MSN
|
||||
_format = "%Y%m%d"
|
||||
_begin = self.yahoo_data.start_datetime.strftime(_format)
|
||||
_end = (self.yahoo_data.end_datetime + pd.Timedelta(days=-1)).strftime(_format)
|
||||
_begin = self.start_datetime.strftime(_format)
|
||||
_end = (self.end_datetime + pd.Timedelta(days=-1)).strftime(_format)
|
||||
for _index_name, _index_code in {"csi300": "000300", "csi100": "000903"}.items():
|
||||
logger.info(f"get bench data: {_index_name}({_index_code})......")
|
||||
try:
|
||||
@@ -396,7 +245,7 @@ class YahooCollectorCN1min(YahooCollectorCN):
|
||||
|
||||
def download_index_data(self):
|
||||
# TODO: 1m
|
||||
logger.warning(f"{self.__class__.__name__} {self._interval} does not support: download_index_data")
|
||||
logger.warning(f"{self.__class__.__name__} {self.interval} does not support: download_index_data")
|
||||
|
||||
|
||||
class YahooCollectorUS(YahooCollector, ABC):
|
||||
@@ -433,29 +282,10 @@ class YahooCollectorUS1min(YahooCollectorUS):
|
||||
return 60 * 6.5 * 5
|
||||
|
||||
|
||||
class YahooNormalize:
|
||||
class YahooNormalize(BaseNormalize):
|
||||
COLUMNS = ["open", "close", "high", "low", "volume"]
|
||||
DAILY_FORMAT = "%Y-%m-%d"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
date_field_name: str
|
||||
date field name, default is date
|
||||
symbol_field_name: str
|
||||
symbol field name, default is symbol
|
||||
"""
|
||||
self._date_field_name = date_field_name
|
||||
self._symbol_field_name = symbol_field_name
|
||||
|
||||
self._calendar_list = self._get_calendar_list()
|
||||
|
||||
@staticmethod
|
||||
def normalize_yahoo(
|
||||
df: pd.DataFrame,
|
||||
@@ -498,11 +328,6 @@ class YahooNormalize:
|
||||
df = self.adjusted_price(df)
|
||||
return df
|
||||
|
||||
@abc.abstractmethod
|
||||
def _get_calendar_list(self):
|
||||
"""Get benchmark calendar"""
|
||||
raise NotImplementedError("")
|
||||
|
||||
@abc.abstractmethod
|
||||
def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""adjusted price"""
|
||||
@@ -618,7 +443,9 @@ class YahooNormalize1min(YahooNormalize, ABC):
|
||||
# get 1d data from yahoo
|
||||
_start = pd.Timestamp(df[self._date_field_name].min()).strftime(self.DAILY_FORMAT)
|
||||
_end = (pd.Timestamp(df[self._date_field_name].max()) + pd.Timedelta(days=1)).strftime(self.DAILY_FORMAT)
|
||||
data_1d = YahooData.get_data_from_remote(self.symbol_to_yahoo(symbol), interval="1d", start=_start, end=_end)
|
||||
data_1d = YahooCollector.get_data_from_remote(
|
||||
self.symbol_to_yahoo(symbol), interval="1d", start=_start, end=_end
|
||||
)
|
||||
if data_1d is None or data_1d.empty:
|
||||
df["factor"] = 1
|
||||
# TODO: np.nan or 1 or 0
|
||||
@@ -723,62 +550,8 @@ class YahooNormalizeCN1min(YahooNormalizeCN, YahooNormalize1min):
|
||||
return get_calendar_list("ALL")
|
||||
|
||||
|
||||
class Normalize:
|
||||
def __init__(
|
||||
self,
|
||||
source_dir: [str, Path],
|
||||
target_dir: [str, Path],
|
||||
normalize_class: Type[YahooNormalize],
|
||||
max_workers: int = 16,
|
||||
date_field_name: str = "date",
|
||||
symbol_field_name: str = "symbol",
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
source_dir: str or Path
|
||||
The directory where the raw data collected from the Internet is saved
|
||||
target_dir: str or Path
|
||||
Directory for normalize data
|
||||
normalize_class: Type[YahooNormalize]
|
||||
normalize class
|
||||
max_workers: int
|
||||
Concurrent number, default is 16
|
||||
date_field_name: str
|
||||
date field name, default is date
|
||||
symbol_field_name: str
|
||||
symbol field name, default is symbol
|
||||
"""
|
||||
if not (source_dir and target_dir):
|
||||
raise ValueError("source_dir and target_dir cannot be None")
|
||||
self._source_dir = Path(source_dir).expanduser()
|
||||
self._target_dir = Path(target_dir).expanduser()
|
||||
self._target_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._max_workers = max_workers
|
||||
|
||||
self._normalize_obj = normalize_class(date_field_name=date_field_name, symbol_field_name=symbol_field_name)
|
||||
|
||||
def _executor(self, file_path: Path):
|
||||
file_path = Path(file_path)
|
||||
df = pd.read_csv(file_path)
|
||||
df = self._normalize_obj.normalize(df)
|
||||
if not df.empty:
|
||||
df.to_csv(self._target_dir.joinpath(file_path.name), index=False)
|
||||
|
||||
def normalize(self):
|
||||
logger.info("normalize data......")
|
||||
|
||||
with ProcessPoolExecutor(max_workers=self._max_workers) as worker:
|
||||
file_list = list(self._source_dir.glob("*.csv"))
|
||||
with tqdm(total=len(file_list)) as p_bar:
|
||||
for _ in worker.map(self._executor, file_list):
|
||||
p_bar.update()
|
||||
|
||||
|
||||
class Run:
|
||||
def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, region=REGION_CN):
|
||||
class Run(BaseRun):
|
||||
def __init__(self, source_dir=None, normalize_dir=None, max_workers=4, interval="1d", region=REGION_CN):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
@@ -789,23 +562,26 @@ class Run:
|
||||
Directory for normalize data, default "Path(__file__).parent/normalize"
|
||||
max_workers: int
|
||||
Concurrent number, default is 4
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1d
|
||||
region: str
|
||||
region, value from ["CN", "US"], default "CN"
|
||||
"""
|
||||
if source_dir is None:
|
||||
source_dir = CUR_DIR.joinpath("source")
|
||||
self.source_dir = Path(source_dir).expanduser().resolve()
|
||||
self.source_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if normalize_dir is None:
|
||||
normalize_dir = CUR_DIR.joinpath("normalize")
|
||||
self.normalize_dir = Path(normalize_dir).expanduser().resolve()
|
||||
self.normalize_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
self._cur_module = importlib.import_module("collector")
|
||||
self.max_workers = max_workers
|
||||
super().__init__(source_dir, normalize_dir, max_workers, interval)
|
||||
self.region = region
|
||||
|
||||
@property
|
||||
def collector_class_name(self):
|
||||
return f"YahooCollector{self.region.upper()}{self.interval}"
|
||||
|
||||
@property
|
||||
def normalize_class_name(self):
|
||||
return f"YahooNormalize{self.region.upper()}{self.interval}"
|
||||
|
||||
@property
|
||||
def default_base_dir(self) -> [Path, str]:
|
||||
return CUR_DIR
|
||||
|
||||
def download_data(
|
||||
self,
|
||||
max_collector_count=2,
|
||||
@@ -815,7 +591,6 @@ class Run:
|
||||
interval="1d",
|
||||
check_data_length=False,
|
||||
limit_nums=None,
|
||||
show_1min_logging=False,
|
||||
):
|
||||
"""download data from Internet
|
||||
|
||||
@@ -835,8 +610,6 @@ class Run:
|
||||
check data length, by default False
|
||||
limit_nums: int
|
||||
using for debug, by default None
|
||||
show_1min_logging: bool
|
||||
show 1m logging, by default False; if True, there may be many warning logs
|
||||
|
||||
Examples
|
||||
---------
|
||||
@@ -846,29 +619,13 @@ class Run:
|
||||
$ python collector.py download_data --source_dir ~/.qlib/stock_data/source --region CN --start 2020-11-01 --end 2020-11-10 --delay 0.1 --interval 1m
|
||||
"""
|
||||
|
||||
_class = getattr(
|
||||
self._cur_module, f"YahooCollector{self.region.upper()}{interval}"
|
||||
) # type: Type[YahooCollector]
|
||||
_class(
|
||||
self.source_dir,
|
||||
max_workers=self.max_workers,
|
||||
max_collector_count=max_collector_count,
|
||||
delay=delay,
|
||||
start=start,
|
||||
end=end,
|
||||
interval=interval,
|
||||
check_data_length=check_data_length,
|
||||
limit_nums=limit_nums,
|
||||
show_1min_logging=show_1min_logging,
|
||||
).collector_data()
|
||||
super(Run, self).download_data(max_collector_count, delay, start, end, interval, check_data_length, limit_nums)
|
||||
|
||||
def normalize_data(self, interval: str = "1d", date_field_name: str = "date", symbol_field_name: str = "symbol"):
|
||||
def normalize_data(self, date_field_name: str = "date", symbol_field_name: str = "symbol"):
|
||||
"""normalize data
|
||||
|
||||
Parameters
|
||||
----------
|
||||
interval: str
|
||||
freq, value from [1min, 1d], default 1d
|
||||
date_field_name: str
|
||||
date field name, default date
|
||||
symbol_field_name: str
|
||||
@@ -878,16 +635,7 @@ class Run:
|
||||
---------
|
||||
$ python collector.py normalize_data --source_dir ~/.qlib/stock_data/source --normalize_dir ~/.qlib/stock_data/normalize --region CN --interval 1d
|
||||
"""
|
||||
_class = getattr(self._cur_module, f"YahooNormalize{self.region.upper()}{interval}")
|
||||
yc = Normalize(
|
||||
source_dir=self.source_dir,
|
||||
target_dir=self.normalize_dir,
|
||||
normalize_class=_class,
|
||||
max_workers=self.max_workers,
|
||||
date_field_name=date_field_name,
|
||||
symbol_field_name=symbol_field_name,
|
||||
)
|
||||
yc.normalize()
|
||||
super(Run, self).normalize_data(date_field_name, symbol_field_name)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -96,7 +96,6 @@ port_analysis_config = {
|
||||
}
|
||||
|
||||
|
||||
# train
|
||||
def train():
|
||||
"""train model
|
||||
|
||||
@@ -111,6 +110,9 @@ def train():
|
||||
# model initiaiton
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
# To test __repr__
|
||||
print(dataset)
|
||||
print(R)
|
||||
|
||||
# start exp
|
||||
with R.start(experiment_name="workflow"):
|
||||
@@ -119,6 +121,10 @@ def train():
|
||||
|
||||
# prediction
|
||||
recorder = R.get_recorder()
|
||||
# To test __repr__
|
||||
print(recorder)
|
||||
# To test get_local_dir
|
||||
print(recorder.get_local_dir())
|
||||
rid = recorder.id
|
||||
sr = SignalRecord(model, dataset, recorder)
|
||||
sr.generate()
|
||||
@@ -133,6 +139,27 @@ def train():
|
||||
return pred_score, {"ic": ic, "ric": ric}, rid
|
||||
|
||||
|
||||
def fake_experiment():
|
||||
"""A fake experiment workflow to test uri
|
||||
|
||||
Returns
|
||||
-------
|
||||
pass_or_not_for_default_uri: bool
|
||||
pass_or_not_for_current_uri: bool
|
||||
temporary_exp_dir: str
|
||||
"""
|
||||
|
||||
# start exp
|
||||
default_uri = R.get_uri()
|
||||
current_uri = "file:./temp-test-exp-mag"
|
||||
with R.start(experiment_name="fake_workflow_for_expm", uri=current_uri):
|
||||
R.log_params(**flatten_dict(task))
|
||||
|
||||
current_uri_to_check = R.get_uri()
|
||||
default_uri_to_check = R.get_uri()
|
||||
return default_uri == default_uri_to_check, current_uri == current_uri_to_check, current_uri
|
||||
|
||||
|
||||
def backtest_analysis(pred, rid):
|
||||
"""backtest and analysis
|
||||
|
||||
@@ -181,6 +208,12 @@ class TestAllFlow(TestAutoData):
|
||||
"backtest failed",
|
||||
)
|
||||
|
||||
def test_2_expmanager(self):
|
||||
pass_default, pass_current, uri_path = fake_experiment()
|
||||
self.assertTrue(pass_default, msg="default uri is incorrect")
|
||||
self.assertTrue(pass_current, msg="current uri is incorrect")
|
||||
shutil.rmtree(str(Path(uri_path.strip("file:")).resolve()))
|
||||
|
||||
|
||||
def suite():
|
||||
_suite = unittest.TestSuite()
|
||||
|
||||
Reference in New Issue
Block a user