mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-06 04:20:57 +08:00
Merge remote-tracking branch 'qlib/main' into qlib_main
# Conflicts: # scripts/data_collector/yahoo/README.md
This commit is contained in:
52
examples/benchmarks/TCTS/README.md
Normal file
52
examples/benchmarks/TCTS/README.md
Normal file
@@ -0,0 +1,52 @@
|
||||
# Temporally Correlated Task Scheduling for Sequence Learning
|
||||
We provide the [code](https://github.com/microsoft/qlib/blob/main/qlib/contrib/model/pytorch_tcts.py) for reproducing the stock trend forecasting experiments.
|
||||
|
||||
### Background
|
||||
Sequence learning has attracted much research attention from the machine learning community in recent years. In many applications, a sequence learning task is usually associated with multiple temporally correlated auxiliary tasks, which are different in terms of how much input information to use or which future step to predict. In stock trend forecasting, as demonstrated in Figure1, one can predict the price of a stock in different future days (e.g., tomorrow, the day after tomorrow). In this paper, we propose a framework to make use of those temporally correlated tasks to help each other.
|
||||
|
||||
<p align="center">
|
||||
<img src="task_description.png" width="600" height="200"/>
|
||||
</p>
|
||||
|
||||
|
||||
### Method
|
||||
Given that there are usually multiple temporally correlated tasks, the key challenge lies in which tasks to use and when to use them in the training process. In this work, we introduce a learnable task scheduler for sequence learning, which adaptively selects temporally correlated tasks during the training process. The scheduler accesses the model status and the current training data (e.g., in current minibatch), and selects the best auxiliary task to help the training of the main task. The scheduler and the model for the main task are jointly trained through bi-level optimization: the scheduler is trained to maximize the validation performance of the model, and the model is trained to minimize the training loss guided by the scheduler. The process is demonstrated in Figure2.
|
||||
|
||||
<p align="center">
|
||||
<img src="workflow.png"/>
|
||||
</p>
|
||||
|
||||
At step <img src="https://render.githubusercontent.com/render/math?math=s">, with training data <img src="https://render.githubusercontent.com/render/math?math=x_s,y_s">, the scheduler <img src="https://render.githubusercontent.com/render/math?math=\varphi"> chooses a suitable task <img src="https://render.githubusercontent.com/render/math?math=T_{i_s}"> (green solid lines) to update the model <img src="https://render.githubusercontent.com/render/math?math=f"> (blue solid lines). After <img src="https://render.githubusercontent.com/render/math?math=S"> steps, we evaluate the model <img src="https://render.githubusercontent.com/render/math?math=f"> on the validation set and update the scheduler <img src="https://render.githubusercontent.com/render/math?math=\varphi"> (green dashed lines).
|
||||
|
||||
### DataSet
|
||||
* We use the historical transaction data for 300 stocks on [CSI300](http://www.csindex.com.cn/en/indices/index-detail/000300) from 01/01/2008 to 08/01/2020.
|
||||
* We split the data into training (01/01/2008-12/31/2013), validation (01/01/2014-12/31/2015), and test sets (01/01/2016-08/01/2020) based on the transaction time.
|
||||
|
||||
### Experiments
|
||||
#### Task Description
|
||||
* The main tasks <img src="https://render.githubusercontent.com/render/math?math=T_k"> (<img src="https://render.githubusercontent.com/render/math?math=task_k"> in Figure1) refers to forecasting return of stock <img src="https://render.githubusercontent.com/render/math?math=i"> as following,
|
||||
<div align=center>
|
||||
<img src="https://render.githubusercontent.com/render/math?math=r_{i}^k = \frac{\price_i^{t+k}}{\price_i^{t+k-1}} - 1">
|
||||
</div>
|
||||
|
||||
* Temporally correlated task sets <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_k = \{T_1, T_2, ... , T_k\}">, in this paper, <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">, <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5"> and <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_10"> are used.
|
||||
#### Baselines
|
||||
* GRU/MLP/LightGBM (LGB)/Graph Attention Networks (GAT)
|
||||
* Multi-task learning (MTL): In multi-task learning, multiple tasks are jointly trained and mutually boosted. Each task is treated equally, while in our setting, we focus on the main task.
|
||||
* Curriculum transfer learning (CL): Transfer learning also leverages auxiliary tasks to boost the main task. [Curriculum transfer learning](https://arxiv.org/pdf/1804.00810.pdf) is one kind of transfer learning which schedules auxiliary tasks according to certain rules. Our problem can also be regarded as a special kind of transfer learning, where the auxiliary tasks are temporally correlated with the main task. Our learning process is dynamically controlled by a scheduler rather than some pre-defined rules. In the CL baseline, we start from the task <img src="https://render.githubusercontent.com/render/math?math=T_1" >, then <img src="https://render.githubusercontent.com/render/math?math=T_2" >, and gradually move to the last one.
|
||||
#### Result
|
||||
| Methods | <img src="https://render.githubusercontent.com/render/math?math=T_1" > | <img src="https://render.githubusercontent.com/render/math?math=T_2"> | <img src="https://render.githubusercontent.com/render/math?math=T_3"> |
|
||||
| :----: | :----: | :----: | :----: |
|
||||
| GRU | 0.049 / 1.903 | 0.018 / 1.972 | 0.014 / 1.989 |
|
||||
| MLP | 0.023 / 1.961 | 0.022 / 1.962 | 0.015 / 1.978 |
|
||||
| LGB | 0.038 / 1.883 | 0.023 / 1.952 | 0.007 / 1.987 |
|
||||
| GAT | 0.052 / 1.898 | 0.024 / 1.954 | 0.015 / 1.973 |
|
||||
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.061 / 1.862 | 0.023 / 1.942 | 0.012 / 1.956 |
|
||||
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.051 / 1.880 | 0.028 / 1.941 | 0.016 / 1.962 |
|
||||
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.071 / 1.851 | 0.030 / 1.939 | 0.017 / 1.963 |
|
||||
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.057 / 1.875 | 0.021 / 1.939 | 0.017 / 1.959 |
|
||||
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.056 / 1.877 | 0.028 / 1.942 | 0.015 / 1.962 |
|
||||
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.075 / 1.849 | 0.032 /1.939 | 0.021 / 1.955 |
|
||||
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.052 / 1.882 | 0.020 / 1.947 | 0.019 / 1.952 |
|
||||
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.051 / 1.882 | 0.028 / 1.950 | 0.016 / 1.961 |
|
||||
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.067 / 1.867 | 0.030 / 1.960 | 0.022 / 1.942|
|
||||
52
examples/benchmarks/TCTS/TCTS.md
Normal file
52
examples/benchmarks/TCTS/TCTS.md
Normal file
@@ -0,0 +1,52 @@
|
||||
# Temporally Correlated Task Scheduling for Sequence Learning
|
||||
We provide the [code](https://github.com/microsoft/qlib/blob/main/qlib/contrib/model/pytorch_tcts.py) for reproducing the stock trend forecasting experiments.
|
||||
|
||||
### Background
|
||||
Sequence learning has attracted much research attention from the machine learning community in recent years. In many applications, a sequence learning task is usually associated with multiple temporally correlated auxiliary tasks, which are different in terms of how much input information to use or which future step to predict. In stock trend forecasting, as demonstrated in Figure1, one can predict the price of a stock in different future days (e.g., tomorrow, the day after tomorrow). In this paper, we propose a framework to make use of those temporally correlated tasks to help each other.
|
||||
|
||||
<p align="center">
|
||||
<img src="task_description.png" width="600" height="200"/>
|
||||
</p>
|
||||
|
||||
|
||||
### Method
|
||||
Given that there are usually multiple temporally correlated tasks, the key challenge lies in which tasks to use and when to use them in the training process. In this work, we introduce a learnable task scheduler for sequence learning, which adaptively selects temporally correlated tasks during the training process. The scheduler accesses the model status and the current training data (e.g., in current minibatch), and selects the best auxiliary task to help the training of the main task. The scheduler and the model for the main task are jointly trained through bi-level optimization: the scheduler is trained to maximize the validation performance of the model, and the model is trained to minimize the training loss guided by the scheduler. The process is demonstrated in Figure2.
|
||||
|
||||
<p align="center">
|
||||
<img src="workflow.png"/>
|
||||
</p>
|
||||
|
||||
At step <img src="https://render.githubusercontent.com/render/math?math=s">, with training data <img src="https://render.githubusercontent.com/render/math?math=x_s,y_s">, the scheduler <img src="https://render.githubusercontent.com/render/math?math=\varphi"> chooses a suitable task <img src="https://render.githubusercontent.com/render/math?math=T_{i_s}"> (green solid lines) to update the model <img src="https://render.githubusercontent.com/render/math?math=f"> (blue solid lines). After <img src="https://render.githubusercontent.com/render/math?math=S"> steps, we evaluate the model <img src="https://render.githubusercontent.com/render/math?math=f"> on the validation set and update the scheduler <img src="https://render.githubusercontent.com/render/math?math=\varphi"> (green dashed lines).
|
||||
|
||||
### DataSet
|
||||
* We use the historical transaction data for 300 stocks on [CSI300](http://www.csindex.com.cn/en/indices/index-detail/000300) from 01/01/2008 to 08/01/2020.
|
||||
* We split the data into training (01/01/2008-12/31/2013), validation (01/01/2014-12/31/2015), and test sets (01/01/2016-08/01/2020) based on the transaction time.
|
||||
|
||||
### Experiments
|
||||
#### Task Description
|
||||
* The main tasks <img src="https://render.githubusercontent.com/render/math?math=T_k"> (<img src="https://render.githubusercontent.com/render/math?math=task_k"> in Figure1) refers to forecasting return of stock <img src="https://render.githubusercontent.com/render/math?math=i"> as following,
|
||||
<div align=center>
|
||||
<img src="https://render.githubusercontent.com/render/math?math=r_{i}^k = \frac{\price_i^{t+k}}{\price_i^{t+k-1}} - 1">
|
||||
</div>
|
||||
|
||||
* Temporally correlated task sets <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_k = \{T_1, T_2, ... , T_k\}">, in this paper, <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">, <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5"> and <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_10"> are used.
|
||||
#### Baselines
|
||||
* GRU/MLP/LightGBM (LGB)/Graph Attention Networks (GAT)
|
||||
* Multi-task learning (MTL): In multi-task learning, multiple tasks are jointly trained and mutually boosted. Each task is treated equally, while in our setting, we focus on the main task.
|
||||
* Curriculum transfer learning (CL): Transfer learning also leverages auxiliary tasks to boost the main task. [Curriculum transfer learning](https://arxiv.org/pdf/1804.00810.pdf) is one kind of transfer learning which schedules auxiliary tasks according to certain rules. Our problem can also be regarded as a special kind of transfer learning, where the auxiliary tasks are temporally correlated with the main task. Our learning process is dynamically controlled by a scheduler rather than some pre-defined rules. In the CL baseline, we start from the task <img src="https://render.githubusercontent.com/render/math?math=T_1" >, then <img src="https://render.githubusercontent.com/render/math?math=T_2" >, and gradually move to the last one.
|
||||
#### Result
|
||||
| Methods | <img src="https://render.githubusercontent.com/render/math?math=T_1" > | <img src="https://render.githubusercontent.com/render/math?math=T_2"> | <img src="https://render.githubusercontent.com/render/math?math=T_3"> |
|
||||
| :----: | :----: | :----: | :----: |
|
||||
| GRU | 0.049 / 1.903 | 0.018 / 1.972 | 0.014 / 1.989 |
|
||||
| MLP | 0.023 / 1.961 | 0.022 / 1.962 | 0.015 / 1.978 |
|
||||
| LGB | 0.038 / 1.883 | 0.023 / 1.952 | 0.007 / 1.987 |
|
||||
| GAT | 0.052 / 1.898 | 0.024 / 1.954 | 0.015 / 1.973 |
|
||||
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.061 / 1.862 | 0.023 / 1.942 | 0.012 / 1.956 |
|
||||
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.051 / 1.880 | 0.028 / 1.941 | 0.016 / 1.962 |
|
||||
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.071 / 1.851 | 0.030 / 1.939 | 0.017 / 1.963 |
|
||||
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.057 / 1.875 | 0.021 / 1.939 | 0.017 / 1.959 |
|
||||
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.056 / 1.877 | 0.028 / 1.942 | 0.015 / 1.962 |
|
||||
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.075 / 1.849 | 0.032 /1.939 | 0.021 / 1.955 |
|
||||
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.052 / 1.882 | 0.020 / 1.947 | 0.019 / 1.952 |
|
||||
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.051 / 1.882 | 0.028 / 1.950 | 0.016 / 1.961 |
|
||||
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.067 / 1.867 | 0.030 / 1.960 | 0.022 / 1.942|
|
||||
BIN
examples/benchmarks/TCTS/task_description.png
Normal file
BIN
examples/benchmarks/TCTS/task_description.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 25 KiB |
BIN
examples/benchmarks/TCTS/workflow.png
Normal file
BIN
examples/benchmarks/TCTS/workflow.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 29 KiB |
93
examples/benchmarks/TCTS/workflow_config_tcts_Alpha360.yaml
Normal file
93
examples/benchmarks/TCTS/workflow_config_tcts_Alpha360.yaml
Normal file
@@ -0,0 +1,93 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1",
|
||||
"Ref($close, -3) / Ref($close, -1) - 1",
|
||||
"Ref($close, -4) / Ref($close, -1) - 1",
|
||||
"Ref($close, -5) / Ref($close, -1) - 1",
|
||||
"Ref($close, -6) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: TCTS
|
||||
module_path: qlib.contrib.model.pytorch_tcts
|
||||
kwargs:
|
||||
d_feat: 6
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
dropout: 0.0
|
||||
n_epochs: 200
|
||||
lr: 1e-3
|
||||
early_stop: 20
|
||||
batch_size: 800
|
||||
metric: loss
|
||||
loss: mse
|
||||
GPU: 0
|
||||
fore_optimizer: adam
|
||||
weight_optimizer: adam
|
||||
output_dim: 5
|
||||
fore_lr: 5e-7
|
||||
weight_lr: 5e-7
|
||||
steps: 3
|
||||
target_label: 0
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -4,6 +4,7 @@
|
||||
"""
|
||||
This example shows how a TrainerRM works based on TaskManager with rolling tasks.
|
||||
After training, how to collect the rolling results will be shown in task_collecting.
|
||||
Based on the ability of TaskManager, `worker` method offer a simple way for multiprocessing.
|
||||
"""
|
||||
|
||||
from pprint import pprint
|
||||
@@ -13,7 +14,7 @@ import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.workflow.task.manage import TaskManager, run_task
|
||||
from qlib.workflow.task.collect import RecorderCollector
|
||||
from qlib.model.ens.group import RollingGroup
|
||||
from qlib.model.trainer import TrainerRM
|
||||
@@ -68,6 +69,11 @@ class RollingTaskExample:
|
||||
trainer = TrainerRM(self.experiment_name, self.task_pool)
|
||||
trainer.train(tasks)
|
||||
|
||||
def worker(self):
|
||||
# train tasks by other progress or machines for multiprocessing. It is same as TrainerRM.worker.
|
||||
print("========== worker ==========")
|
||||
run_task(task_train, self.task_pool, experiment_name=self.experiment_name)
|
||||
|
||||
def task_collecting(self):
|
||||
print("========== task_collecting ==========")
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
This example is about how can simulate the OnlineManager based on rolling tasks.
|
||||
"""
|
||||
|
||||
from pprint import pprint
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM
|
||||
@@ -13,7 +14,7 @@ from qlib.workflow.online.manager import OnlineManager
|
||||
from qlib.workflow.online.strategy import RollingStrategy
|
||||
from qlib.workflow.task.gen import RollingGen
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG
|
||||
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG_ONLINE, CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE
|
||||
|
||||
|
||||
class OnlineSimulationExample:
|
||||
@@ -22,8 +23,8 @@ class OnlineSimulationExample:
|
||||
provider_uri="~/.qlib/qlib_data/cn_data",
|
||||
region="cn",
|
||||
exp_name="rolling_exp",
|
||||
task_url="mongodb://10.0.0.4:27017/",
|
||||
task_db_name="rolling_db",
|
||||
task_url="mongodb://10.0.0.4:27017/", # not necessary when using TrainerR or DelayTrainerR
|
||||
task_db_name="rolling_db", # not necessary when using TrainerR or DelayTrainerR
|
||||
task_pool="rolling_task",
|
||||
rolling_step=80,
|
||||
start_time="2018-09-10",
|
||||
@@ -46,7 +47,7 @@ class OnlineSimulationExample:
|
||||
tasks (dict or list[dict]): a set of the task config waiting for rolling and training
|
||||
"""
|
||||
if tasks is None:
|
||||
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG]
|
||||
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE, CSI100_RECORD_LGB_TASK_CONFIG_ONLINE]
|
||||
self.exp_name = exp_name
|
||||
self.task_pool = task_pool
|
||||
self.start_time = start_time
|
||||
@@ -59,7 +60,7 @@ class OnlineSimulationExample:
|
||||
self.rolling_gen = RollingGen(
|
||||
step=rolling_step, rtype=RollingGen.ROLL_SD, ds_extra_mod_func=None
|
||||
) # The rolling tasks generator, ds_extra_mod_func is None because we just need to simulate to 2018-10-31 and needn't change the handler end time.
|
||||
self.trainer = DelayTrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR
|
||||
self.trainer = TrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR
|
||||
self.rolling_online_manager = OnlineManager(
|
||||
RollingStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen),
|
||||
trainer=self.trainer,
|
||||
@@ -85,6 +86,15 @@ class OnlineSimulationExample:
|
||||
print("========== signals ==========")
|
||||
print(self.rolling_online_manager.get_signals())
|
||||
|
||||
def worker(self):
|
||||
# train tasks by other progress or machines for multiprocessing
|
||||
# FIXME: only can call after finishing simulation when using DelayTrainerRM, or there will be some exception.
|
||||
print("========== worker ==========")
|
||||
if isinstance(self.trainer, TrainerRM):
|
||||
self.trainer.worker()
|
||||
else:
|
||||
print(f"{type(self.trainer)} is not supported for worker.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
## to run all workflow automatically with your own parameters, use the command below
|
||||
|
||||
@@ -13,11 +13,13 @@ Finally, the OnlineManager will finish second routine and update all strategies.
|
||||
import os
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM, end_task_train, task_train
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.online.strategy import RollingStrategy
|
||||
from qlib.workflow.task.gen import RollingGen
|
||||
from qlib.workflow.online.manager import OnlineManager
|
||||
from qlib.tests.config import CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG
|
||||
from qlib.tests.config import CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING, CSI100_RECORD_LGB_TASK_CONFIG_ROLLING
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
|
||||
|
||||
class RollingOnlineExample:
|
||||
@@ -25,16 +27,17 @@ class RollingOnlineExample:
|
||||
self,
|
||||
provider_uri="~/.qlib/qlib_data/cn_data",
|
||||
region="cn",
|
||||
task_url="mongodb://10.0.0.4:27017/",
|
||||
task_db_name="rolling_db",
|
||||
trainer=DelayTrainerRM(), # you can choose from TrainerR, TrainerRM, DelayTrainerR, DelayTrainerRM
|
||||
task_url="mongodb://10.0.0.4:27017/", # not necessary when using TrainerR or DelayTrainerR
|
||||
task_db_name="rolling_db", # not necessary when using TrainerR or DelayTrainerR
|
||||
rolling_step=550,
|
||||
tasks=None,
|
||||
add_tasks=None,
|
||||
):
|
||||
if add_tasks is None:
|
||||
add_tasks = [CSI100_RECORD_LGB_TASK_CONFIG]
|
||||
add_tasks = [CSI100_RECORD_LGB_TASK_CONFIG_ROLLING]
|
||||
if tasks is None:
|
||||
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG]
|
||||
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING]
|
||||
mongo_conf = {
|
||||
"task_url": task_url, # your MongoDB url
|
||||
"task_db_name": task_db_name, # database name
|
||||
@@ -53,17 +56,28 @@ class RollingOnlineExample:
|
||||
RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD),
|
||||
)
|
||||
)
|
||||
|
||||
self.rolling_online_manager = OnlineManager(strategies)
|
||||
self.trainer = trainer
|
||||
self.rolling_online_manager = OnlineManager(strategies, trainer=self.trainer)
|
||||
|
||||
_ROLLING_MANAGER_PATH = (
|
||||
".RollingOnlineExample" # the OnlineManager will dump to this file, for it can be loaded when calling routine.
|
||||
)
|
||||
|
||||
def worker(self):
|
||||
# train tasks by other progress or machines for multiprocessing
|
||||
print("========== worker ==========")
|
||||
if isinstance(self.trainer, TrainerRM):
|
||||
for task in self.tasks + self.add_tasks:
|
||||
name_id = task["model"]["class"]
|
||||
self.trainer.worker(experiment_name=name_id)
|
||||
else:
|
||||
print(f"{type(self.trainer)} is not supported for worker.")
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(self):
|
||||
for task in self.tasks + self.add_tasks:
|
||||
name_id = task["model"]["class"]
|
||||
TaskManager(task_pool=name_id).remove()
|
||||
exp = R.get_exp(experiment_name=name_id)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
Reference in New Issue
Block a user