mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Merge remote-tracking branch 'qlib/main' into qlib_main
# Conflicts: # scripts/data_collector/yahoo/README.md
This commit is contained in:
@@ -68,7 +68,7 @@ Your feedbacks about the features are very important.
|
||||
# Framework of Qlib
|
||||
|
||||
<div style="align: center">
|
||||
<img src="http://fintech.msra.cn/images_v060/framework.png?v=0.1" />
|
||||
<img src="http://fintech.msra.cn/images_v060/framework.png?v=0.2" />
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
BIN
docs/_static/img/framework.png
vendored
BIN
docs/_static/img/framework.png
vendored
Binary file not shown.
|
Before Width: | Height: | Size: 271 KiB After Width: | Height: | Size: 208 KiB |
@@ -100,12 +100,19 @@ Converting CSV Format into Qlib Format
|
||||
|
||||
``Qlib`` has provided the script ``scripts/dump_bin.py`` to convert **any** data in CSV format into `.bin` files (``Qlib`` format) as long as they are in the correct format.
|
||||
|
||||
Users can download the demo china-stock data in CSV format as follows for reference to the CSV format.
|
||||
Besides downloading the prepared demo data, users could download demo data directly from the Collector as follows for reference to the CSV format.
|
||||
Here are some example:
|
||||
|
||||
.. code-block:: bash
|
||||
for daily data:
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/get_data.py csv_data_cn --target_dir ~/.qlib/csv_data/cn_data
|
||||
|
||||
for 1min data:
|
||||
.. code-block:: bash
|
||||
|
||||
python scripts/data_collector/yahoo/collector.py download_data --source_dir ~/.qlib/stock_data/source/cn_1min --region CN --start 2021-05-20 --end 2021-05-23 --delay 0.1 --interval 1min --limit_nums 10
|
||||
|
||||
Users can also provide their own data in CSV format. However, the CSV data **must satisfies** following criterions:
|
||||
|
||||
- CSV file is named after a specific stock *or* the CSV file includes a column of the stock name
|
||||
@@ -173,6 +180,16 @@ After conversion, users can find their Qlib format data in the directory `~/.qli
|
||||
|
||||
In the convention of `Qlib` data processing, `open, close, high, low, volume, money and factor` will be set to NaN if the stock is suspended.
|
||||
|
||||
Stock Pool (Market)
|
||||
--------------------------------
|
||||
|
||||
``Qlib`` defines `stock pool <https://github.com/microsoft/qlib/blob/main/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml#L4>`_ as stock list and their date ranges. Predefined stock pools (e.g. csi300) may be imported as follows.
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
python collector.py --index_name CSI300 --qlib_dir <user qlib data dir> --method parse_instruments
|
||||
|
||||
|
||||
Multiple Stock Modes
|
||||
--------------------------------
|
||||
|
||||
|
||||
@@ -101,7 +101,7 @@ Graphical Result
|
||||
- Axis Y:
|
||||
- `ic`
|
||||
The `Pearson correlation coefficient` series between `label` and `prediction score`.
|
||||
In the above example, the `label` is formulated as `Ref($close, -1)/$close - 1`. Please refer to `Data Featrue <data.html#feature>`_ for more details.
|
||||
In the above example, the `label` is formulated as `Ref($close, -1)/$close - 1`. Please refer to `Data Feature <data.html#feature>`_ for more details.
|
||||
|
||||
- `rank_ic`
|
||||
The `Spearman's rank correlation coefficient` series between `label` and `prediction score`.
|
||||
|
||||
@@ -111,8 +111,6 @@ Usage & Example
|
||||
pred_score, strategy=strategy, **BACKTEST_CONFIG
|
||||
)
|
||||
|
||||
Also, the above example has been given in ``examples/train_backtest_analyze.ipynb``.
|
||||
|
||||
To know more about the `prediction score` `pred_score` output by ``Forecast Model``, please refer to `Forecast Model: Model Training & Prediction <model.html>`_.
|
||||
|
||||
To know more about ``Intraday Trading``, please refer to `Intraday Trading: Model&Strategy Testing <backtest.html>`_.
|
||||
|
||||
@@ -82,7 +82,7 @@ The Custom models need to inherit `qlib.model.base.Model <../reference/api.html#
|
||||
return pd.Series(self.model.predict(x_test.values), index=x_test.index)
|
||||
|
||||
- Override the `finetune` method (Optional)
|
||||
- This method is optional to the users, and when users one to use this method on their own models, they should inherit the ``ModelFT`` base class, which includes the interface of `finetune`.
|
||||
- This method is optional to the users. When users want to use this method on their own models, they should inherit the ``ModelFT`` base class, which includes the interface of `finetune`.
|
||||
- The parameters must include the parameter `dataset`.
|
||||
- Code Example: In the following example, users will use `LightGBM` as the model and finetune it.
|
||||
.. code-block:: Python
|
||||
|
||||
52
examples/benchmarks/TCTS/README.md
Normal file
52
examples/benchmarks/TCTS/README.md
Normal file
@@ -0,0 +1,52 @@
|
||||
# Temporally Correlated Task Scheduling for Sequence Learning
|
||||
We provide the [code](https://github.com/microsoft/qlib/blob/main/qlib/contrib/model/pytorch_tcts.py) for reproducing the stock trend forecasting experiments.
|
||||
|
||||
### Background
|
||||
Sequence learning has attracted much research attention from the machine learning community in recent years. In many applications, a sequence learning task is usually associated with multiple temporally correlated auxiliary tasks, which are different in terms of how much input information to use or which future step to predict. In stock trend forecasting, as demonstrated in Figure1, one can predict the price of a stock in different future days (e.g., tomorrow, the day after tomorrow). In this paper, we propose a framework to make use of those temporally correlated tasks to help each other.
|
||||
|
||||
<p align="center">
|
||||
<img src="task_description.png" width="600" height="200"/>
|
||||
</p>
|
||||
|
||||
|
||||
### Method
|
||||
Given that there are usually multiple temporally correlated tasks, the key challenge lies in which tasks to use and when to use them in the training process. In this work, we introduce a learnable task scheduler for sequence learning, which adaptively selects temporally correlated tasks during the training process. The scheduler accesses the model status and the current training data (e.g., in current minibatch), and selects the best auxiliary task to help the training of the main task. The scheduler and the model for the main task are jointly trained through bi-level optimization: the scheduler is trained to maximize the validation performance of the model, and the model is trained to minimize the training loss guided by the scheduler. The process is demonstrated in Figure2.
|
||||
|
||||
<p align="center">
|
||||
<img src="workflow.png"/>
|
||||
</p>
|
||||
|
||||
At step <img src="https://render.githubusercontent.com/render/math?math=s">, with training data <img src="https://render.githubusercontent.com/render/math?math=x_s,y_s">, the scheduler <img src="https://render.githubusercontent.com/render/math?math=\varphi"> chooses a suitable task <img src="https://render.githubusercontent.com/render/math?math=T_{i_s}"> (green solid lines) to update the model <img src="https://render.githubusercontent.com/render/math?math=f"> (blue solid lines). After <img src="https://render.githubusercontent.com/render/math?math=S"> steps, we evaluate the model <img src="https://render.githubusercontent.com/render/math?math=f"> on the validation set and update the scheduler <img src="https://render.githubusercontent.com/render/math?math=\varphi"> (green dashed lines).
|
||||
|
||||
### DataSet
|
||||
* We use the historical transaction data for 300 stocks on [CSI300](http://www.csindex.com.cn/en/indices/index-detail/000300) from 01/01/2008 to 08/01/2020.
|
||||
* We split the data into training (01/01/2008-12/31/2013), validation (01/01/2014-12/31/2015), and test sets (01/01/2016-08/01/2020) based on the transaction time.
|
||||
|
||||
### Experiments
|
||||
#### Task Description
|
||||
* The main tasks <img src="https://render.githubusercontent.com/render/math?math=T_k"> (<img src="https://render.githubusercontent.com/render/math?math=task_k"> in Figure1) refers to forecasting return of stock <img src="https://render.githubusercontent.com/render/math?math=i"> as following,
|
||||
<div align=center>
|
||||
<img src="https://render.githubusercontent.com/render/math?math=r_{i}^k = \frac{\price_i^{t+k}}{\price_i^{t+k-1}} - 1">
|
||||
</div>
|
||||
|
||||
* Temporally correlated task sets <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_k = \{T_1, T_2, ... , T_k\}">, in this paper, <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">, <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5"> and <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_10"> are used.
|
||||
#### Baselines
|
||||
* GRU/MLP/LightGBM (LGB)/Graph Attention Networks (GAT)
|
||||
* Multi-task learning (MTL): In multi-task learning, multiple tasks are jointly trained and mutually boosted. Each task is treated equally, while in our setting, we focus on the main task.
|
||||
* Curriculum transfer learning (CL): Transfer learning also leverages auxiliary tasks to boost the main task. [Curriculum transfer learning](https://arxiv.org/pdf/1804.00810.pdf) is one kind of transfer learning which schedules auxiliary tasks according to certain rules. Our problem can also be regarded as a special kind of transfer learning, where the auxiliary tasks are temporally correlated with the main task. Our learning process is dynamically controlled by a scheduler rather than some pre-defined rules. In the CL baseline, we start from the task <img src="https://render.githubusercontent.com/render/math?math=T_1" >, then <img src="https://render.githubusercontent.com/render/math?math=T_2" >, and gradually move to the last one.
|
||||
#### Result
|
||||
| Methods | <img src="https://render.githubusercontent.com/render/math?math=T_1" > | <img src="https://render.githubusercontent.com/render/math?math=T_2"> | <img src="https://render.githubusercontent.com/render/math?math=T_3"> |
|
||||
| :----: | :----: | :----: | :----: |
|
||||
| GRU | 0.049 / 1.903 | 0.018 / 1.972 | 0.014 / 1.989 |
|
||||
| MLP | 0.023 / 1.961 | 0.022 / 1.962 | 0.015 / 1.978 |
|
||||
| LGB | 0.038 / 1.883 | 0.023 / 1.952 | 0.007 / 1.987 |
|
||||
| GAT | 0.052 / 1.898 | 0.024 / 1.954 | 0.015 / 1.973 |
|
||||
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.061 / 1.862 | 0.023 / 1.942 | 0.012 / 1.956 |
|
||||
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.051 / 1.880 | 0.028 / 1.941 | 0.016 / 1.962 |
|
||||
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.071 / 1.851 | 0.030 / 1.939 | 0.017 / 1.963 |
|
||||
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.057 / 1.875 | 0.021 / 1.939 | 0.017 / 1.959 |
|
||||
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.056 / 1.877 | 0.028 / 1.942 | 0.015 / 1.962 |
|
||||
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.075 / 1.849 | 0.032 /1.939 | 0.021 / 1.955 |
|
||||
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.052 / 1.882 | 0.020 / 1.947 | 0.019 / 1.952 |
|
||||
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.051 / 1.882 | 0.028 / 1.950 | 0.016 / 1.961 |
|
||||
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.067 / 1.867 | 0.030 / 1.960 | 0.022 / 1.942|
|
||||
52
examples/benchmarks/TCTS/TCTS.md
Normal file
52
examples/benchmarks/TCTS/TCTS.md
Normal file
@@ -0,0 +1,52 @@
|
||||
# Temporally Correlated Task Scheduling for Sequence Learning
|
||||
We provide the [code](https://github.com/microsoft/qlib/blob/main/qlib/contrib/model/pytorch_tcts.py) for reproducing the stock trend forecasting experiments.
|
||||
|
||||
### Background
|
||||
Sequence learning has attracted much research attention from the machine learning community in recent years. In many applications, a sequence learning task is usually associated with multiple temporally correlated auxiliary tasks, which are different in terms of how much input information to use or which future step to predict. In stock trend forecasting, as demonstrated in Figure1, one can predict the price of a stock in different future days (e.g., tomorrow, the day after tomorrow). In this paper, we propose a framework to make use of those temporally correlated tasks to help each other.
|
||||
|
||||
<p align="center">
|
||||
<img src="task_description.png" width="600" height="200"/>
|
||||
</p>
|
||||
|
||||
|
||||
### Method
|
||||
Given that there are usually multiple temporally correlated tasks, the key challenge lies in which tasks to use and when to use them in the training process. In this work, we introduce a learnable task scheduler for sequence learning, which adaptively selects temporally correlated tasks during the training process. The scheduler accesses the model status and the current training data (e.g., in current minibatch), and selects the best auxiliary task to help the training of the main task. The scheduler and the model for the main task are jointly trained through bi-level optimization: the scheduler is trained to maximize the validation performance of the model, and the model is trained to minimize the training loss guided by the scheduler. The process is demonstrated in Figure2.
|
||||
|
||||
<p align="center">
|
||||
<img src="workflow.png"/>
|
||||
</p>
|
||||
|
||||
At step <img src="https://render.githubusercontent.com/render/math?math=s">, with training data <img src="https://render.githubusercontent.com/render/math?math=x_s,y_s">, the scheduler <img src="https://render.githubusercontent.com/render/math?math=\varphi"> chooses a suitable task <img src="https://render.githubusercontent.com/render/math?math=T_{i_s}"> (green solid lines) to update the model <img src="https://render.githubusercontent.com/render/math?math=f"> (blue solid lines). After <img src="https://render.githubusercontent.com/render/math?math=S"> steps, we evaluate the model <img src="https://render.githubusercontent.com/render/math?math=f"> on the validation set and update the scheduler <img src="https://render.githubusercontent.com/render/math?math=\varphi"> (green dashed lines).
|
||||
|
||||
### DataSet
|
||||
* We use the historical transaction data for 300 stocks on [CSI300](http://www.csindex.com.cn/en/indices/index-detail/000300) from 01/01/2008 to 08/01/2020.
|
||||
* We split the data into training (01/01/2008-12/31/2013), validation (01/01/2014-12/31/2015), and test sets (01/01/2016-08/01/2020) based on the transaction time.
|
||||
|
||||
### Experiments
|
||||
#### Task Description
|
||||
* The main tasks <img src="https://render.githubusercontent.com/render/math?math=T_k"> (<img src="https://render.githubusercontent.com/render/math?math=task_k"> in Figure1) refers to forecasting return of stock <img src="https://render.githubusercontent.com/render/math?math=i"> as following,
|
||||
<div align=center>
|
||||
<img src="https://render.githubusercontent.com/render/math?math=r_{i}^k = \frac{\price_i^{t+k}}{\price_i^{t+k-1}} - 1">
|
||||
</div>
|
||||
|
||||
* Temporally correlated task sets <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_k = \{T_1, T_2, ... , T_k\}">, in this paper, <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">, <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5"> and <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_10"> are used.
|
||||
#### Baselines
|
||||
* GRU/MLP/LightGBM (LGB)/Graph Attention Networks (GAT)
|
||||
* Multi-task learning (MTL): In multi-task learning, multiple tasks are jointly trained and mutually boosted. Each task is treated equally, while in our setting, we focus on the main task.
|
||||
* Curriculum transfer learning (CL): Transfer learning also leverages auxiliary tasks to boost the main task. [Curriculum transfer learning](https://arxiv.org/pdf/1804.00810.pdf) is one kind of transfer learning which schedules auxiliary tasks according to certain rules. Our problem can also be regarded as a special kind of transfer learning, where the auxiliary tasks are temporally correlated with the main task. Our learning process is dynamically controlled by a scheduler rather than some pre-defined rules. In the CL baseline, we start from the task <img src="https://render.githubusercontent.com/render/math?math=T_1" >, then <img src="https://render.githubusercontent.com/render/math?math=T_2" >, and gradually move to the last one.
|
||||
#### Result
|
||||
| Methods | <img src="https://render.githubusercontent.com/render/math?math=T_1" > | <img src="https://render.githubusercontent.com/render/math?math=T_2"> | <img src="https://render.githubusercontent.com/render/math?math=T_3"> |
|
||||
| :----: | :----: | :----: | :----: |
|
||||
| GRU | 0.049 / 1.903 | 0.018 / 1.972 | 0.014 / 1.989 |
|
||||
| MLP | 0.023 / 1.961 | 0.022 / 1.962 | 0.015 / 1.978 |
|
||||
| LGB | 0.038 / 1.883 | 0.023 / 1.952 | 0.007 / 1.987 |
|
||||
| GAT | 0.052 / 1.898 | 0.024 / 1.954 | 0.015 / 1.973 |
|
||||
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.061 / 1.862 | 0.023 / 1.942 | 0.012 / 1.956 |
|
||||
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.051 / 1.880 | 0.028 / 1.941 | 0.016 / 1.962 |
|
||||
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.071 / 1.851 | 0.030 / 1.939 | 0.017 / 1.963 |
|
||||
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.057 / 1.875 | 0.021 / 1.939 | 0.017 / 1.959 |
|
||||
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.056 / 1.877 | 0.028 / 1.942 | 0.015 / 1.962 |
|
||||
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.075 / 1.849 | 0.032 /1.939 | 0.021 / 1.955 |
|
||||
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.052 / 1.882 | 0.020 / 1.947 | 0.019 / 1.952 |
|
||||
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.051 / 1.882 | 0.028 / 1.950 | 0.016 / 1.961 |
|
||||
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.067 / 1.867 | 0.030 / 1.960 | 0.022 / 1.942|
|
||||
BIN
examples/benchmarks/TCTS/task_description.png
Normal file
BIN
examples/benchmarks/TCTS/task_description.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 25 KiB |
BIN
examples/benchmarks/TCTS/workflow.png
Normal file
BIN
examples/benchmarks/TCTS/workflow.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 29 KiB |
93
examples/benchmarks/TCTS/workflow_config_tcts_Alpha360.yaml
Normal file
93
examples/benchmarks/TCTS/workflow_config_tcts_Alpha360.yaml
Normal file
@@ -0,0 +1,93 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1",
|
||||
"Ref($close, -3) / Ref($close, -1) - 1",
|
||||
"Ref($close, -4) / Ref($close, -1) - 1",
|
||||
"Ref($close, -5) / Ref($close, -1) - 1",
|
||||
"Ref($close, -6) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: TCTS
|
||||
module_path: qlib.contrib.model.pytorch_tcts
|
||||
kwargs:
|
||||
d_feat: 6
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
dropout: 0.0
|
||||
n_epochs: 200
|
||||
lr: 1e-3
|
||||
early_stop: 20
|
||||
batch_size: 800
|
||||
metric: loss
|
||||
loss: mse
|
||||
GPU: 0
|
||||
fore_optimizer: adam
|
||||
weight_optimizer: adam
|
||||
output_dim: 5
|
||||
fore_lr: 5e-7
|
||||
weight_lr: 5e-7
|
||||
steps: 3
|
||||
target_label: 0
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -4,6 +4,7 @@
|
||||
"""
|
||||
This example shows how a TrainerRM works based on TaskManager with rolling tasks.
|
||||
After training, how to collect the rolling results will be shown in task_collecting.
|
||||
Based on the ability of TaskManager, `worker` method offer a simple way for multiprocessing.
|
||||
"""
|
||||
|
||||
from pprint import pprint
|
||||
@@ -13,7 +14,7 @@ import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.workflow.task.manage import TaskManager, run_task
|
||||
from qlib.workflow.task.collect import RecorderCollector
|
||||
from qlib.model.ens.group import RollingGroup
|
||||
from qlib.model.trainer import TrainerRM
|
||||
@@ -68,6 +69,11 @@ class RollingTaskExample:
|
||||
trainer = TrainerRM(self.experiment_name, self.task_pool)
|
||||
trainer.train(tasks)
|
||||
|
||||
def worker(self):
|
||||
# train tasks by other progress or machines for multiprocessing. It is same as TrainerRM.worker.
|
||||
print("========== worker ==========")
|
||||
run_task(task_train, self.task_pool, experiment_name=self.experiment_name)
|
||||
|
||||
def task_collecting(self):
|
||||
print("========== task_collecting ==========")
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
This example is about how can simulate the OnlineManager based on rolling tasks.
|
||||
"""
|
||||
|
||||
from pprint import pprint
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM
|
||||
@@ -13,7 +14,7 @@ from qlib.workflow.online.manager import OnlineManager
|
||||
from qlib.workflow.online.strategy import RollingStrategy
|
||||
from qlib.workflow.task.gen import RollingGen
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG
|
||||
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG_ONLINE, CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE
|
||||
|
||||
|
||||
class OnlineSimulationExample:
|
||||
@@ -22,8 +23,8 @@ class OnlineSimulationExample:
|
||||
provider_uri="~/.qlib/qlib_data/cn_data",
|
||||
region="cn",
|
||||
exp_name="rolling_exp",
|
||||
task_url="mongodb://10.0.0.4:27017/",
|
||||
task_db_name="rolling_db",
|
||||
task_url="mongodb://10.0.0.4:27017/", # not necessary when using TrainerR or DelayTrainerR
|
||||
task_db_name="rolling_db", # not necessary when using TrainerR or DelayTrainerR
|
||||
task_pool="rolling_task",
|
||||
rolling_step=80,
|
||||
start_time="2018-09-10",
|
||||
@@ -46,7 +47,7 @@ class OnlineSimulationExample:
|
||||
tasks (dict or list[dict]): a set of the task config waiting for rolling and training
|
||||
"""
|
||||
if tasks is None:
|
||||
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG]
|
||||
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE, CSI100_RECORD_LGB_TASK_CONFIG_ONLINE]
|
||||
self.exp_name = exp_name
|
||||
self.task_pool = task_pool
|
||||
self.start_time = start_time
|
||||
@@ -59,7 +60,7 @@ class OnlineSimulationExample:
|
||||
self.rolling_gen = RollingGen(
|
||||
step=rolling_step, rtype=RollingGen.ROLL_SD, ds_extra_mod_func=None
|
||||
) # The rolling tasks generator, ds_extra_mod_func is None because we just need to simulate to 2018-10-31 and needn't change the handler end time.
|
||||
self.trainer = DelayTrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR
|
||||
self.trainer = TrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR
|
||||
self.rolling_online_manager = OnlineManager(
|
||||
RollingStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen),
|
||||
trainer=self.trainer,
|
||||
@@ -85,6 +86,15 @@ class OnlineSimulationExample:
|
||||
print("========== signals ==========")
|
||||
print(self.rolling_online_manager.get_signals())
|
||||
|
||||
def worker(self):
|
||||
# train tasks by other progress or machines for multiprocessing
|
||||
# FIXME: only can call after finishing simulation when using DelayTrainerRM, or there will be some exception.
|
||||
print("========== worker ==========")
|
||||
if isinstance(self.trainer, TrainerRM):
|
||||
self.trainer.worker()
|
||||
else:
|
||||
print(f"{type(self.trainer)} is not supported for worker.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
## to run all workflow automatically with your own parameters, use the command below
|
||||
|
||||
@@ -13,11 +13,13 @@ Finally, the OnlineManager will finish second routine and update all strategies.
|
||||
import os
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM, end_task_train, task_train
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.online.strategy import RollingStrategy
|
||||
from qlib.workflow.task.gen import RollingGen
|
||||
from qlib.workflow.online.manager import OnlineManager
|
||||
from qlib.tests.config import CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG
|
||||
from qlib.tests.config import CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING, CSI100_RECORD_LGB_TASK_CONFIG_ROLLING
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
|
||||
|
||||
class RollingOnlineExample:
|
||||
@@ -25,16 +27,17 @@ class RollingOnlineExample:
|
||||
self,
|
||||
provider_uri="~/.qlib/qlib_data/cn_data",
|
||||
region="cn",
|
||||
task_url="mongodb://10.0.0.4:27017/",
|
||||
task_db_name="rolling_db",
|
||||
trainer=DelayTrainerRM(), # you can choose from TrainerR, TrainerRM, DelayTrainerR, DelayTrainerRM
|
||||
task_url="mongodb://10.0.0.4:27017/", # not necessary when using TrainerR or DelayTrainerR
|
||||
task_db_name="rolling_db", # not necessary when using TrainerR or DelayTrainerR
|
||||
rolling_step=550,
|
||||
tasks=None,
|
||||
add_tasks=None,
|
||||
):
|
||||
if add_tasks is None:
|
||||
add_tasks = [CSI100_RECORD_LGB_TASK_CONFIG]
|
||||
add_tasks = [CSI100_RECORD_LGB_TASK_CONFIG_ROLLING]
|
||||
if tasks is None:
|
||||
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG]
|
||||
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING]
|
||||
mongo_conf = {
|
||||
"task_url": task_url, # your MongoDB url
|
||||
"task_db_name": task_db_name, # database name
|
||||
@@ -53,17 +56,28 @@ class RollingOnlineExample:
|
||||
RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD),
|
||||
)
|
||||
)
|
||||
|
||||
self.rolling_online_manager = OnlineManager(strategies)
|
||||
self.trainer = trainer
|
||||
self.rolling_online_manager = OnlineManager(strategies, trainer=self.trainer)
|
||||
|
||||
_ROLLING_MANAGER_PATH = (
|
||||
".RollingOnlineExample" # the OnlineManager will dump to this file, for it can be loaded when calling routine.
|
||||
)
|
||||
|
||||
def worker(self):
|
||||
# train tasks by other progress or machines for multiprocessing
|
||||
print("========== worker ==========")
|
||||
if isinstance(self.trainer, TrainerRM):
|
||||
for task in self.tasks + self.add_tasks:
|
||||
name_id = task["model"]["class"]
|
||||
self.trainer.worker(experiment_name=name_id)
|
||||
else:
|
||||
print(f"{type(self.trainer)} is not supported for worker.")
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(self):
|
||||
for task in self.tasks + self.add_tasks:
|
||||
name_id = task["model"]["class"]
|
||||
TaskManager(task_pool=name_id).remove()
|
||||
exp = R.get_exp(experiment_name=name_id)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
@@ -20,11 +20,17 @@ def init(default_conf="client", **kwargs):
|
||||
from .config import C
|
||||
from .data.cache import H
|
||||
|
||||
H.clear()
|
||||
|
||||
# FIXME: this logger ignored the level in config
|
||||
logger = get_module_logger("Initialization", level=logging.INFO)
|
||||
|
||||
skip_if_reg = kwargs.pop("skip_if_reg", False)
|
||||
if skip_if_reg and C.registered:
|
||||
# if we reinitialize Qlib during running an experiment `R.start`.
|
||||
# it will result in loss of the recorder
|
||||
logger.warning("Skip initialization because `skip_if_reg is True`")
|
||||
return
|
||||
|
||||
H.clear()
|
||||
C.set(default_conf, **kwargs)
|
||||
|
||||
# check path if server/local
|
||||
@@ -197,14 +203,15 @@ def auto_init(**kwargs):
|
||||
- Find the project configuration and init qlib
|
||||
- The parsing process will be affected by the `conf_type` of the configuration file
|
||||
- Init qlib with default config
|
||||
- Skip initialization if already initialized
|
||||
"""
|
||||
kwargs["skip_if_reg"] = kwargs.get("skip_if_reg", True)
|
||||
|
||||
try:
|
||||
pp = get_project_path(cur_path=kwargs.pop("cur_path", None))
|
||||
except FileNotFoundError:
|
||||
init(**kwargs)
|
||||
else:
|
||||
|
||||
conf_pp = pp / "config.yaml"
|
||||
with conf_pp.open() as f:
|
||||
conf = yaml.safe_load(f)
|
||||
|
||||
393
qlib/contrib/model/pytorch_tcts.py
Normal file
393
qlib/contrib/model/pytorch_tcts.py
Normal file
@@ -0,0 +1,393 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
|
||||
|
||||
class TCTS(Model):
|
||||
"""TCTS Model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
d_feat : int
|
||||
input dimension for each time step
|
||||
metric: str
|
||||
the evaluate metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_feat=6,
|
||||
hidden_size=64,
|
||||
num_layers=2,
|
||||
dropout=0.0,
|
||||
n_epochs=200,
|
||||
batch_size=2000,
|
||||
early_stop=20,
|
||||
loss="mse",
|
||||
fore_optimizer="adam",
|
||||
weight_optimizer="adam",
|
||||
output_dim=5,
|
||||
fore_lr=5e-7,
|
||||
weight_lr=5e-7,
|
||||
steps=3,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
target_label=0,
|
||||
**kwargs
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("TCTS")
|
||||
self.logger.info("TCTS pytorch version...")
|
||||
|
||||
# set hyper-parameters.
|
||||
self.d_feat = d_feat
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.dropout = dropout
|
||||
self.n_epochs = n_epochs
|
||||
self.batch_size = batch_size
|
||||
self.early_stop = early_stop
|
||||
self.loss = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.seed = seed
|
||||
self.output_dim = output_dim
|
||||
self.fore_lr = fore_lr
|
||||
self.weight_lr = weight_lr
|
||||
self.steps = steps
|
||||
self.target_label = target_label
|
||||
|
||||
self.logger.info(
|
||||
"TCTS parameters setting:"
|
||||
"\nd_feat : {}"
|
||||
"\nhidden_size : {}"
|
||||
"\nnum_layers : {}"
|
||||
"\ndropout : {}"
|
||||
"\nn_epochs : {}"
|
||||
"\nbatch_size : {}"
|
||||
"\nearly_stop : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
d_feat,
|
||||
hidden_size,
|
||||
num_layers,
|
||||
dropout,
|
||||
n_epochs,
|
||||
batch_size,
|
||||
early_stop,
|
||||
loss,
|
||||
GPU,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
)
|
||||
)
|
||||
|
||||
if self.seed is not None:
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
self.fore_model = GRUModel(
|
||||
d_feat=self.d_feat,
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
)
|
||||
self.weight_model = MLPModel(
|
||||
d_feat=360 + 2 * self.output_dim + 1,
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
output_dim=self.output_dim,
|
||||
)
|
||||
if fore_optimizer.lower() == "adam":
|
||||
self.fore_optimizer = optim.Adam(self.fore_model.parameters(), lr=self.fore_lr)
|
||||
elif fore_optimizer.lower() == "gd":
|
||||
self.fore_optimizer = optim.SGD(self.fore_model.parameters(), lr=self.fore_lr)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(fore_optimizer))
|
||||
if weight_optimizer.lower() == "adam":
|
||||
self.weight_optimizer = optim.Adam(self.weight_model.parameters(), lr=self.weight_lr)
|
||||
elif weight_optimizer.lower() == "gd":
|
||||
self.weight_optimizer = optim.SGD(self.weight_model.parameters(), lr=self.weight_lr)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(weight_optimizer))
|
||||
|
||||
self.fitted = False
|
||||
self.fore_model.to(self.device)
|
||||
self.weight_model.to(self.device)
|
||||
|
||||
def loss_fn(self, pred, label, weight):
|
||||
|
||||
loc = torch.argmax(weight, 1)
|
||||
loss = (pred - label[np.arange(weight.shape[0]), loc]) ** 2
|
||||
return torch.mean(loss)
|
||||
|
||||
def train_epoch(self, x_train, y_train, x_valid, y_valid):
|
||||
|
||||
x_train_values = x_train.values
|
||||
y_train_values = np.squeeze(y_train.values)
|
||||
|
||||
indices = np.arange(len(x_train_values))
|
||||
np.random.shuffle(indices)
|
||||
|
||||
init_fore_model = copy.deepcopy(self.fore_model)
|
||||
for p in init_fore_model.parameters():
|
||||
p.init_fore_model = False
|
||||
|
||||
self.fore_model.train()
|
||||
self.weight_model.train()
|
||||
|
||||
for p in self.weight_model.parameters():
|
||||
p.requires_grad = False
|
||||
for p in self.fore_model.parameters():
|
||||
p.requires_grad = True
|
||||
|
||||
for i in range(self.steps):
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
|
||||
init_pred = init_fore_model(feature)
|
||||
pred = self.fore_model(feature)
|
||||
|
||||
dis = init_pred - label.transpose(0, 1)
|
||||
weight_feature = torch.cat((feature, dis.transpose(0, 1), label, init_pred.view(-1, 1)), 1)
|
||||
weight = self.weight_model(weight_feature)
|
||||
|
||||
loss = self.loss_fn(pred, label, weight) # hard
|
||||
|
||||
self.fore_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(self.fore_model.parameters(), 3.0)
|
||||
self.fore_optimizer.step()
|
||||
|
||||
x_valid_values = x_valid.values
|
||||
y_valid_values = np.squeeze(y_valid.values)
|
||||
|
||||
indices = np.arange(len(x_valid_values))
|
||||
np.random.shuffle(indices)
|
||||
for p in self.weight_model.parameters():
|
||||
p.requires_grad = True
|
||||
for p in self.fore_model.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
# fix forecasting model and valid weight model
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
feature = torch.from_numpy(x_valid_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_valid_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
|
||||
pred = self.fore_model(feature)
|
||||
dis = pred - label.transpose(0, 1)
|
||||
weight_feature = torch.cat((feature, dis.transpose(0, 1), label, pred.view(-1, 1)), 1)
|
||||
weight = self.weight_model(weight_feature)
|
||||
loc = torch.argmax(weight, 1)
|
||||
valid_loss = torch.mean((pred - label[:, 0]) ** 2)
|
||||
loss = torch.mean(-valid_loss * torch.log(weight[np.arange(weight.shape[0]), loc]))
|
||||
|
||||
self.weight_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(self.weight_model.parameters(), 3.0)
|
||||
self.weight_optimizer.step()
|
||||
|
||||
def test_epoch(self, data_x, data_y):
|
||||
|
||||
# prepare training data
|
||||
x_values = data_x.values
|
||||
y_values = np.squeeze(data_y.values)
|
||||
|
||||
self.fore_model.eval()
|
||||
|
||||
scores = []
|
||||
losses = []
|
||||
|
||||
indices = np.arange(len(x_values))
|
||||
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
|
||||
pred = self.fore_model(feature)
|
||||
loss = torch.mean((pred - label[:, abs(self.target_label)]) ** 2)
|
||||
losses.append(loss.item())
|
||||
|
||||
return np.mean(losses)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
df_train, df_valid, df_test = dataset.prepare(
|
||||
["train", "valid", "test"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
x_test, y_test = df_test["feature"], df_test["label"]
|
||||
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
|
||||
best_loss = np.inf
|
||||
best_epoch = 0
|
||||
stop_round = 0
|
||||
fore_best_param = copy.deepcopy(self.fore_optimizer.state_dict())
|
||||
weight_best_param = copy.deepcopy(self.weight_optimizer.state_dict())
|
||||
|
||||
for epoch in range(self.n_epochs):
|
||||
print("Epoch:", epoch)
|
||||
|
||||
print("training...")
|
||||
self.train_epoch(x_train, y_train, x_valid, y_valid)
|
||||
print("evaluating...")
|
||||
val_loss = self.test_epoch(x_valid, y_valid)
|
||||
test_loss = self.test_epoch(x_test, y_test)
|
||||
|
||||
print("valid %.6f, test %.6f" % (val_loss, test_loss))
|
||||
|
||||
if val_loss < best_loss:
|
||||
best_loss = val_loss
|
||||
stop_round = 0
|
||||
best_epoch = epoch
|
||||
torch.save(copy.deepcopy(self.fore_model.state_dict()), save_path + "_fore_model.bin")
|
||||
torch.save(copy.deepcopy(self.weight_model.state_dict()), save_path + "_weight_model.bin")
|
||||
|
||||
else:
|
||||
stop_round += 1
|
||||
if stop_round >= self.early_stop:
|
||||
print("early stop")
|
||||
break
|
||||
|
||||
print("best loss:", best_loss, "@", best_epoch)
|
||||
best_param = torch.load(save_path + "_fore_model.bin")
|
||||
self.fore_model.load_state_dict(best_param)
|
||||
best_param = torch.load(save_path + "_weight_model.bin")
|
||||
self.weight_model.load_state_dict(best_param)
|
||||
self.fitted = True
|
||||
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
index = x_test.index
|
||||
self.fore_model.eval()
|
||||
x_values = x_test.values
|
||||
sample_num = x_values.shape[0]
|
||||
preds = []
|
||||
|
||||
for begin in range(sample_num)[:: self.batch_size]:
|
||||
|
||||
if sample_num - begin < self.batch_size:
|
||||
end = sample_num
|
||||
else:
|
||||
end = begin + self.batch_size
|
||||
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_gpu:
|
||||
pred = self.fore_model(x_batch).detach().cpu().numpy()
|
||||
else:
|
||||
pred = self.fore_model(x_batch).detach().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
return pd.Series(np.concatenate(preds), index=index)
|
||||
|
||||
|
||||
class MLPModel(nn.Module):
|
||||
def __init__(self, d_feat, hidden_size=256, num_layers=3, dropout=0.0, output_dim=1):
|
||||
super().__init__()
|
||||
|
||||
self.mlp = nn.Sequential()
|
||||
self.softmax = nn.Softmax(dim=1)
|
||||
|
||||
for i in range(num_layers):
|
||||
if i > 0:
|
||||
self.mlp.add_module("drop_%d" % i, nn.Dropout(dropout))
|
||||
self.mlp.add_module("fc_%d" % i, nn.Linear(d_feat if i == 0 else hidden_size, hidden_size))
|
||||
self.mlp.add_module("relu_%d" % i, nn.ReLU())
|
||||
|
||||
self.mlp.add_module("fc_out", nn.Linear(hidden_size, output_dim))
|
||||
|
||||
def forward(self, x):
|
||||
# feature
|
||||
# [N, F]
|
||||
out = self.mlp(x).squeeze()
|
||||
out = self.softmax(out)
|
||||
return out
|
||||
|
||||
|
||||
class GRUModel(nn.Module):
|
||||
def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0):
|
||||
super().__init__()
|
||||
|
||||
self.rnn = nn.GRU(
|
||||
input_size=d_feat,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.fc_out = nn.Linear(hidden_size, 1)
|
||||
|
||||
self.d_feat = d_feat
|
||||
|
||||
def forward(self, x):
|
||||
# x: [N, F*T]
|
||||
x = x.reshape(len(x), self.d_feat, -1) # [N, F, T]
|
||||
x = x.permute(0, 2, 1) # [N, T, F]
|
||||
out, _ = self.rnn(x)
|
||||
return self.fc_out(out[:, -1, :]).squeeze()
|
||||
@@ -148,7 +148,6 @@ class WeightStrategyBase(BaseStrategy, AdjustTimer):
|
||||
pred score for this trade date, index is stock_id, contain 'score' column.
|
||||
current : Position()
|
||||
current position.
|
||||
trade_exchange : Exchange()
|
||||
trade_date : pd.Timestamp
|
||||
trade date.
|
||||
"""
|
||||
|
||||
@@ -237,7 +237,7 @@ class CacheUtils:
|
||||
lock.acquire()
|
||||
except redis_lock.AlreadyAcquired:
|
||||
raise QlibCacheException(
|
||||
f"""It sees the key(lock:{repr(lock_name)[1:-1]}-wlock) of the redis lock has existed in your redis db now.
|
||||
f"""It sees the key(lock:{repr(lock_name)[1:-1]}-wlock) of the redis lock has existed in your redis db now.
|
||||
You can use the following command to clear your redis keys and rerun your commands:
|
||||
$ redis-cli
|
||||
> select {C.redis_task_db}
|
||||
@@ -784,10 +784,10 @@ class DiskDatasetCache(DatasetCache):
|
||||
def build_index_from_data(data, start_index=0):
|
||||
if data.empty:
|
||||
return pd.DataFrame()
|
||||
line_data = data.iloc[:, 0].fillna(0).groupby("datetime").count()
|
||||
line_data = data.groupby("datetime").size()
|
||||
line_data.sort_index(inplace=True)
|
||||
index_end = line_data.cumsum()
|
||||
index_start = index_end.shift(1).fillna(0)
|
||||
index_start = index_end.shift(1, fill_value=0)
|
||||
|
||||
index_data = pd.DataFrame()
|
||||
index_data["start"] = index_start
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from ...utils.serial import Serializable
|
||||
from typing import Union, List, Tuple, Dict, Text, Optional
|
||||
from ...utils import init_instance_by_config, np_ffill
|
||||
from ...utils import init_instance_by_config, np_ffill, time_to_slc_point
|
||||
from ...log import get_module_logger
|
||||
from .handler import DataHandler, DataHandlerLP
|
||||
from copy import deepcopy
|
||||
@@ -243,6 +243,8 @@ class TSDataSampler:
|
||||
|
||||
It works like `torch.data.utils.Dataset`, it provides a very convenient interface for constructing time-series
|
||||
dataset based on tabular data.
|
||||
- On time step dimension, the smaller index indicates the historical data and the larger index indicates the future
|
||||
data.
|
||||
|
||||
If user have further requirements for processing data, user could process them based on `TSDataSampler` or create
|
||||
more powerful subclasses.
|
||||
@@ -309,11 +311,19 @@ class TSDataSampler:
|
||||
self.data_index = deepcopy(self.data.index)
|
||||
|
||||
if flt_data is not None:
|
||||
self.flt_data = np.array(flt_data.reindex(self.data_index)).reshape(-1)
|
||||
if isinstance(flt_data, pd.DataFrame):
|
||||
assert len(flt_data.columns) == 1
|
||||
flt_data = flt_data.iloc[:, 0]
|
||||
# NOTE: bool(np.nan) is True !!!!!!!!
|
||||
# make sure reindex comes first. Otherwise extra NaN may appear.
|
||||
flt_data = flt_data.reindex(self.data_index).fillna(False).astype(np.bool)
|
||||
self.flt_data = flt_data.values
|
||||
self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map)
|
||||
self.data_index = self.data_index[np.where(self.flt_data == True)[0]]
|
||||
|
||||
self.start_idx, self.end_idx = self.data_index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end))
|
||||
self.start_idx, self.end_idx = self.data_index.slice_locs(
|
||||
start=time_to_slc_point(start), end=time_to_slc_point(end)
|
||||
)
|
||||
self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance
|
||||
|
||||
del self.data # save memory
|
||||
@@ -341,7 +351,7 @@ class TSDataSampler:
|
||||
setattr(self, k, v)
|
||||
|
||||
@staticmethod
|
||||
def build_index(data: pd.DataFrame) -> dict:
|
||||
def build_index(data: pd.DataFrame) -> Tuple[pd.DataFrame, dict]:
|
||||
"""
|
||||
The relation of the data
|
||||
|
||||
@@ -352,9 +362,15 @@ class TSDataSampler:
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict:
|
||||
{<index>: <prev_index or None>}
|
||||
# get the previous index of a line given index
|
||||
Tuple[pd.DataFrame, dict]:
|
||||
1) the first element: reshape the original index into a <datetime(row), instrument(column)> 2D dataframe
|
||||
instrument SH600000 SH600004 SH600006 SH600007 SH600008 SH600009 ...
|
||||
datetime
|
||||
2021-01-11 0 1 2 3 4 5 ...
|
||||
2021-01-12 4146 4147 4148 4149 4150 4151 ...
|
||||
2021-01-13 8293 8294 8295 8296 8297 8298 ...
|
||||
2021-01-14 12441 12442 12443 12444 12445 12446 ...
|
||||
2) the second element: {<original index>: <row, col>}
|
||||
"""
|
||||
# object incase of pandas converting int to flaot
|
||||
idx_df = pd.Series(range(data.shape[0]), index=data.index, dtype=object)
|
||||
@@ -491,7 +507,9 @@ class TSDatasetH(DatasetH):
|
||||
- The dimension of a batch of data <batch_idx, feature, timestep>
|
||||
"""
|
||||
|
||||
def __init__(self, step_len=30, **kwargs):
|
||||
DEFAULT_STEP_LEN = 30
|
||||
|
||||
def __init__(self, step_len=DEFAULT_STEP_LEN, **kwargs):
|
||||
self.step_len = step_len
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
@@ -28,16 +28,18 @@ class QlibLogger(metaclass=MetaLogger):
|
||||
|
||||
def __init__(self, module_name):
|
||||
self.module_name = module_name
|
||||
self.level = 0
|
||||
# this feature name conflicts with the attribute with Logger
|
||||
# rename it to avoid some corner cases that result in comparing `str` and `int`
|
||||
self.__level = 0
|
||||
|
||||
@property
|
||||
def logger(self):
|
||||
logger = logging.getLogger(self.module_name)
|
||||
logger.setLevel(self.level)
|
||||
logger.setLevel(self.__level)
|
||||
return logger
|
||||
|
||||
def setLevel(self, level):
|
||||
self.level = level
|
||||
self.__level = level
|
||||
|
||||
def __getattr__(self, name):
|
||||
# During unpickling, python will call __getattr__. Use this line to avoid maximum recursion error.
|
||||
|
||||
@@ -97,7 +97,7 @@ class ModelFT(Model):
|
||||
|
||||
# Finetune model based on previous trained model
|
||||
with R.start(experiment_name="finetune model"):
|
||||
recorder = R.get_recorder(rid, experiment_name="init models")
|
||||
recorder = R.get_recorder(recorder_id=rid, experiment_name="init models")
|
||||
model = recorder.load_object("init_model")
|
||||
model.finetune(dataset, num_boost_round=10)
|
||||
|
||||
|
||||
@@ -8,13 +8,15 @@ There are two steps in each Trainer including ``train``(make model recorder) and
|
||||
This is a concept called ``DelayTrainer``, which can be used in online simulating for parallel training.
|
||||
In ``DelayTrainer``, the first step is only to save some necessary info to model recorders, and the second step which will be finished in the end can do some concurrent and time-consuming operations such as model fitting.
|
||||
|
||||
``Qlib`` offer two kinds of Trainer, ``TrainerR`` is the simplest way and ``TrainerRM`` is based on TaskManager to help manager tasks lifecycle automatically.
|
||||
``Qlib`` offer two kinds of Trainer, ``TrainerR`` is the simplest way and ``TrainerRM`` is based on TaskManager to help manager tasks lifecycle automatically.
|
||||
"""
|
||||
|
||||
import socket
|
||||
import time
|
||||
from typing import Callable, List
|
||||
|
||||
from qlib.data.dataset import Dataset
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.model.base import Model
|
||||
from qlib.utils import flatten_dict, get_cls_kwargs, init_instance_by_config
|
||||
from qlib.workflow import R
|
||||
@@ -151,6 +153,9 @@ class Trainer:
|
||||
"""
|
||||
return self.delay
|
||||
|
||||
def __call__(self, *args, **kwargs) -> list:
|
||||
return self.end_train(self.train(*args, **kwargs))
|
||||
|
||||
|
||||
class TrainerR(Trainer):
|
||||
"""
|
||||
@@ -190,6 +195,8 @@ class TrainerR(Trainer):
|
||||
Returns:
|
||||
List[Recorder]: a list of Recorders
|
||||
"""
|
||||
if isinstance(tasks, dict):
|
||||
tasks = [tasks]
|
||||
if len(tasks) == 0:
|
||||
return []
|
||||
if train_func is None:
|
||||
@@ -213,6 +220,8 @@ class TrainerR(Trainer):
|
||||
Returns:
|
||||
List[Recorder]: the same list as the param.
|
||||
"""
|
||||
if isinstance(recs, Recorder):
|
||||
recs = [recs]
|
||||
for rec in recs:
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
|
||||
return recs
|
||||
@@ -250,6 +259,8 @@ class DelayTrainerR(TrainerR):
|
||||
Returns:
|
||||
List[Recorder]: a list of Recorders
|
||||
"""
|
||||
if isinstance(recs, Recorder):
|
||||
recs = [recs]
|
||||
if end_train_func is None:
|
||||
end_train_func = self.end_train_func
|
||||
if experiment_name is None:
|
||||
@@ -275,7 +286,12 @@ class TrainerRM(Trainer):
|
||||
STATUS_BEGIN = "begin_task_train"
|
||||
STATUS_END = "end_task_train"
|
||||
|
||||
def __init__(self, experiment_name: str = None, task_pool: str = None, train_func=task_train):
|
||||
# This tag is the _id in TaskManager to distinguish tasks.
|
||||
TM_ID = "_id in TaskManager"
|
||||
|
||||
def __init__(
|
||||
self, experiment_name: str = None, task_pool: str = None, train_func=task_train, skip_run_task: bool = False
|
||||
):
|
||||
"""
|
||||
Init TrainerR.
|
||||
|
||||
@@ -283,11 +299,16 @@ class TrainerRM(Trainer):
|
||||
experiment_name (str): the default name of experiment.
|
||||
task_pool (str): task pool name in TaskManager. None for use same name as experiment_name.
|
||||
train_func (Callable, optional): default training method. Defaults to `task_train`.
|
||||
skip_run_task (bool):
|
||||
If skip_run_task == True:
|
||||
Only run_task in the worker. Otherwise skip run_task.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.experiment_name = experiment_name
|
||||
self.task_pool = task_pool
|
||||
self.train_func = train_func
|
||||
self.skip_run_task = skip_run_task
|
||||
|
||||
def train(
|
||||
self,
|
||||
@@ -315,6 +336,8 @@ class TrainerRM(Trainer):
|
||||
Returns:
|
||||
List[Recorder]: a list of Recorders
|
||||
"""
|
||||
if isinstance(tasks, dict):
|
||||
tasks = [tasks]
|
||||
if len(tasks) == 0:
|
||||
return []
|
||||
if train_func is None:
|
||||
@@ -326,19 +349,26 @@ class TrainerRM(Trainer):
|
||||
task_pool = experiment_name
|
||||
tm = TaskManager(task_pool=task_pool)
|
||||
_id_list = tm.create_task(tasks) # all tasks will be saved to MongoDB
|
||||
run_task(
|
||||
train_func,
|
||||
task_pool,
|
||||
experiment_name=experiment_name,
|
||||
before_status=before_status,
|
||||
after_status=after_status,
|
||||
**kwargs,
|
||||
)
|
||||
query = {"_id": {"$in": _id_list}}
|
||||
if not self.skip_run_task:
|
||||
run_task(
|
||||
train_func,
|
||||
task_pool,
|
||||
query=query, # only train these tasks
|
||||
experiment_name=experiment_name,
|
||||
before_status=before_status,
|
||||
after_status=after_status,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not self.is_delay():
|
||||
tm.wait(query=query)
|
||||
|
||||
recs = []
|
||||
for _id in _id_list:
|
||||
rec = tm.re_query(_id)["res"]
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN})
|
||||
rec.set_tags(**{self.TM_ID: _id})
|
||||
recs.append(rec)
|
||||
return recs
|
||||
|
||||
@@ -352,10 +382,33 @@ class TrainerRM(Trainer):
|
||||
Returns:
|
||||
List[Recorder]: the same list as the param.
|
||||
"""
|
||||
if isinstance(recs, Recorder):
|
||||
recs = [recs]
|
||||
for rec in recs:
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
|
||||
return recs
|
||||
|
||||
def worker(
|
||||
self,
|
||||
train_func: Callable = None,
|
||||
experiment_name: str = None,
|
||||
):
|
||||
"""
|
||||
The multiprocessing method for `train`. It can share a same task_pool with `train` and can run in other progress or other machines.
|
||||
|
||||
Args:
|
||||
train_func (Callable): the training method which needs at least `task`s and `experiment_name`. None for the default training method.
|
||||
experiment_name (str): the experiment name, None for use default name.
|
||||
"""
|
||||
if train_func is None:
|
||||
train_func = self.train_func
|
||||
if experiment_name is None:
|
||||
experiment_name = self.experiment_name
|
||||
task_pool = self.task_pool
|
||||
if task_pool is None:
|
||||
task_pool = experiment_name
|
||||
run_task(train_func, task_pool=task_pool, experiment_name=experiment_name)
|
||||
|
||||
|
||||
class DelayTrainerRM(TrainerRM):
|
||||
"""
|
||||
@@ -369,6 +422,7 @@ class DelayTrainerRM(TrainerRM):
|
||||
task_pool: str = None,
|
||||
train_func=begin_task_train,
|
||||
end_train_func=end_task_train,
|
||||
skip_run_task: bool = False,
|
||||
):
|
||||
"""
|
||||
Init DelayTrainerRM.
|
||||
@@ -378,10 +432,15 @@ class DelayTrainerRM(TrainerRM):
|
||||
task_pool (str): task pool name in TaskManager. None for use same name as experiment_name.
|
||||
train_func (Callable, optional): default train method. Defaults to `begin_task_train`.
|
||||
end_train_func (Callable, optional): default end_train method. Defaults to `end_task_train`.
|
||||
skip_run_task (bool):
|
||||
If skip_run_task == True:
|
||||
Only run_task in the worker. Otherwise skip run_task.
|
||||
E.g. Starting trainer on a CPU VM and then waiting tasks to be finished on GPU VMs.
|
||||
"""
|
||||
super().__init__(experiment_name, task_pool, train_func)
|
||||
self.end_train_func = end_train_func
|
||||
self.delay = True
|
||||
self.skip_run_task = skip_run_task
|
||||
|
||||
def train(self, tasks: list, train_func=None, experiment_name: str = None, **kwargs) -> List[Recorder]:
|
||||
"""
|
||||
@@ -395,6 +454,8 @@ class DelayTrainerRM(TrainerRM):
|
||||
Returns:
|
||||
List[Recorder]: a list of Recorders
|
||||
"""
|
||||
if isinstance(tasks, dict):
|
||||
tasks = [tasks]
|
||||
if len(tasks) == 0:
|
||||
return []
|
||||
return super().train(
|
||||
@@ -410,8 +471,6 @@ class DelayTrainerRM(TrainerRM):
|
||||
Given a list of Recorder and return a list of trained Recorder.
|
||||
This class will finish real data loading and model fitting.
|
||||
|
||||
NOTE: This method will train all STATUS_PART_DONE tasks in the task pool, not only the ``recs``.
|
||||
|
||||
Args:
|
||||
recs (list): a list of Recorder, the tasks have been saved to them.
|
||||
end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
|
||||
@@ -421,7 +480,8 @@ class DelayTrainerRM(TrainerRM):
|
||||
Returns:
|
||||
List[Recorder]: a list of Recorders
|
||||
"""
|
||||
|
||||
if isinstance(recs, Recorder):
|
||||
recs = [recs]
|
||||
if end_train_func is None:
|
||||
end_train_func = self.end_train_func
|
||||
if experiment_name is None:
|
||||
@@ -429,18 +489,45 @@ class DelayTrainerRM(TrainerRM):
|
||||
task_pool = self.task_pool
|
||||
if task_pool is None:
|
||||
task_pool = experiment_name
|
||||
tasks = []
|
||||
_id_list = []
|
||||
for rec in recs:
|
||||
tasks.append(rec.load_object("task"))
|
||||
_id_list.append(rec.list_tags()[self.TM_ID])
|
||||
|
||||
query = {"_id": {"$in": _id_list}}
|
||||
if not self.skip_run_task:
|
||||
run_task(
|
||||
end_train_func,
|
||||
task_pool,
|
||||
query=query, # only train these tasks
|
||||
experiment_name=experiment_name,
|
||||
before_status=TaskManager.STATUS_PART_DONE,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
TaskManager(task_pool=task_pool).wait(query=query)
|
||||
|
||||
run_task(
|
||||
end_train_func,
|
||||
task_pool,
|
||||
query={"filter": {"$in": tasks}}, # only train these tasks
|
||||
experiment_name=experiment_name,
|
||||
before_status=TaskManager.STATUS_PART_DONE,
|
||||
**kwargs,
|
||||
)
|
||||
for rec in recs:
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
|
||||
return recs
|
||||
|
||||
def worker(self, end_train_func=None, experiment_name: str = None):
|
||||
"""
|
||||
The multiprocessing method for `end_train`. It can share a same task_pool with `end_train` and can run in other progress or other machines.
|
||||
|
||||
Args:
|
||||
end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
|
||||
experiment_name (str): the experiment name, None for use default name.
|
||||
"""
|
||||
if end_train_func is None:
|
||||
end_train_func = self.end_train_func
|
||||
if experiment_name is None:
|
||||
experiment_name = self.experiment_name
|
||||
task_pool = self.task_pool
|
||||
if task_pool is None:
|
||||
task_pool = experiment_name
|
||||
run_task(
|
||||
end_train_func,
|
||||
task_pool=task_pool,
|
||||
experiment_name=experiment_name,
|
||||
before_status=TaskManager.STATUS_PART_DONE,
|
||||
)
|
||||
|
||||
@@ -43,17 +43,29 @@ RECORD_CONFIG = [
|
||||
]
|
||||
|
||||
|
||||
def get_data_handler_config(market=CSI300_MARKET):
|
||||
def get_data_handler_config(
|
||||
start_time="2008-01-01",
|
||||
end_time="2020-08-01",
|
||||
fit_start_time="2008-01-01",
|
||||
fit_end_time="2014-12-31",
|
||||
instruments=CSI300_MARKET,
|
||||
):
|
||||
return {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": market,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"fit_start_time": fit_start_time,
|
||||
"fit_end_time": fit_end_time,
|
||||
"instruments": instruments,
|
||||
}
|
||||
|
||||
|
||||
def get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA158_CLASS):
|
||||
def get_dataset_config(
|
||||
dataset_class=DATASET_ALPHA158_CLASS,
|
||||
train=("2008-01-01", "2014-12-31"),
|
||||
valid=("2015-01-01", "2016-12-31"),
|
||||
test=("2017-01-01", "2020-08-01"),
|
||||
handler_kwargs={"instruments": CSI300_MARKET},
|
||||
):
|
||||
return {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
@@ -61,48 +73,88 @@ def get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA158_CLAS
|
||||
"handler": {
|
||||
"class": dataset_class,
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": get_data_handler_config(market),
|
||||
"kwargs": get_data_handler_config(**handler_kwargs),
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
"train": train,
|
||||
"valid": valid,
|
||||
"test": test,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_gbdt_task(market=CSI300_MARKET):
|
||||
def get_gbdt_task(dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET}):
|
||||
return {
|
||||
"model": GBDT_MODEL,
|
||||
"dataset": get_dataset_config(market),
|
||||
"dataset": get_dataset_config(**dataset_kwargs, handler_kwargs=handler_kwargs),
|
||||
}
|
||||
|
||||
|
||||
def get_record_lgb_config(market=CSI300_MARKET):
|
||||
def get_record_lgb_config(dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET}):
|
||||
return {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
},
|
||||
"dataset": get_dataset_config(market),
|
||||
"dataset": get_dataset_config(**dataset_kwargs, handler_kwargs=handler_kwargs),
|
||||
"record": RECORD_CONFIG,
|
||||
}
|
||||
|
||||
|
||||
def get_record_xgboost_config(market=CSI300_MARKET):
|
||||
def get_record_xgboost_config(dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET}):
|
||||
return {
|
||||
"model": {
|
||||
"class": "XGBModel",
|
||||
"module_path": "qlib.contrib.model.xgboost",
|
||||
},
|
||||
"dataset": get_dataset_config(market),
|
||||
"dataset": get_dataset_config(**dataset_kwargs, handler_kwargs=handler_kwargs),
|
||||
"record": RECORD_CONFIG,
|
||||
}
|
||||
|
||||
|
||||
CSI300_DATASET_CONFIG = get_dataset_config(market=CSI300_MARKET)
|
||||
CSI300_GBDT_TASK = get_gbdt_task(market=CSI300_MARKET)
|
||||
CSI300_DATASET_CONFIG = get_dataset_config(handler_kwargs={"instruments": CSI300_MARKET})
|
||||
CSI300_GBDT_TASK = get_gbdt_task(handler_kwargs={"instruments": CSI300_MARKET})
|
||||
|
||||
CSI100_RECORD_XGBOOST_TASK_CONFIG = get_record_xgboost_config(market=CSI100_MARKET)
|
||||
CSI100_RECORD_LGB_TASK_CONFIG = get_record_lgb_config(market=CSI100_MARKET)
|
||||
CSI100_RECORD_XGBOOST_TASK_CONFIG = get_record_xgboost_config(handler_kwargs={"instruments": CSI100_MARKET})
|
||||
CSI100_RECORD_LGB_TASK_CONFIG = get_record_lgb_config(handler_kwargs={"instruments": CSI100_MARKET})
|
||||
|
||||
# use for rolling_online_managment.py
|
||||
ROLLING_HANDLER_CONFIG = {
|
||||
"start_time": "2013-01-01",
|
||||
"end_time": "2020-09-25",
|
||||
"fit_start_time": "2013-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": CSI100_MARKET,
|
||||
}
|
||||
ROLLING_DATASET_CONFIG = {
|
||||
"train": ("2013-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2015-12-31"),
|
||||
"test": ("2016-01-01", "2020-07-10"),
|
||||
}
|
||||
CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING = get_record_xgboost_config(
|
||||
dataset_kwargs=ROLLING_DATASET_CONFIG, handler_kwargs=ROLLING_HANDLER_CONFIG
|
||||
)
|
||||
CSI100_RECORD_LGB_TASK_CONFIG_ROLLING = get_record_lgb_config(
|
||||
dataset_kwargs=ROLLING_DATASET_CONFIG, handler_kwargs=ROLLING_HANDLER_CONFIG
|
||||
)
|
||||
|
||||
# use for online_management_simulate.py
|
||||
ONLINE_HANDLER_CONFIG = {
|
||||
"start_time": "2018-01-01",
|
||||
"end_time": "2018-10-31",
|
||||
"fit_start_time": "2018-01-01",
|
||||
"fit_end_time": "2018-03-31",
|
||||
"instruments": CSI100_MARKET,
|
||||
}
|
||||
ONLINE_DATASET_CONFIG = {
|
||||
"train": ("2018-01-01", "2018-03-31"),
|
||||
"valid": ("2018-04-01", "2018-05-31"),
|
||||
"test": ("2018-06-01", "2018-09-10"),
|
||||
}
|
||||
CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE = get_record_xgboost_config(
|
||||
dataset_kwargs=ONLINE_DATASET_CONFIG, handler_kwargs=ONLINE_HANDLER_CONFIG
|
||||
)
|
||||
CSI100_RECORD_LGB_TASK_CONFIG_ONLINE = get_record_lgb_config(
|
||||
dataset_kwargs=ONLINE_DATASET_CONFIG, handler_kwargs=ONLINE_HANDLER_CONFIG
|
||||
)
|
||||
|
||||
@@ -642,6 +642,28 @@ def split_pred(pred, number=None, split_date=None):
|
||||
return pred_left, pred_right
|
||||
|
||||
|
||||
def time_to_slc_point(t: Union[None, str, pd.Timestamp]) -> Union[None, pd.Timestamp]:
|
||||
"""
|
||||
Time slicing in Qlib or Pandas is a frequently-used action.
|
||||
However, user often input all kinds of data format to represent time.
|
||||
This function will help user to convert these inputs into a uniform format which is friendly to time slicing.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
t : Union[None, str, pd.Timestamp]
|
||||
original time
|
||||
|
||||
Returns
|
||||
-------
|
||||
Union[None, pd.Timestamp]:
|
||||
"""
|
||||
if t is None:
|
||||
# None represents unbounded in Qlib or Pandas(e.g. df.loc[slice(None, "20210303")]).
|
||||
return t
|
||||
else:
|
||||
return pd.Timestamp(t)
|
||||
|
||||
|
||||
def can_use_cache():
|
||||
res = True
|
||||
r = get_redis_connection()
|
||||
|
||||
12
qlib/utils/exceptions.py
Normal file
12
qlib/utils/exceptions.py
Normal file
@@ -0,0 +1,12 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
# Base exception class
|
||||
class QlibException(Exception):
|
||||
def __init__(self, message):
|
||||
super(QlibException, self).__init__(message)
|
||||
|
||||
|
||||
# Error type for reinitialization when starting an experiment
|
||||
class RecorderInitializationError(QlibException):
|
||||
pass
|
||||
@@ -7,6 +7,7 @@ from .expm import MLflowExpManager
|
||||
from .exp import Experiment
|
||||
from .recorder import Recorder
|
||||
from ..utils import Wrapper
|
||||
from ..utils.exceptions import RecorderInitializationError
|
||||
|
||||
|
||||
class QlibRecorder:
|
||||
@@ -215,9 +216,9 @@ class QlibRecorder:
|
||||
-------
|
||||
A dictionary (id -> recorder) of recorder information that being stored.
|
||||
"""
|
||||
return self.get_exp(experiment_id, experiment_name).list_recorders()
|
||||
return self.get_exp(experiment_id=experiment_id, experiment_name=experiment_name).list_recorders()
|
||||
|
||||
def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True) -> Experiment:
|
||||
def get_exp(self, *, experiment_id=None, experiment_name=None, create: bool = True) -> Experiment:
|
||||
"""
|
||||
Method for retrieving an experiment with given id or name. Once the `create` argument is set to
|
||||
True, if no valid experiment is found, this method will create one for you. Otherwise, it will
|
||||
@@ -262,7 +263,7 @@ class QlibRecorder:
|
||||
|
||||
# Case 2
|
||||
with R.start('test'):
|
||||
exp = R.get_exp('test1')
|
||||
exp = R.get_exp(experiment_name='test1')
|
||||
|
||||
# Case 3
|
||||
exp = R.get_exp() -> a default experiment.
|
||||
@@ -287,7 +288,9 @@ class QlibRecorder:
|
||||
-------
|
||||
An experiment instance with given id or name.
|
||||
"""
|
||||
return self.exp_manager.get_exp(experiment_id, experiment_name, create, start=False)
|
||||
return self.exp_manager.get_exp(
|
||||
experiment_id=experiment_id, experiment_name=experiment_name, create=create, start=False
|
||||
)
|
||||
|
||||
def delete_exp(self, experiment_id=None, experiment_name=None):
|
||||
"""
|
||||
@@ -331,7 +334,9 @@ class QlibRecorder:
|
||||
"""
|
||||
self.exp_manager.set_uri(uri)
|
||||
|
||||
def get_recorder(self, recorder_id=None, recorder_name=None, experiment_name=None) -> Recorder:
|
||||
def get_recorder(
|
||||
self, *, recorder_id=None, recorder_name=None, experiment_id=None, experiment_name=None
|
||||
) -> Recorder:
|
||||
"""
|
||||
Method for retrieving a recorder.
|
||||
|
||||
@@ -384,7 +389,7 @@ class QlibRecorder:
|
||||
-------
|
||||
A recorder instance.
|
||||
"""
|
||||
return self.get_exp(experiment_name=experiment_name, create=False).get_recorder(
|
||||
return self.get_exp(experiment_name=experiment_name, experiment_id=experiment_id, create=False).get_recorder(
|
||||
recorder_id, recorder_name, create=False, start=False
|
||||
)
|
||||
|
||||
@@ -525,14 +530,29 @@ class QlibRecorder:
|
||||
self.get_exp().get_recorder().set_tags(**kwargs)
|
||||
|
||||
|
||||
class RecorderWrapper(Wrapper):
|
||||
"""
|
||||
Wrapper class for QlibRecorder, which detects whether users reinitialize qlib when already starting an experiment.
|
||||
"""
|
||||
|
||||
def register(self, provider):
|
||||
if self._provider is not None:
|
||||
expm = getattr(self._provider, "exp_manager")
|
||||
if expm.active_experiment is not None:
|
||||
raise RecorderInitializationError(
|
||||
"Please don't reinitialize Qlib if QlibRecorder is already acivated. Otherwise, the experiment stored location will be modified."
|
||||
)
|
||||
self._provider = provider
|
||||
|
||||
|
||||
import sys
|
||||
|
||||
if sys.version_info >= (3, 9):
|
||||
from typing import Annotated
|
||||
|
||||
QlibRecorderWrapper = Annotated[QlibRecorder, Wrapper]
|
||||
QlibRecorderWrapper = Annotated[QlibRecorder, RecorderWrapper]
|
||||
else:
|
||||
QlibRecorderWrapper = QlibRecorder
|
||||
|
||||
# global record
|
||||
R: QlibRecorderWrapper = Wrapper()
|
||||
R: QlibRecorderWrapper = RecorderWrapper()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from typing import Union
|
||||
import mlflow, logging
|
||||
from mlflow.entities import ViewType
|
||||
from mlflow.exceptions import MlflowException
|
||||
@@ -213,11 +214,15 @@ class Experiment:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `_get_recorder` method")
|
||||
|
||||
def list_recorders(self):
|
||||
def list_recorders(self, **flt_kwargs):
|
||||
"""
|
||||
List all the existing recorders of this experiment. Please first get the experiment instance before calling this method.
|
||||
If user want to use the method `R.list_recorders()`, please refer to the related API document in `QlibRecorder`.
|
||||
|
||||
flt_kwargs : dict
|
||||
filter recorders by conditions
|
||||
e.g. list_recorders(status=Recorder.STATUS_FI)
|
||||
|
||||
Returns
|
||||
-------
|
||||
A dictionary (id -> recorder) of recorder information that being stored.
|
||||
@@ -320,11 +325,21 @@ class MLflowExperiment(Experiment):
|
||||
|
||||
UNLIMITED = 50000 # FIXME: Mlflow can only list 50000 records at most!!!!!!!
|
||||
|
||||
def list_recorders(self, max_results=UNLIMITED):
|
||||
def list_recorders(self, max_results: int = UNLIMITED, status: Union[str, None] = None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
max_results : int
|
||||
the number limitation of the results
|
||||
status : str
|
||||
the criteria based on status to filter results.
|
||||
`None` indicates no filtering.
|
||||
"""
|
||||
runs = self._client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results)
|
||||
recorders = dict()
|
||||
for i in range(len(runs)):
|
||||
recorder = MLflowRecorder(self.id, self._uri, mlflow_run=runs[i])
|
||||
recorders[runs[i].info.run_id] = recorder
|
||||
if status is None or recorder.status == status:
|
||||
recorders[runs[i].info.run_id] = recorder
|
||||
|
||||
return recorders
|
||||
|
||||
@@ -109,7 +109,7 @@ class ExpManager:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `search_records` method.")
|
||||
|
||||
def get_exp(self, experiment_id=None, experiment_name=None, create: bool = True, start: bool = False):
|
||||
def get_exp(self, *, experiment_id=None, experiment_name=None, create: bool = True, start: bool = False):
|
||||
"""
|
||||
Retrieve an experiment. This method includes getting an active experiment, and get_or_create a specific experiment.
|
||||
|
||||
@@ -190,7 +190,7 @@ class ExpManager:
|
||||
except ValueError:
|
||||
if experiment_name is None:
|
||||
experiment_name = self._default_exp_name
|
||||
logger.info(f"No valid experiment found. Create a new experiment with name {experiment_name}.")
|
||||
logger.warning(f"No valid experiment found. Create a new experiment with name {experiment_name}.")
|
||||
return self.create_exp(experiment_name), True
|
||||
|
||||
def _get_exp(self, experiment_id=None, experiment_name=None) -> Experiment:
|
||||
@@ -352,6 +352,8 @@ class MLflowExpManager(ExpManager):
|
||||
), "Please input at least one of experiment/recorder id or name before retrieving experiment/recorder."
|
||||
if experiment_id is not None:
|
||||
try:
|
||||
# NOTE: the mlflow's experiment_id must be str type...
|
||||
# https://www.mlflow.org/docs/latest/python_api/mlflow.tracking.html#mlflow.tracking.MlflowClient.get_experiment
|
||||
exp = self.client.get_experiment(experiment_id)
|
||||
if exp.lifecycle_stage.upper() == "DELETED":
|
||||
raise MlflowException("No valid experiment has been found.")
|
||||
|
||||
@@ -6,7 +6,7 @@ OnlineManager can manage a set of `Online Strategy <#Online Strategy>`_ and run
|
||||
|
||||
With the change of time, the decisive models will be also changed. In this module, we call those contributing models `online` models.
|
||||
In every routine(such as every day or every minute), the `online` models may be changed and the prediction of them needs to be updated.
|
||||
So this module provides a series of methods to control this process.
|
||||
So this module provides a series of methods to control this process.
|
||||
|
||||
This module also provides a method to simulate `Online Strategy <#Online Strategy>`_ in history.
|
||||
Which means you can verify your strategy or find a better one.
|
||||
@@ -18,10 +18,12 @@ There are 4 total situations for using different trainers in different situation
|
||||
========================= ===================================================================================
|
||||
Situations Description
|
||||
========================= ===================================================================================
|
||||
Online + Trainer When you REAL want to do a routine, the Trainer will help you train the models.
|
||||
Online + Trainer When you want to do a REAL routine, the Trainer will help you train the models. It
|
||||
will train models task by task and strategy by strategy.
|
||||
|
||||
Online + DelayTrainer In normal online routine, whether Trainer or DelayTrainer will REAL train models
|
||||
in this routine. So it is not necessary to use DelayTrainer when do a REAL routine.
|
||||
Online + DelayTrainer When your models don't have any temporal dependence, the DelayTrainer will train
|
||||
nothing until all tasks have been prepared. It makes user can train all tasks in
|
||||
the end of `routine` or `first_train`.
|
||||
|
||||
Simulation + Trainer When your models have some temporal dependence on the previous models, then you
|
||||
need to consider using Trainer. This means it will REAL train your models in
|
||||
@@ -29,7 +31,7 @@ Simulation + Trainer When your models have some temporal dependence on the
|
||||
|
||||
Simulation + DelayTrainer When your models don't have any temporal dependence, you can use DelayTrainer
|
||||
for the ability to multitasking. It means all tasks in all routines
|
||||
can be REAL trained at the end of simulating. The signals will be prepared well at
|
||||
can be REAL trained at the end of simulating. The signals will be prepared well at
|
||||
different time segments (based on whether or not any new model is online).
|
||||
========================= ===================================================================================
|
||||
"""
|
||||
@@ -103,17 +105,23 @@ class OnlineManager(Serializable):
|
||||
"""
|
||||
if strategies is None:
|
||||
strategies = self.strategies
|
||||
for strategy in strategies:
|
||||
|
||||
models_list = []
|
||||
for strategy in strategies:
|
||||
self.logger.info(f"Strategy `{strategy.name_id}` begins first training...")
|
||||
tasks = strategy.first_tasks()
|
||||
models = self.trainer.train(tasks, experiment_name=strategy.name_id)
|
||||
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
|
||||
models_list.append(models)
|
||||
self.logger.info(f"Finished training {len(models)} models.")
|
||||
|
||||
# FIXME: Traing multiple online models at `first_train` will result in getting too much online models at the
|
||||
# start.
|
||||
online_models = strategy.prepare_online_models(models, **model_kwargs)
|
||||
self.history.setdefault(self.cur_time, {})[strategy] = online_models
|
||||
|
||||
if not self.status == self.STATUS_SIMULATING or not self.trainer.is_delay():
|
||||
for strategy, models in zip(strategies, models_list):
|
||||
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
|
||||
|
||||
def routine(
|
||||
self,
|
||||
cur_time: Union[str, pd.Timestamp] = None,
|
||||
@@ -139,33 +147,41 @@ class OnlineManager(Serializable):
|
||||
cur_time = D.calendar(freq=self.freq).max()
|
||||
self.cur_time = pd.Timestamp(cur_time) # None for latest date
|
||||
|
||||
models_list = []
|
||||
for strategy in self.strategies:
|
||||
self.logger.info(f"Strategy `{strategy.name_id}` begins routine...")
|
||||
if self.status == self.STATUS_NORMAL:
|
||||
strategy.tool.update_online_pred()
|
||||
|
||||
tasks = strategy.prepare_tasks(self.cur_time, **task_kwargs)
|
||||
models = self.trainer.train(tasks)
|
||||
if self.status == self.STATUS_NORMAL or not self.trainer.is_delay():
|
||||
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
|
||||
models = self.trainer.train(tasks, experiment_name=strategy.name_id)
|
||||
models_list.append(models)
|
||||
self.logger.info(f"Finished training {len(models)} models.")
|
||||
online_models = strategy.prepare_online_models(models, **model_kwargs)
|
||||
self.history.setdefault(self.cur_time, {})[strategy] = online_models
|
||||
|
||||
if not self.trainer.is_delay():
|
||||
# The online model may changes in the above processes
|
||||
# So updating the predictions of online models should be the last step
|
||||
if self.status == self.STATUS_NORMAL:
|
||||
strategy.tool.update_online_pred()
|
||||
|
||||
if not self.status == self.STATUS_SIMULATING or not self.trainer.is_delay():
|
||||
for strategy, models in zip(self.strategies, models_list):
|
||||
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
|
||||
self.prepare_signals(**signal_kwargs)
|
||||
|
||||
def get_collector(self) -> MergeCollector:
|
||||
def get_collector(self, **kwargs) -> MergeCollector:
|
||||
"""
|
||||
Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results from every strategy.
|
||||
This collector can be a basis as the signals preparation.
|
||||
|
||||
Args:
|
||||
**kwargs: the params for get_collector.
|
||||
|
||||
Returns:
|
||||
MergeCollector: the collector to merge other collectors.
|
||||
"""
|
||||
collector_dict = {}
|
||||
for strategy in self.strategies:
|
||||
collector_dict[strategy.name_id] = strategy.get_collector()
|
||||
collector_dict[strategy.name_id] = strategy.get_collector(**kwargs)
|
||||
return MergeCollector(collector_dict, process_list=[])
|
||||
|
||||
def add_strategy(self, strategies: Union[OnlineStrategy, List[OnlineStrategy]]):
|
||||
@@ -225,7 +241,7 @@ class OnlineManager(Serializable):
|
||||
SIM_LOG_NAME = "SIMULATE_INFO"
|
||||
|
||||
def simulate(
|
||||
self, end_time, frequency="day", task_kwargs={}, model_kwargs={}, signal_kwargs={}
|
||||
self, end_time=None, frequency="day", task_kwargs={}, model_kwargs={}, signal_kwargs={}
|
||||
) -> Union[pd.Series, pd.DataFrame]:
|
||||
"""
|
||||
Starting from the current time, this method will simulate every routine in OnlineManager until the end time.
|
||||
@@ -297,6 +313,7 @@ class OnlineManager(Serializable):
|
||||
# NOTE: Assumption: the predictions of online models need less than next cur_time, or this method will work in a wrong way.
|
||||
self.prepare_signals(**signal_kwargs)
|
||||
if signals_time > cur_time:
|
||||
# FIXME: if use DelayTrainer and worker (and worker is faster than main progress), there are some possibilities of showing this warning.
|
||||
self.logger.warn(
|
||||
f"The signals have already parpred to {signals_time} by last preparation, but current time is only {cur_time}. This may be because the online models predict more than they should, which can cause signals to be contaminated by the offline models."
|
||||
)
|
||||
|
||||
@@ -52,6 +52,12 @@ class OnlineStrategy:
|
||||
|
||||
NOTE: Reset all online models to trained models. If there are no trained models, then do nothing.
|
||||
|
||||
**NOTE**:
|
||||
Current implementation is very naive. Here is a more complex situation which is more closer to the
|
||||
practical scenarios.
|
||||
1. Train new models at the day before `test_start` (at time stamp `T`)
|
||||
2. Switch models at the `test_start` (at time timestamp `T + 1` typically)
|
||||
|
||||
Args:
|
||||
models (list): a list of models.
|
||||
cur_time (pd.Dataframe): current time from OnlineManger. None for the latest.
|
||||
|
||||
@@ -136,7 +136,7 @@ class PredUpdater(RecordUpdater):
|
||||
# https://github.com/pytorch/pytorch/issues/16797
|
||||
|
||||
start_time = get_date_by_shift(self.last_end, 1, freq=self.freq)
|
||||
if start_time >= self.to_date:
|
||||
if start_time > self.to_date:
|
||||
self.logger.info(
|
||||
f"The prediction in {self.record.info['id']} are latest ({start_time}). No need to update to {self.to_date}."
|
||||
)
|
||||
|
||||
@@ -8,8 +8,10 @@ This allows us to use efficient submodels as the market-style changing.
|
||||
"""
|
||||
|
||||
from typing import List, Union
|
||||
from qlib.data.dataset import TSDatasetH
|
||||
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.utils import get_cls_kwargs
|
||||
from qlib.workflow.online.update import PredUpdater
|
||||
from qlib.workflow.recorder import Recorder
|
||||
from qlib.workflow.task.utils import list_recorders
|
||||
@@ -161,8 +163,9 @@ class OnlineToolR(OnlineTool):
|
||||
hist_ref = 0
|
||||
task = rec.load_object("task")
|
||||
# Special treatment of historical dependencies
|
||||
if task["dataset"]["class"] == "TSDatasetH":
|
||||
hist_ref = task["dataset"]["kwargs"]["step_len"]
|
||||
cls, kwargs = get_cls_kwargs(task["dataset"], default_module="qlib.data.dataset")
|
||||
if issubclass(cls, TSDatasetH):
|
||||
hist_ref = kwargs.get("step_len", TSDatasetH.DEFAULT_STEP_LEN)
|
||||
PredUpdater(rec, to_date=to_date, hist_ref=hist_ref).update()
|
||||
|
||||
self.logger.info(f"Finished updating {len(online_models)} online model predictions of {self.exp_name}.")
|
||||
|
||||
@@ -6,6 +6,7 @@ Collector module can collect objects from everywhere and process them such as me
|
||||
"""
|
||||
|
||||
from typing import Callable, Dict, List
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.utils.serial import Serializable
|
||||
from qlib.workflow import R
|
||||
|
||||
@@ -192,6 +193,7 @@ class RecorderCollector(Collector):
|
||||
if rec_filter_func is None or rec_filter_func(rec):
|
||||
recs_flt[rid] = rec
|
||||
|
||||
logger = get_module_logger("RecorderCollector")
|
||||
for _, rec in recs_flt.items():
|
||||
rec_key = self.rec_key_func(rec)
|
||||
for key in artifacts_key:
|
||||
@@ -205,7 +207,13 @@ class RecorderCollector(Collector):
|
||||
# only collect existing artifact
|
||||
continue
|
||||
raise e
|
||||
collect_dict.setdefault(key, {})[rec_key] = artifact
|
||||
# give user some warning if the values are overridden
|
||||
cdd = collect_dict.setdefault(key, {})
|
||||
if rec_key in cdd:
|
||||
logger.warning(
|
||||
f"key '{rec_key}' is duplicated. Previous value will be overrides. Please check you `rec_key_func`"
|
||||
)
|
||||
cdd[rec_key] = artifact
|
||||
|
||||
return collect_dict
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@ TaskGenerator module can generate many tasks based on TaskGen and some task temp
|
||||
import abc
|
||||
import copy
|
||||
from typing import List, Union, Callable
|
||||
|
||||
from qlib.utils import transform_end_date
|
||||
from .utils import TimeAdjuster
|
||||
|
||||
|
||||
@@ -199,7 +201,7 @@ class RollingGen(TaskGen):
|
||||
# First rolling
|
||||
# 1) prepare the end point
|
||||
segments: dict = copy.deepcopy(self.ta.align_seg(t["dataset"]["kwargs"]["segments"]))
|
||||
test_end = self.ta.max() if segments[self.test_key][1] is None else segments[self.test_key][1]
|
||||
test_end = transform_end_date(segments[self.test_key][1])
|
||||
# 2) and init test segments
|
||||
test_start_idx = self.ta.align_idx(segments[self.test_key][0])
|
||||
segments[self.test_key] = (self.ta.get(test_start_idx), self.ta.get(test_start_idx + self.step - 1))
|
||||
|
||||
@@ -69,28 +69,29 @@ class TaskManager:
|
||||
|
||||
ENCODE_FIELDS_PREFIX = ["def", "res"]
|
||||
|
||||
def __init__(self, task_pool: str = None):
|
||||
def __init__(self, task_pool: str):
|
||||
"""
|
||||
Init Task Manager, remember to make the statement of MongoDB url and database name firstly.
|
||||
A TaskManager instance serves a specific task pool.
|
||||
The static method of this module serves the whole MongoDB.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task_pool: str
|
||||
the name of Collection in MongoDB
|
||||
"""
|
||||
self.mdb = get_mongodb()
|
||||
if task_pool is not None:
|
||||
self.task_pool = getattr(self.mdb, task_pool)
|
||||
self.task_pool = getattr(get_mongodb(), task_pool)
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
|
||||
def list(self) -> list:
|
||||
@staticmethod
|
||||
def list() -> list:
|
||||
"""
|
||||
List the all collection(task_pool) of the db
|
||||
List the all collection(task_pool) of the db.
|
||||
|
||||
Returns:
|
||||
list
|
||||
"""
|
||||
return self.mdb.list_collection_names()
|
||||
return get_mongodb().list_collection_names()
|
||||
|
||||
def _encode_task(self, task):
|
||||
for prefix in self.ENCODE_FIELDS_PREFIX:
|
||||
@@ -109,6 +110,25 @@ class TaskManager:
|
||||
def _dict_to_str(self, flt):
|
||||
return {k: str(v) for k, v in flt.items()}
|
||||
|
||||
def _decode_query(self, query):
|
||||
"""
|
||||
If the query includes any `_id`, then it needs `ObjectId` to decode.
|
||||
For example, when using TrainerRM, it needs query `{"_id": {"$in": _id_list}}`. Then we need to `ObjectId` every `_id` in `_id_list`.
|
||||
|
||||
Args:
|
||||
query (dict): query dict. Defaults to {}.
|
||||
|
||||
Returns:
|
||||
dict: the query after decoding.
|
||||
"""
|
||||
if "_id" in query:
|
||||
if isinstance(query["_id"], dict):
|
||||
for key in query["_id"]:
|
||||
query["_id"][key] = [ObjectId(i) for i in query["_id"][key]]
|
||||
else:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
return query
|
||||
|
||||
def replace_task(self, task, new_task):
|
||||
"""
|
||||
Use a new task to replace a old one
|
||||
@@ -224,8 +244,7 @@ class TaskManager:
|
||||
dict: a task(document in collection) after decoding
|
||||
"""
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
query = self._decode_query(query)
|
||||
query.update({"status": status})
|
||||
task = self.task_pool.find_one_and_update(
|
||||
query, {"$set": {"status": self.STATUS_RUNNING}}, sort=[("priority", pymongo.DESCENDING)]
|
||||
@@ -253,10 +272,10 @@ class TaskManager:
|
||||
task = self.fetch_task(query=query, status=status)
|
||||
try:
|
||||
yield task
|
||||
except Exception:
|
||||
except (Exception, KeyboardInterrupt): # KeyboardInterrupt is not a subclass of Exception
|
||||
if task is not None:
|
||||
self.logger.info("Returning task before raising error")
|
||||
self.return_task(task)
|
||||
self.return_task(task, status=status) # return task as the original status
|
||||
self.logger.info("Task returned")
|
||||
raise
|
||||
|
||||
@@ -283,12 +302,11 @@ class TaskManager:
|
||||
dict: a task(document in collection) after decoding
|
||||
"""
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
query = self._decode_query(query)
|
||||
for t in self.task_pool.find(query):
|
||||
yield self._decode_task(t)
|
||||
|
||||
def re_query(self, _id):
|
||||
def re_query(self, _id) -> dict:
|
||||
"""
|
||||
Use _id to query task.
|
||||
|
||||
@@ -339,8 +357,7 @@ class TaskManager:
|
||||
|
||||
"""
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
query = self._decode_query(query)
|
||||
self.task_pool.delete_many(query)
|
||||
|
||||
def task_stat(self, query={}) -> dict:
|
||||
@@ -354,8 +371,7 @@ class TaskManager:
|
||||
dict
|
||||
"""
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
query = self._decode_query(query)
|
||||
tasks = self.query(query=query, decode=False)
|
||||
status_stat = {}
|
||||
for t in tasks:
|
||||
@@ -377,8 +393,7 @@ class TaskManager:
|
||||
|
||||
def reset_status(self, query, status):
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
query = self._decode_query(query)
|
||||
print(self.task_pool.update_many(query, {"$set": {"status": status}}))
|
||||
|
||||
def prioritize(self, task, priority: int):
|
||||
@@ -396,15 +411,29 @@ class TaskManager:
|
||||
self.task_pool.update_one({"_id": task["_id"]}, update_dict)
|
||||
|
||||
def _get_undone_n(self, task_stat):
|
||||
return task_stat.get(self.STATUS_WAITING, 0) + task_stat.get(self.STATUS_RUNNING, 0)
|
||||
return (
|
||||
task_stat.get(self.STATUS_WAITING, 0)
|
||||
+ task_stat.get(self.STATUS_RUNNING, 0)
|
||||
+ task_stat.get(self.STATUS_PART_DONE, 0)
|
||||
)
|
||||
|
||||
def _get_total(self, task_stat):
|
||||
return sum(task_stat.values())
|
||||
|
||||
def wait(self, query={}):
|
||||
"""
|
||||
When multiprocessing, the main progress may fetch nothing from TaskManager because there are still some running tasks.
|
||||
So main progress should wait until all tasks are trained well by other progress or machines.
|
||||
|
||||
Args:
|
||||
query (dict, optional): the query dict. Defaults to {}.
|
||||
"""
|
||||
task_stat = self.task_stat(query)
|
||||
total = self._get_total(task_stat)
|
||||
last_undone_n = self._get_undone_n(task_stat)
|
||||
if last_undone_n == 0:
|
||||
return
|
||||
self.logger.warning(f"Waiting for {last_undone_n} undone tasks. Please make sure they are running.")
|
||||
with tqdm(total=total, initial=total - last_undone_n) as pbar:
|
||||
while True:
|
||||
time.sleep(10)
|
||||
|
||||
@@ -17,7 +17,6 @@ def experiment_exit_handler():
|
||||
Thus, if any exception or user interuption occurs beforehead, we should handle them first. Once `R` is
|
||||
ended, another call of `R.end_exp` will not take effect.
|
||||
"""
|
||||
signal.signal(signal.SIGINT, experiment_kill_signal_handler) # handle user keyboard interupt
|
||||
sys.excepthook = experiment_exception_hook # handle uncaught exception
|
||||
atexit.register(R.end_exp, recorder_status=Recorder.STATUS_FI) # will not take effect if experiment ends
|
||||
|
||||
@@ -39,10 +38,3 @@ def experiment_exception_hook(type, value, tb):
|
||||
print(f"{type.__name__}: {value}")
|
||||
|
||||
R.end_exp(recorder_status=Recorder.STATUS_FA)
|
||||
|
||||
|
||||
def experiment_kill_signal_handler(signum, frame):
|
||||
"""
|
||||
End an experiment when user kill the program through keyboard (CTRL+C, etc.).
|
||||
"""
|
||||
R.end_exp(recorder_status=Recorder.STATUS_FA)
|
||||
|
||||
@@ -5,6 +5,5 @@ numpy
|
||||
pandas
|
||||
tqdm
|
||||
lxml
|
||||
loguru
|
||||
yahooquery
|
||||
joblib
|
||||
|
||||
Reference in New Issue
Block a user