diff --git a/examples/benchmarks/TCTS/TCTS.md b/examples/benchmarks/TCTS/TCTS.md
new file mode 100644
index 000000000..ee67ffbeb
--- /dev/null
+++ b/examples/benchmarks/TCTS/TCTS.md
@@ -0,0 +1,52 @@
+# Temporally Correlated Task Scheduling for Sequence Learning
+We provide the [code](https://github.com/microsoft/qlib/blob/main/qlib/contrib/model/pytorch_tcts.py) for reproducing the stock trend forecasting experiments.
+
+### Background
+Sequence learning has attracted much research attention from the machine learning community in recent years. In many applications, a sequence learning task is usually associated with multiple temporally correlated auxiliary tasks, which are different in terms of how much input information to use or which future step to predict. In stock trend forecasting, as demonstrated in Figure1, one can predict the price of a stock in different future days (e.g., tomorrow, the day after tomorrow). In this paper, we propose a framework to make use of those temporally correlated tasks to help each other.
+
+
+
+
+
+
+### Method
+Given that there are usually multiple temporally correlated tasks, the key challenge lies in which tasks to use and when to use them in the training process. In this work, we introduce a learnable task scheduler for sequence learning, which adaptively selects temporally correlated tasks during the training process. The scheduler accesses the model status and the current training data (e.g., in current minibatch), and selects the best auxiliary task to help the training of the main task. The scheduler and the model for the main task are jointly trained through bi-level optimization: the scheduler is trained to maximize the validation performance of the model, and the model is trained to minimize the training loss guided by the scheduler. The process is demonstrated in Figure2.
+
+
+
+
+
+At step
, with training data
, the scheduler
chooses a suitable task
(green solid lines) to update the model
(blue solid lines). After
steps, we evaluate the model
on the validation set and update the scheduler
(green dashed lines).
+
+### DataSet
+* We use the historical transaction data for 300 stocks on [CSI300](http://www.csindex.com.cn/en/indices/index-detail/000300) from 01/01/2008 to 08/01/2020.
+* We split the data into training (01/01/2008-12/31/2013), validation (01/01/2014-12/31/2015), and test sets (01/01/2016-08/01/2020) based on the transaction time.
+
+### Experiments
+#### Task Description
+* The main tasks
(
in Figure1) refers to forecasting return of stock
as following,
+
+

+
+
+* Temporally correlated task sets
, in this paper,
,
and
are used.
+#### Baselines
+* GRU/MLP/LightGBM (LGB)/Graph Attention Networks (GAT)
+* Multi-task learning (MTL): In multi-task learning, multiple tasks are jointly trained and mutually boosted. Each task is treated equally, while in our setting, we focus on the main task.
+* Curriculum transfer learning (CL): Transfer learning also leverages auxiliary tasks to boost the main task. [Curriculum transfer learning](https://arxiv.org/pdf/1804.00810.pdf) is one kind of transfer learning which schedules auxiliary tasks according to certain rules. Our problem can also be regarded as a special kind of transfer learning, where the auxiliary tasks are temporally correlated with the main task. Our learning process is dynamically controlled by a scheduler rather than some pre-defined rules. In the CL baseline, we start from the task
, then
, and gradually move to the last one.
+#### Result
+| Methods |
|
|
|
+| :----: | :----: | :----: | :----: |
+| GRU | 0.049 / 1.903 | 0.018 / 1.972 | 0.014 / 1.989 |
+| MLP | 0.023 / 1.961 | 0.022 / 1.962 | 0.015 / 1.978 |
+| LGB | 0.038 / 1.883 | 0.023 / 1.952 | 0.007 / 1.987 |
+| GAT | 0.052 / 1.898 | 0.024 / 1.954 | 0.015 / 1.973 |
+| MTL(
) | 0.061 / 1.862 | 0.023 / 1.942 | 0.012 / 1.956 |
+| CL(
) | 0.051 / 1.880 | 0.028 / 1.941 | 0.016 / 1.962 |
+| Ours(
) | 0.071 / 1.851 | 0.030 / 1.939 | 0.017 / 1.963 |
+| MTL(
) | 0.057 / 1.875 | 0.021 / 1.939 | 0.017 / 1.959 |
+| CL(
) | 0.056 / 1.877 | 0.028 / 1.942 | 0.015 / 1.962 |
+| Ours(
) | 0.075 / 1.849 | 0.032 /1.939 | 0.021 / 1.955 |
+| MTL(
) | 0.052 / 1.882 | 0.020 / 1.947 | 0.019 / 1.952 |
+| CL(
) | 0.051 / 1.882 | 0.028 / 1.950 | 0.016 / 1.961 |
+| Ours(
) | 0.067 / 1.867 | 0.030 / 1.960 | 0.022 / 1.942|
\ No newline at end of file
diff --git a/examples/benchmarks/TCTS/task_description.png b/examples/benchmarks/TCTS/task_description.png
new file mode 100644
index 000000000..7a9005bf2
Binary files /dev/null and b/examples/benchmarks/TCTS/task_description.png differ
diff --git a/examples/benchmarks/TCTS/workflow.png b/examples/benchmarks/TCTS/workflow.png
new file mode 100644
index 000000000..403a17de3
Binary files /dev/null and b/examples/benchmarks/TCTS/workflow.png differ
diff --git a/examples/benchmarks/TCTS/workflow_config_tcts_Alpha360.yaml b/examples/benchmarks/TCTS/workflow_config_tcts_Alpha360.yaml
new file mode 100644
index 000000000..589f4b43e
--- /dev/null
+++ b/examples/benchmarks/TCTS/workflow_config_tcts_Alpha360.yaml
@@ -0,0 +1,93 @@
+qlib_init:
+ provider_uri: "~/.qlib/qlib_data/cn_data"
+ region: cn
+market: &market csi300
+benchmark: &benchmark SH000300
+data_handler_config: &data_handler_config
+ start_time: 2008-01-01
+ end_time: 2020-08-01
+ fit_start_time: 2008-01-01
+ fit_end_time: 2014-12-31
+ instruments: *market
+ infer_processors:
+ - class: RobustZScoreNorm
+ kwargs:
+ fields_group: feature
+ clip_outlier: true
+ - class: Fillna
+ kwargs:
+ fields_group: feature
+ learn_processors:
+ - class: DropnaLabel
+ - class: CSRankNorm
+ kwargs:
+ fields_group: label
+ label: ["Ref($close, -2) / Ref($close, -1) - 1",
+ "Ref($close, -3) / Ref($close, -1) - 1",
+ "Ref($close, -4) / Ref($close, -1) - 1",
+ "Ref($close, -5) / Ref($close, -1) - 1",
+ "Ref($close, -6) / Ref($close, -1) - 1"]
+port_analysis_config: &port_analysis_config
+ strategy:
+ class: TopkDropoutStrategy
+ module_path: qlib.contrib.strategy.strategy
+ kwargs:
+ topk: 50
+ n_drop: 5
+ backtest:
+ verbose: False
+ limit_threshold: 0.095
+ account: 100000000
+ benchmark: *benchmark
+ deal_price: close
+ open_cost: 0.0005
+ close_cost: 0.0015
+ min_cost: 5
+task:
+ model:
+ class: TCTS
+ module_path: qlib.contrib.model.pytorch_tcts
+ kwargs:
+ d_feat: 6
+ hidden_size: 64
+ num_layers: 2
+ dropout: 0.0
+ n_epochs: 200
+ lr: 1e-3
+ early_stop: 20
+ batch_size: 800
+ metric: loss
+ loss: mse
+ GPU: 0
+ fore_optimizer: adam
+ weight_optimizer: adam
+ output_dim: 5
+ fore_lr: 5e-7
+ weight_lr: 5e-7
+ steps: 3
+ target_label: 0
+ dataset:
+ class: DatasetH
+ module_path: qlib.data.dataset
+ kwargs:
+ handler:
+ class: Alpha360
+ module_path: qlib.contrib.data.handler
+ kwargs: *data_handler_config
+ segments:
+ train: [2008-01-01, 2014-12-31]
+ valid: [2015-01-01, 2016-12-31]
+ test: [2017-01-01, 2020-08-01]
+ record:
+ - class: SignalRecord
+ module_path: qlib.workflow.record_temp
+ kwargs: {}
+ - class: SigAnaRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ ana_long_short: False
+ ann_scaler: 252
+ - class: PortAnaRecord
+ module_path: qlib.workflow.record_temp
+ kwargs:
+ config: *port_analysis_config
\ No newline at end of file
diff --git a/examples/model_rolling/task_manager_rolling.py b/examples/model_rolling/task_manager_rolling.py
index 9ef8694bf..844f18198 100644
--- a/examples/model_rolling/task_manager_rolling.py
+++ b/examples/model_rolling/task_manager_rolling.py
@@ -4,6 +4,7 @@
"""
This example shows how a TrainerRM works based on TaskManager with rolling tasks.
After training, how to collect the rolling results will be shown in task_collecting.
+Based on the ability of TaskManager, `worker` method offer a simple way for multiprocessing.
"""
from pprint import pprint
@@ -13,7 +14,7 @@ import qlib
from qlib.config import REG_CN
from qlib.workflow import R
from qlib.workflow.task.gen import RollingGen, task_generator
-from qlib.workflow.task.manage import TaskManager
+from qlib.workflow.task.manage import TaskManager, run_task
from qlib.workflow.task.collect import RecorderCollector
from qlib.model.ens.group import RollingGroup
from qlib.model.trainer import TrainerRM
@@ -68,6 +69,11 @@ class RollingTaskExample:
trainer = TrainerRM(self.experiment_name, self.task_pool)
trainer.train(tasks)
+ def worker(self):
+ # train tasks by other progress or machines for multiprocessing. It is same as TrainerRM.worker.
+ print("========== worker ==========")
+ run_task(task_train, self.task_pool, experiment_name=self.experiment_name)
+
def task_collecting(self):
print("========== task_collecting ==========")
diff --git a/examples/nested_decision_execution/README.md b/examples/nested_decision_execution/README.md
index 312f94d31..382e5a320 100644
--- a/examples/nested_decision_execution/README.md
+++ b/examples/nested_decision_execution/README.md
@@ -1,6 +1,6 @@
# Nested Decision Execution
-This worflow is an example for nested decision execution in backtesting. Qlib supports nested decision execution in backtesting. It means that users can use different strategies to make trade decision in different frequencies.
+This workflow is an example for nested decision execution in backtesting. Qlib supports nested decision execution in backtesting. It means that users can use different strategies to make trade decision in different frequencies.
## Weekly Portfolio Generation and Daily Order Execution
diff --git a/examples/nested_decision_execution/workflow.py b/examples/nested_decision_execution/workflow.py
index b8e9e5fb5..2286f4f12 100644
--- a/examples/nested_decision_execution/workflow.py
+++ b/examples/nested_decision_execution/workflow.py
@@ -19,10 +19,10 @@ class NestedDecisonExecutionWorkflow:
benchmark = "SH000300"
data_handler_config = {
- "start_time": "2008-01-01",
- "end_time": "2021-01-20",
- "fit_start_time": "2008-01-01",
- "fit_end_time": "2014-12-31",
+ "start_time": "2010-01-01",
+ "end_time": "2021-05-28",
+ "fit_start_time": "2010-01-01",
+ "fit_end_time": "2017-12-31",
"instruments": market,
}
@@ -52,9 +52,9 @@ class NestedDecisonExecutionWorkflow:
"kwargs": data_handler_config,
},
"segments": {
- "train": ("2008-01-01", "2014-12-31"),
- "valid": ("2015-01-01", "2016-12-31"),
- "test": ("2017-01-01", "2021-01-20"),
+ "train": ("2010-01-01", "2017-12-31"),
+ "valid": ("2018-01-01", "2019-12-31"),
+ "test": ("2020-01-01", "2021-05-28"),
},
},
},
@@ -67,33 +67,45 @@ class NestedDecisonExecutionWorkflow:
"kwargs": {
"time_per_step": "week",
"inner_executor": {
- "class": "SimulatorExecutor",
+ "class": "NestedExecutor",
"module_path": "qlib.backtest.executor",
"kwargs": {
"time_per_step": "day",
- "verbose": True,
- "generate_report": True,
+ "inner_executor": {
+ "class": "SimulatorExecutor",
+ "module_path": "qlib.backtest.executor",
+ "kwargs": {
+ "time_per_step": "15min",
+ "generate_report": True,
+ "verbose": True,
+ },
+ },
+ "inner_strategy": {
+ "class": "TWAPStrategy",
+ "module_path": "qlib.contrib.strategy.rule_strategy",
+ },
+ "show_indicator": True,
},
},
"inner_strategy": {
- "class": "SBBStrategyEMA",
+ "class": "VAStrategy",
"module_path": "qlib.contrib.strategy.rule_strategy",
"kwargs": {
"freq": "day",
"instruments": market,
},
},
- "generate_report": True,
"track_data": True,
+ "show_indicator": True,
},
},
"backtest": {
- "start_time": "2017-01-01",
- "end_time": "2020-08-01",
+ "start_time": "2020-09-20",
+ "end_time": "2021-05-28",
"account": 100000000,
"benchmark": benchmark,
"exchange_kwargs": {
- "freq": "day",
+ "freq": "1min",
"limit_threshold": 0.095,
"deal_price": "close",
"open_cost": 0.0005,
@@ -105,11 +117,40 @@ class NestedDecisonExecutionWorkflow:
def _init_qlib(self):
"""initialize qlib"""
- provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
- if not exists_qlib_data(provider_uri):
- print(f"Qlib data is not found in {provider_uri}")
- GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
- qlib.init(provider_uri=provider_uri, region=REG_CN)
+ provider_uri_day = "/data1/v-xiabi/qlib/qlib_data/cn_data" # target_dir
+ GetData().qlib_data(target_dir=provider_uri_day, region=REG_CN, version="v2", exists_skip=True)
+ # provider_uri_1min = HIGH_FREQ_CONFIG.get("provider_uri")
+ provider_uri_1min = "/data1/v-xiabi/qlib/qlib_data/cn_data_highfreq"
+ GetData().qlib_data(
+ target_dir=provider_uri_1min, interval="1min", region=REG_CN, version="v2", exists_skip=True
+ )
+
+ provider_uri_map = {"1min": provider_uri_1min, "day": provider_uri_day}
+ client_config = {
+ "calendar_provider": {
+ "class": "LocalCalendarProvider",
+ "module_path": "qlib.data.data",
+ "kwargs": {
+ "backend": {
+ "class": "FileCalendarStorage",
+ "module_path": "qlib.data.storage.file_storage",
+ "kwargs": {"provider_uri_map": provider_uri_map},
+ }
+ },
+ },
+ "feature_provider": {
+ "class": "LocalFeatureProvider",
+ "module_path": "qlib.data.data",
+ "kwargs": {
+ "backend": {
+ "class": "FileFeatureStorage",
+ "module_path": "qlib.data.storage.file_storage",
+ "kwargs": {"provider_uri_map": provider_uri_map},
+ }
+ },
+ },
+ }
+ qlib.init(provider_uri=provider_uri_day, **client_config)
def _train_model(self, model, dataset):
with R.start(experiment_name="train"):
@@ -141,7 +182,7 @@ class NestedDecisonExecutionWorkflow:
with R.start(experiment_name="backtest"):
recorder = R.get_recorder()
- par = PortAnaRecord(recorder, self.port_analysis_config, "day")
+ par = PortAnaRecord(recorder, self.port_analysis_config, "15minute")
par.generate()
def collect_data(self):
@@ -165,98 +206,6 @@ class NestedDecisonExecutionWorkflow:
for trade_decision in data_generator:
print(trade_decision)
- def _init_qlib_with_backend(self):
- provider_uri_1min = HIGH_FREQ_CONFIG.get("provider_uri")
- if not exists_qlib_data(provider_uri_1min):
- print(f"Qlib data is not found in {provider_uri_1min}")
- GetData().qlib_data(target_dir=provider_uri_1min, interval="1min", region=REG_CN)
-
- # TODO: update latest data
- provider_uri_day = "~/.qlib/qlib_data/cn_data" # target_dir
- if not exists_qlib_data(provider_uri_day):
- print(f"Qlib data is not found in {provider_uri_day}")
- GetData().qlib_data(target_dir=provider_uri_day, region=REG_CN)
-
- provider_uri_map = {"1min": provider_uri_1min, "day": provider_uri_day}
- client_config = {
- "calendar_provider": {
- "class": "LocalCalendarProvider",
- "module_path": "qlib.data.data",
- "kwargs": {
- "backend": {
- "class": "FileCalendarStorage",
- "module_path": "qlib.data.storage.file_storage",
- "kwargs": {"provider_uri_map": provider_uri_map},
- }
- },
- },
- "feature_provider": {
- "class": "LocalFeatureProvider",
- "module_path": "qlib.data.data",
- "kwargs": {
- "backend": {
- "class": "FileFeatureStorage",
- "module_path": "qlib.data.storage.file_storage",
- "kwargs": {"provider_uri_map": provider_uri_map},
- }
- },
- },
- }
- qlib.init(provider_uri=provider_uri_day, **client_config)
-
- def _get_highfreq_config(self, model, dataset):
-
- executor_config = self.port_analysis_config["executor"]
- # update executor with hierarchical decison freq ["day", "1min"]
- executor_config["kwargs"]["time_per_step"] = "day"
- executor_config["kwargs"]["inner_executor"]["kwargs"]["time_per_step"] = "15min"
- backtest_config = self.port_analysis_config["backtest"]
-
- # yahoo highfreq data time
- backtest_config["start_time"] = "2020-09-20"
- backtest_config["end_time"] = "2021-01-20"
-
- # update benchmark, yahoo data don't have SH000300
- instruments = D.instruments(market="csi300")
- instrument_list = D.list_instruments(instruments=instruments, as_list=True)
- backtest_config["benchmark"] = instrument_list
-
- # update exchange config
- backtest_config["exchange_kwargs"]["freq"] = "1min"
-
- # set strategy
- strategy_config = {
- "class": "TopkDropoutStrategy",
- "module_path": "qlib.contrib.strategy.model_strategy",
- "kwargs": {
- "model": model,
- "dataset": dataset,
- "topk": 50,
- "n_drop": 5,
- },
- }
-
- return executor_config, strategy_config, backtest_config
-
- def backtest_highfreq(self):
- self._init_qlib_with_backend()
- model = init_instance_by_config(self.task["model"])
- dataset = init_instance_by_config(self.task["dataset"])
- self._train_model(model, dataset)
- executor_config, strategy_config, backtest_config = self._get_highfreq_config(model, dataset)
-
- highfreq_port_analysis_config = {
- "executor": executor_config,
- "strategy": strategy_config,
- "backtest": backtest_config,
- }
-
- with R.start(experiment_name="backtest_highfreq"):
-
- recorder = R.get_recorder()
- par = PortAnaRecord(recorder, highfreq_port_analysis_config, "day")
- par.generate()
-
if __name__ == "__main__":
fire.Fire(NestedDecisonExecutionWorkflow)
diff --git a/examples/online_srv/online_management_simulate.py b/examples/online_srv/online_management_simulate.py
index 8c9e77bf7..bd7c4675d 100644
--- a/examples/online_srv/online_management_simulate.py
+++ b/examples/online_srv/online_management_simulate.py
@@ -5,6 +5,7 @@
This example is about how can simulate the OnlineManager based on rolling tasks.
"""
+from pprint import pprint
import fire
import qlib
from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM
@@ -13,7 +14,7 @@ from qlib.workflow.online.manager import OnlineManager
from qlib.workflow.online.strategy import RollingStrategy
from qlib.workflow.task.gen import RollingGen
from qlib.workflow.task.manage import TaskManager
-from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG
+from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG_ONLINE, CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE
class OnlineSimulationExample:
@@ -22,8 +23,8 @@ class OnlineSimulationExample:
provider_uri="~/.qlib/qlib_data/cn_data",
region="cn",
exp_name="rolling_exp",
- task_url="mongodb://10.0.0.4:27017/",
- task_db_name="rolling_db",
+ task_url="mongodb://10.0.0.4:27017/", # not necessary when using TrainerR or DelayTrainerR
+ task_db_name="rolling_db", # not necessary when using TrainerR or DelayTrainerR
task_pool="rolling_task",
rolling_step=80,
start_time="2018-09-10",
@@ -46,7 +47,7 @@ class OnlineSimulationExample:
tasks (dict or list[dict]): a set of the task config waiting for rolling and training
"""
if tasks is None:
- tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG]
+ tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE, CSI100_RECORD_LGB_TASK_CONFIG_ONLINE]
self.exp_name = exp_name
self.task_pool = task_pool
self.start_time = start_time
@@ -59,7 +60,7 @@ class OnlineSimulationExample:
self.rolling_gen = RollingGen(
step=rolling_step, rtype=RollingGen.ROLL_SD, ds_extra_mod_func=None
) # The rolling tasks generator, ds_extra_mod_func is None because we just need to simulate to 2018-10-31 and needn't change the handler end time.
- self.trainer = DelayTrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR
+ self.trainer = TrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR
self.rolling_online_manager = OnlineManager(
RollingStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen),
trainer=self.trainer,
@@ -85,6 +86,15 @@ class OnlineSimulationExample:
print("========== signals ==========")
print(self.rolling_online_manager.get_signals())
+ def worker(self):
+ # train tasks by other progress or machines for multiprocessing
+ # FIXME: only can call after finishing simulation when using DelayTrainerRM, or there will be some exception.
+ print("========== worker ==========")
+ if isinstance(self.trainer, TrainerRM):
+ self.trainer.worker()
+ else:
+ print(f"{type(self.trainer)} is not supported for worker.")
+
if __name__ == "__main__":
## to run all workflow automatically with your own parameters, use the command below
diff --git a/examples/online_srv/rolling_online_management.py b/examples/online_srv/rolling_online_management.py
index 592f1f866..6abbbfb0e 100644
--- a/examples/online_srv/rolling_online_management.py
+++ b/examples/online_srv/rolling_online_management.py
@@ -13,11 +13,13 @@ Finally, the OnlineManager will finish second routine and update all strategies.
import os
import fire
import qlib
+from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM, end_task_train, task_train
from qlib.workflow import R
from qlib.workflow.online.strategy import RollingStrategy
from qlib.workflow.task.gen import RollingGen
from qlib.workflow.online.manager import OnlineManager
-from qlib.tests.config import CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG
+from qlib.tests.config import CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING, CSI100_RECORD_LGB_TASK_CONFIG_ROLLING
+from qlib.workflow.task.manage import TaskManager
class RollingOnlineExample:
@@ -25,16 +27,17 @@ class RollingOnlineExample:
self,
provider_uri="~/.qlib/qlib_data/cn_data",
region="cn",
- task_url="mongodb://10.0.0.4:27017/",
- task_db_name="rolling_db",
+ trainer=DelayTrainerRM(), # you can choose from TrainerR, TrainerRM, DelayTrainerR, DelayTrainerRM
+ task_url="mongodb://10.0.0.4:27017/", # not necessary when using TrainerR or DelayTrainerR
+ task_db_name="rolling_db", # not necessary when using TrainerR or DelayTrainerR
rolling_step=550,
tasks=None,
add_tasks=None,
):
if add_tasks is None:
- add_tasks = [CSI100_RECORD_LGB_TASK_CONFIG]
+ add_tasks = [CSI100_RECORD_LGB_TASK_CONFIG_ROLLING]
if tasks is None:
- tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG]
+ tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING]
mongo_conf = {
"task_url": task_url, # your MongoDB url
"task_db_name": task_db_name, # database name
@@ -53,17 +56,28 @@ class RollingOnlineExample:
RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD),
)
)
-
- self.rolling_online_manager = OnlineManager(strategies)
+ self.trainer = trainer
+ self.rolling_online_manager = OnlineManager(strategies, trainer=self.trainer)
_ROLLING_MANAGER_PATH = (
".RollingOnlineExample" # the OnlineManager will dump to this file, for it can be loaded when calling routine.
)
+ def worker(self):
+ # train tasks by other progress or machines for multiprocessing
+ print("========== worker ==========")
+ if isinstance(self.trainer, TrainerRM):
+ for task in self.tasks + self.add_tasks:
+ name_id = task["model"]["class"]
+ self.trainer.worker(experiment_name=name_id)
+ else:
+ print(f"{type(self.trainer)} is not supported for worker.")
+
# Reset all things to the first status, be careful to save important data
def reset(self):
for task in self.tasks + self.add_tasks:
name_id = task["model"]["class"]
+ TaskManager(task_pool=name_id).remove()
exp = R.get_exp(experiment_name=name_id)
for rid in exp.list_recorders():
exp.delete_recorder(rid)
diff --git a/examples/workflow_by_code.ipynb b/examples/workflow_by_code.ipynb
index 3d99bf1e1..e74897664 100644
--- a/examples/workflow_by_code.ipynb
+++ b/examples/workflow_by_code.ipynb
@@ -362,8 +362,9 @@
],
"metadata": {
"kernelspec": {
- "name": "pythonjvsc74a57bd0fcc004278713aaede7c629a6a43738a929cb09abb52817d4f72eb70db44cd87b",
- "display_name": "Python 3.8 ('qlib_backtest': conda)"
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
},
"language_info": {
"codemirror_mode": {
@@ -375,7 +376,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.8"
+ "version": "3.8.3"
},
"toc": {
"base_numbering": 1,
@@ -389,11 +390,6 @@
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
- },
- "metadata": {
- "interpreter": {
- "hash": "fcc004278713aaede7c629a6a43738a929cb09abb52817d4f72eb70db44cd87b"
- }
}
},
"nbformat": 4,
diff --git a/qlib/backtest/__init__.py b/qlib/backtest/__init__.py
index 1adad91d2..a3706008a 100644
--- a/qlib/backtest/__init__.py
+++ b/qlib/backtest/__init__.py
@@ -4,8 +4,8 @@
from .account import Account
from .exchange import Exchange
from .executor import BaseExecutor
-from .backtest import backtest as backtest_func
-from .backtest import collect_data as data_generator
+from .backtest import backtest_loop
+from .backtest import collect_data_loop
from .utils import CommonInfrastructure
from .order import Order
@@ -116,7 +116,7 @@ def backtest(start_time, end_time, strategy, executor, benchmark="SH000300", acc
trade_strategy, trade_executor = get_strategy_executor(
start_time, end_time, strategy, executor, benchmark, account, exchange_kwargs
)
- report_dict = backtest_func(start_time, end_time, trade_strategy, trade_executor)
+ report_dict = backtest_loop(start_time, end_time, trade_strategy, trade_executor)
return report_dict
@@ -126,6 +126,6 @@ def collect_data(start_time, end_time, strategy, executor, benchmark="SH000300",
trade_strategy, trade_executor = get_strategy_executor(
start_time, end_time, strategy, executor, benchmark, account, exchange_kwargs
)
- report_dict = yield from data_generator(start_time, end_time, trade_strategy, trade_executor)
+ report_dict = yield from collect_data_loop(start_time, end_time, trade_strategy, trade_executor)
return report_dict
diff --git a/qlib/backtest/account.py b/qlib/backtest/account.py
index dfe248c68..71214036a 100644
--- a/qlib/backtest/account.py
+++ b/qlib/backtest/account.py
@@ -7,7 +7,7 @@ import warnings
import pandas as pd
from .position import Position
-from .report import Report
+from .report import Report, Indicator
from .order import Order
@@ -42,6 +42,7 @@ class Account:
def reset_report(self, freq, benchmark_config):
self.report = Report(freq, benchmark_config)
+ self.indicator = Indicator()
self.positions = {}
self.rtn = 0
self.ct = 0
diff --git a/qlib/backtest/backtest.py b/qlib/backtest/backtest.py
index 1f0d2ac38..e9d864c92 100644
--- a/qlib/backtest/backtest.py
+++ b/qlib/backtest/backtest.py
@@ -2,8 +2,25 @@
# Licensed under the MIT License.
-def backtest(start_time, end_time, trade_strategy, trade_executor):
+def backtest_loop(start_time, end_time, trade_strategy, trade_executor):
+ """backtest funciton for the interaction of the outermost strategy and executor in the nested decison execution
+ Parameters
+ ----------
+ start_time : pd.Timestamp|str
+ closed start time for backtest
+ end_time : pd.Timestamp|str
+ closed end time for backtest
+ trade_strategy : BaseStrategy
+ the outermost portfolio strategy
+ trade_executor : BaseExecutor
+ the outermost executor
+
+ Returns
+ -------
+ report: Report
+ it records the trading report information
+ """
trade_executor.reset(start_time=start_time, end_time=end_time)
level_infra = trade_executor.get_level_infra()
trade_strategy.reset(level_infra=level_infra)
@@ -16,8 +33,14 @@ def backtest(start_time, end_time, trade_strategy, trade_executor):
return trade_executor.get_report()
-def collect_data(start_time, end_time, trade_strategy, trade_executor):
+def collect_data_loop(start_time, end_time, trade_strategy, trade_executor):
+ """Generator for collecting the trade decision data for rl training
+ Yields
+ -------
+ object
+ trade decision
+ """
trade_executor.reset(start_time=start_time, end_time=end_time)
level_infra = trade_executor.get_level_infra()
trade_strategy.reset(level_infra=level_infra)
@@ -26,5 +49,3 @@ def collect_data(start_time, end_time, trade_strategy, trade_executor):
while not trade_executor.finished():
_trade_decision = trade_strategy.generate_trade_decision(_execute_result)
_execute_result = yield from trade_executor.collect_data(_trade_decision)
-
- return trade_executor.get_report()
diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py
index 4fc01d8e2..6accb5e05 100644
--- a/qlib/backtest/exchange.py
+++ b/qlib/backtest/exchange.py
@@ -342,7 +342,10 @@ class Exchange:
return -deal_amount
def generate_order_for_target_amount_position(self, target_position, current_position, start_time, end_time):
- """Parameter:
+ """
+ Note: some future information is used in this function
+
+ Parameter:
target_position : dict { stock_id : amount }
current_postion : dict { stock_id : amount}
trade_unit : trade_unit
diff --git a/qlib/backtest/executor.py b/qlib/backtest/executor.py
index 656073759..d68ff3ab1 100644
--- a/qlib/backtest/executor.py
+++ b/qlib/backtest/executor.py
@@ -3,14 +3,14 @@ import warnings
import pandas as pd
from typing import Union
-from ..utils import init_instance_by_config
-from ..utils.resam import parse_freq
-
-
from .order import Order
from .exchange import Exchange
from .utils import TradeCalendarManager, CommonInfrastructure, LevelInfrastructure
+from ..utils import init_instance_by_config
+from ..utils.resam import parse_freq
+from ..strategy.base import BaseStrategy
+
class BaseExecutor:
"""Base executor for trading"""
@@ -20,6 +20,7 @@ class BaseExecutor:
time_per_step: str,
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
+ show_indicator: bool = False,
generate_report: bool = False,
verbose: bool = False,
track_data: bool = False,
@@ -31,12 +32,14 @@ class BaseExecutor:
----------
time_per_step : str
trade time per trading step, used for genreate the trade calendar
+ show_indicator: bool, optional
+ whether to show indicators, such as FFR/PA/POS, .etc
generate_report : bool, optional
whether to generate report, by default False
verbose : bool, optional
whether to print trading info, by default False
track_data : bool, optional
- whether to generate trade_decision, will be used when making data for multi-level training
+ whether to generate trade_decision, will be used when training rl agent
- If `self.track_data` is true, when making data for training, the input `trade_decision` of `execute` will be generated by `collect_data`
- Else, `trade_decision` will not be generated
common_infra : CommonInfrastructure, optional:
@@ -48,6 +51,7 @@ class BaseExecutor:
"""
self.time_per_step = time_per_step
+ self.show_indicator = show_indicator
self.generate_report = generate_report
self.verbose = verbose
self.track_data = track_data
@@ -103,11 +107,27 @@ class BaseExecutor:
Returns
----------
execute_result : List[object]
- the executed result for trade decison
+ the executed result for trade decision
"""
raise NotImplementedError("execute is not implemented!")
def collect_data(self, trade_decision):
+ """Generator for collecting the trade decision data for rl training
+
+ Parameters
+ ----------
+ trade_decision : object
+
+ Returns
+ ----------
+ execute_result : List[object]
+ the executed result for trade decision
+
+ Yields
+ -------
+ object
+ trade decision
+ """
if self.track_data:
yield trade_decision
return self.execute(trade_decision)
@@ -122,6 +142,9 @@ class BaseExecutor:
"""Return all executors"""
return [self]
+ def get_trade_indicator(self):
+ return self.trade_account.indicator.trade_indicator
+
class NestedExecutor(BaseExecutor):
"""
@@ -129,8 +152,6 @@ class NestedExecutor(BaseExecutor):
- At each time `execute` is called, it will call the inner strategy and executor to execute the `trade_decision` in a higher frequency env.
"""
- from ..strategy.base import BaseStrategy
-
def __init__(
self,
time_per_step: str,
@@ -138,6 +159,7 @@ class NestedExecutor(BaseExecutor):
inner_strategy: Union[BaseStrategy, dict],
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
+ show_indicator: bool = False,
generate_report: bool = False,
verbose: bool = False,
track_data: bool = False,
@@ -161,13 +183,14 @@ class NestedExecutor(BaseExecutor):
inner_executor, common_infra=common_infra, accept_types=BaseExecutor
)
self.inner_strategy = init_instance_by_config(
- inner_strategy, common_infra=common_infra, accept_types=self.BaseStrategy
+ inner_strategy, common_infra=common_infra, accept_types=BaseStrategy
)
super(NestedExecutor, self).__init__(
time_per_step=time_per_step,
start_time=start_time,
end_time=end_time,
+ show_indicator=show_indicator,
generate_report=generate_report,
verbose=verbose,
track_data=track_data,
@@ -199,7 +222,7 @@ class NestedExecutor(BaseExecutor):
sub_level_infra = self.inner_executor.get_level_infra()
self.inner_strategy.reset(level_infra=sub_level_infra, outer_trade_decision=trade_decision)
- def _update_trade_account(self):
+ def _update_trade_account(self, inner_indicators):
trade_step = self.trade_calendar.get_trade_step()
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
self.trade_account.update_bar_count()
@@ -210,33 +233,44 @@ class NestedExecutor(BaseExecutor):
trade_exchange=self.trade_exchange,
)
+ self.trade_account.indicator.clear()
+ self.trade_account.indicator.agg_report_info(inner_indicators=inner_indicators)
+ self.trade_account.indicator.agg_FFR()
+ self.trade_account.indicator.agg_PA(inner_indicators=inner_indicators)
+
+ if self.show_indicator:
+ FFR_value = self.trade_account.indicator.get_statistics_FFR(method="value_weighted")
+ PA_value = self.trade_account.indicator.get_statistics_PA(method="value_weighted")
+ POS_values = self.trade_account.indicator.get_statistics_POS()
+ print(
+ "[Indicator({}) {:%Y-%m-%d}]: FFR: {}, PA: {}, POS: {}".format(
+ self.time_per_step, trade_start_time, FFR_value, PA_value, POS_values
+ )
+ )
+
def execute(self, trade_decision):
- self._init_sub_trading(trade_decision)
- execute_result = []
- _inner_execute_result = None
- while not self.inner_executor.finished():
- _inner_trade_decision = self.inner_strategy.generate_trade_decision(_inner_execute_result)
- _inner_execute_result = self.inner_executor.execute(trade_decision=_inner_trade_decision)
- execute_result.extend(_inner_execute_result)
- if hasattr(self, "trade_account"):
- self._update_trade_account()
- self.trade_calendar.step()
- return execute_result
+ for _data in self.collect_data(trade_decision):
+ pass
+ return self._execute_result
def collect_data(self, trade_decision):
if self.track_data:
yield trade_decision
- self.trade_calendar.step()
self._init_sub_trading(trade_decision)
execute_result = []
+ inner_indicators = []
_inner_execute_result = None
while not self.inner_executor.finished():
_inner_trade_decision = self.inner_strategy.generate_trade_decision(_inner_execute_result)
_inner_execute_result = yield from self.inner_executor.collect_data(trade_decision=_inner_trade_decision)
execute_result.extend(_inner_execute_result)
- if hasattr(self, "trade_account"):
- self._update_trade_account()
+ inner_indicators.append(self.inner_executor.get_trade_indicator())
+ if hasattr(self, "trade_account"):
+ self._update_trade_account(inner_indicators=inner_indicators)
+
+ self.trade_calendar.step()
+ self._execute_result = execute_result
return execute_result
def get_report(self):
@@ -261,6 +295,7 @@ class SimulatorExecutor(BaseExecutor):
time_per_step: str,
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
+ show_indicator: bool = False,
generate_report: bool = False,
verbose: bool = False,
track_data: bool = False,
@@ -279,6 +314,7 @@ class SimulatorExecutor(BaseExecutor):
time_per_step=time_per_step,
start_time=start_time,
end_time=end_time,
+ show_indicator=show_indicator,
generate_report=generate_report,
verbose=verbose,
track_data=track_data,
@@ -337,7 +373,7 @@ class SimulatorExecutor(BaseExecutor):
else:
if self.verbose:
- print("[W {:%Y-%m-%d}]: {} wrong.".format(trade_start_time, order.stock_id))
+ print("[W {:%Y-%m-%d %H:%M:%S}]: {} wrong.".format(trade_start_time, order.stock_id))
# do nothing
pass
@@ -349,6 +385,25 @@ class SimulatorExecutor(BaseExecutor):
trade_end_time=trade_end_time,
trade_exchange=self.trade_exchange,
)
+
+ self.trade_account.indicator.clear()
+ self.trade_account.indicator.update_trade_info(trade_info=execute_result)
+ self.trade_account.indicator.update_FFR()
+ self.trade_account.indicator.update_PA(
+ freq=self.time_per_step, trade_start_time=trade_start_time, trade_end_time=trade_end_time
+ )
+ self.trade_account.indicator.record(trade_start_time=trade_start_time)
+
+ if self.show_indicator:
+ FFR_value = self.trade_account.indicator.get_statistics_FFR(method="value_weighted")
+ PA_value = self.trade_account.indicator.get_statistics_PA(method="value_weighted")
+ POS_values = self.trade_account.indicator.get_statistics_POS()
+ print(
+ "[Indicator({}) {:%Y-%m-%d %H:%M:%S}]: FFR: {}, PA: {}, POS: {}".format(
+ self.time_per_step, trade_start_time, FFR_value, PA_value, POS_values
+ )
+ )
+
self.trade_calendar.step()
return execute_result
diff --git a/qlib/backtest/report.py b/qlib/backtest/report.py
index 0668f81cf..d12595db5 100644
--- a/qlib/backtest/report.py
+++ b/qlib/backtest/report.py
@@ -7,10 +7,11 @@ from logging import warning
import pandas as pd
import pathlib
import warnings
+from pandas.core import groupby
from pandas.core.frame import DataFrame
-from ..utils.resam import parse_freq, resam_ts_data
+from ..utils.resam import parse_freq, resam_ts_data, get_higher_freq_feature
from ..data import D
from ..tests.config import CSI300_BENCH
@@ -79,19 +80,7 @@ class Report:
raise ValueError("benchmark freq can't be None!")
_codes = benchmark if isinstance(benchmark, list) else [benchmark]
fields = ["$close/Ref($close,1)-1"]
- try:
- _temp_result = D.features(_codes, fields, start_time, end_time, freq=freq, disk_cache=1)
- except (ValueError, KeyError):
- _, norm_freq = parse_freq(freq)
- if norm_freq in ["month", "week", "day"]:
- try:
- _temp_result = D.features(_codes, fields, start_time, end_time, freq="day", disk_cache=1)
- except (ValueError, KeyError):
- _temp_result = D.features(_codes, fields, start_time, end_time, freq="1min", disk_cache=1)
- elif norm_freq == "minute":
- _temp_result = D.features(_codes, fields, start_time, end_time, freq="1min", disk_cache=1)
- else:
- raise ValueError(f"benchmark freq {freq} is not supported")
+ _temp_result, _ = get_higher_freq_feature(_codes, fields, start_time, end_time, freq=freq)
if len(_temp_result) == 0:
raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark")
return _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean().fillna(0)
@@ -122,11 +111,11 @@ class Report:
turnover_rate=None,
cost_rate=None,
stock_value=None,
+ bench_value=None,
):
# check data
if None in [
trade_start_time,
- trade_end_time,
account_value,
cash,
return_rate,
@@ -135,8 +124,14 @@ class Report:
stock_value,
]:
raise ValueError(
- "None in [trade_start_time, trade_end_time, account_value, cash, return_rate, turnover_rate, cost_rate, stock_value]"
+ "None in [trade_start_time, account_value, cash, return_rate, turnover_rate, cost_rate, stock_value]"
)
+
+ if trade_end_time is None and bench_value is None:
+ raise ValueError("Both trade_end_time and bench_value is None, benchmark is not usable.")
+ elif bench_value is None:
+ bench_value = self._sample_benchmark(self.bench, trade_start_time, trade_end_time)
+
# update report data
self.accounts[trade_start_time] = account_value
self.returns[trade_start_time] = return_rate
@@ -144,7 +139,7 @@ class Report:
self.costs[trade_start_time] = cost_rate
self.values[trade_start_time] = stock_value
self.cashes[trade_start_time] = cash
- self.benches[trade_start_time] = self._sample_benchmark(self.bench, trade_start_time, trade_end_time)
+ self.benches[trade_start_time] = bench_value
# update latest_report_date
self.latest_report_time = trade_start_time
# finish daily report update
@@ -178,14 +173,162 @@ class Report:
index = r.index
self.init_vars()
- for trade_time in index:
+ for trade_start_time in index:
self.update_report_record(
- trade_time=trade_time,
- account_value=r.loc[trade_time]["account"],
- cash=r.loc[trade_time]["cash"],
- return_rate=r.loc[trade_time]["return"],
- turnover_rate=r.loc[trade_time]["turnover"],
- cost_rate=r.loc[trade_time]["cost"],
- stock_value=r.loc[trade_time]["value"],
- bench_value=r.loc[trade_time]["bench"],
+ trade_start_time=trade_start_time,
+ account_value=r.loc[trade_start_time]["account"],
+ cash=r.loc[trade_start_time]["cash"],
+ return_rate=r.loc[trade_start_time]["return"],
+ turnover_rate=r.loc[trade_start_time]["turnover"],
+ cost_rate=r.loc[trade_start_time]["cost"],
+ stock_value=r.loc[trade_start_time]["value"],
+ bench_value=r.loc[trade_start_time]["bench"],
)
+
+
+class Indicator:
+ def __init__(self):
+ self.indicator_his = dict()
+ self.trade_indicator = dict()
+
+ def __getitem__(self, key):
+ return self.trade_indicator[key]
+
+ def __setitem__(self, key, value):
+ self.trade_indicator[key] = value
+
+ def __contains__(self, key):
+ return key in self.trade_indicator
+
+ def clear(self):
+ self.trade_indicator = dict()
+
+ def record(self, trade_start_time):
+ self.indicator_his[trade_start_time] = pd.DataFrame(self.trade_indicator)
+
+ def update_trade_info(self, trade_info: list):
+ amount = dict()
+ deal_amount = dict()
+ trade_price = dict()
+ trade_cost = dict()
+
+ for order, _trade_val, _trade_cost, _trade_price in trade_info:
+ amount[order.stock_id] = order.amount * (order.direction * 2 - 1)
+ deal_amount[order.stock_id] = order.deal_amount * (order.direction * 2 - 1)
+ trade_price[order.stock_id] = _trade_price
+ trade_cost[order.stock_id] = _trade_cost
+
+ self["amount"] = pd.Series(amount)
+ self["deal_amount"] = pd.Series(deal_amount)
+ self["trade_price"] = pd.Series(trade_price)
+ self["trade_cost"] = pd.Series(trade_cost)
+
+ def update_FFR(self):
+ self["fulfill_rate"] = self["deal_amount"] / self["amount"]
+
+ def update_PA(self, freq, trade_start_time, trade_end_time, base_price="twap"):
+ base_price = base_price.lower()
+
+ instruments = list(self["amount"].index)
+ if base_price == "twap":
+ # too slow
+ # price_info, _ = get_higher_freq_feature(instruments, fields=["$close"], start_time=trade_start_time, end_time=trade_end_time, freq=freq)
+ # price_info = price_info.astype(float)
+
+ # self["base_price"] = price_info["$close"].groupby(level="instrument").mean()
+ self["base_price"] = self["trade_price"]
+
+ elif base_price == "vwap":
+ # too slow
+ price_info, _ = get_higher_freq_feature(
+ instruments,
+ fields=["$close", "$volume"],
+ start_time=trade_start_time,
+ end_time=trade_end_time,
+ freq=freq,
+ )
+ price_info = price_info.astype(float)
+ self["base_price"] = price_info.groupby(level="instrument").apply(
+ lambda x: (x["$close"] * x["$volume"]).sum() / x["$volume"].sum()
+ )
+ self["volume"] = price_info["$volume"].groupby(level="instrument").sum()
+ else:
+ raise ValueError(f"base_price {base_price} is not supported!")
+
+ self["pa"] = (self["trade_price"] - self["base_price"]) / self["base_price"]
+
+ def agg_report_info(self, inner_indicators):
+ amount = pd.Series()
+ deal_amount = pd.Series()
+ trade_price = pd.Series()
+ trade_cost = pd.Series()
+ for inner_indicator in inner_indicators:
+ amount = amount.add(inner_indicator["amount"], fill_value=0)
+ deal_amount = deal_amount.add(inner_indicator["deal_amount"], fill_value=0)
+ trade_price = trade_price.add(inner_indicator["trade_price"] * inner_indicator["deal_amount"], fill_value=0)
+ trade_cost = trade_cost.add(inner_indicator["trade_cost"], fill_value=0)
+
+ self["amount"] = amount
+ self["deal_amount"] = deal_amount
+ trade_price /= self["deal_amount"]
+ self["trade_price"] = trade_price
+ self["trade_cost"] = trade_cost
+
+ def agg_FFR(self):
+ self["fulfill_rate"] = self["deal_amount"] / self["amount"]
+
+ def agg_PA(self, inner_indicators, base_price="twap"):
+ base_price = base_price.lower()
+
+ if base_price == "twap":
+ base_price = pd.Series()
+ price_count = pd.Series()
+ for inner_indicator in inner_indicators:
+ base_price = base_price.add(inner_indicator["base_price"], fill_value=0)
+ price_count = price_count.add(pd.Series(1, index=inner_indicator["base_price"].index), fill_value=0)
+ base_price /= price_count
+ self["base_price"] = base_price
+
+ elif base_price == "vwap":
+ base_price = pd.Series()
+ volume = pd.Series()
+ for inner_indicator in inner_indicators:
+ base_price = base_price.add(inner_indicator["base_price"] * inner_indicator["volume"], fill_value=0)
+ volume = volume.add(inner_indicator["volume"], fill_value=0)
+ base_price /= volume
+ self["base_price"] = base_price
+ self["volume"] = volume
+ else:
+ raise ValueError(f"base_price {base_price} is not supported!")
+
+ self["pa"] = (self["trade_price"] - self["base_price"]) / self["base_price"]
+
+ def get_statistics_FFR(self, method="mean"):
+ if method == "mean":
+ return self["fulfill_rate"].mean()
+ elif method == "amount_weighted":
+ weights = self["deal_amount"].abs()
+ return (self["fulfill_rate"] * weights).sum() / weights.sum()
+ elif method == "value_weighted":
+ weights = (self["deal_amount"] * self["trade_price"]).abs()
+ return (self["fulfill_rate"] * weights).sum() / weights.sum()
+ else:
+ raise ValueError(f"method {method} is not supported!")
+
+ def get_statistics_PA(self, method="mean"):
+ pa_order = self["pa"] * (self["amount"] < 0).astype(int)
+
+ if method == "mean":
+ return pa_order.mean()
+ elif method == "amount_weighted":
+ weights = self["deal_amount"].abs()
+ return (pa_order * weights).sum() / weights.sum()
+ elif method == "value_weighted":
+ weights = (self["deal_amount"] * self["trade_price"]).abs()
+ return (pa_order * weights).sum() / weights.sum()
+ else:
+ raise ValueError(f"method {method} is not supported!")
+
+ def get_statistics_POS(self):
+ pa_order = self["pa"] * (self["amount"] < 0).astype(int)
+ return (pa_order > 1e-8).astype(int).sum() / len(pa_order)
diff --git a/qlib/backtest/utils.py b/qlib/backtest/utils.py
index 8582cfe28..25ddc45a4 100644
--- a/qlib/backtest/utils.py
+++ b/qlib/backtest/utils.py
@@ -74,7 +74,12 @@ class TradeCalendarManager:
def get_step_time(self, trade_step=0, shift=0):
"""
- Get the time range of trading step
+ Get the left and right endpoints of the trade_step'th trading interval
+
+ About the endpoints:
+ - Qlib uses the closed interval in time-series data selection, which has the same performance as pandas.Series.loc
+ - The returned right endpoints should minus 1 seconds becasue of the closed interval representation in Qlib.
+ Note: Qlib supports up to minutely decision execution, so 1 seconds is less than any trading time interval.
Parameters
----------
diff --git a/qlib/contrib/model/pytorch_tcts.py b/qlib/contrib/model/pytorch_tcts.py
new file mode 100644
index 000000000..9f44ba31c
--- /dev/null
+++ b/qlib/contrib/model/pytorch_tcts.py
@@ -0,0 +1,393 @@
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+
+from __future__ import division
+from __future__ import print_function
+
+import os
+import numpy as np
+import pandas as pd
+import copy
+from sklearn.metrics import roc_auc_score, mean_squared_error
+import logging
+from ...utils import (
+ unpack_archive_with_buffer,
+ save_multiple_parts_file,
+ create_save_path,
+ drop_nan_by_y_index,
+)
+from ...log import get_module_logger, TimeInspector
+
+import torch
+import torch.nn as nn
+import torch.optim as optim
+
+from ...model.base import Model
+from ...data.dataset import DatasetH
+from ...data.dataset.handler import DataHandlerLP
+
+
+class TCTS(Model):
+ """TCTS Model
+
+ Parameters
+ ----------
+ d_feat : int
+ input dimension for each time step
+ metric: str
+ the evaluate metric used in early stop
+ optimizer : str
+ optimizer name
+ GPU : str
+ the GPU ID(s) used for training
+ """
+
+ def __init__(
+ self,
+ d_feat=6,
+ hidden_size=64,
+ num_layers=2,
+ dropout=0.0,
+ n_epochs=200,
+ batch_size=2000,
+ early_stop=20,
+ loss="mse",
+ fore_optimizer="adam",
+ weight_optimizer="adam",
+ output_dim=5,
+ fore_lr=5e-7,
+ weight_lr=5e-7,
+ steps=3,
+ GPU=0,
+ seed=None,
+ target_label=0,
+ **kwargs
+ ):
+ # Set logger.
+ self.logger = get_module_logger("TCTS")
+ self.logger.info("TCTS pytorch version...")
+
+ # set hyper-parameters.
+ self.d_feat = d_feat
+ self.hidden_size = hidden_size
+ self.num_layers = num_layers
+ self.dropout = dropout
+ self.n_epochs = n_epochs
+ self.batch_size = batch_size
+ self.early_stop = early_stop
+ self.loss = loss
+ self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
+ self.use_gpu = torch.cuda.is_available()
+ self.seed = seed
+ self.output_dim = output_dim
+ self.fore_lr = fore_lr
+ self.weight_lr = weight_lr
+ self.steps = steps
+ self.target_label = target_label
+
+ self.logger.info(
+ "TCTS parameters setting:"
+ "\nd_feat : {}"
+ "\nhidden_size : {}"
+ "\nnum_layers : {}"
+ "\ndropout : {}"
+ "\nn_epochs : {}"
+ "\nbatch_size : {}"
+ "\nearly_stop : {}"
+ "\nloss_type : {}"
+ "\nvisible_GPU : {}"
+ "\nuse_GPU : {}"
+ "\nseed : {}".format(
+ d_feat,
+ hidden_size,
+ num_layers,
+ dropout,
+ n_epochs,
+ batch_size,
+ early_stop,
+ loss,
+ GPU,
+ self.use_gpu,
+ seed,
+ )
+ )
+
+ if self.seed is not None:
+ np.random.seed(self.seed)
+ torch.manual_seed(self.seed)
+
+ self.fore_model = GRUModel(
+ d_feat=self.d_feat,
+ hidden_size=self.hidden_size,
+ num_layers=self.num_layers,
+ dropout=self.dropout,
+ )
+ self.weight_model = MLPModel(
+ d_feat=360 + 2 * self.output_dim + 1,
+ hidden_size=self.hidden_size,
+ num_layers=self.num_layers,
+ dropout=self.dropout,
+ output_dim=self.output_dim,
+ )
+ if fore_optimizer.lower() == "adam":
+ self.fore_optimizer = optim.Adam(self.fore_model.parameters(), lr=self.fore_lr)
+ elif fore_optimizer.lower() == "gd":
+ self.fore_optimizer = optim.SGD(self.fore_model.parameters(), lr=self.fore_lr)
+ else:
+ raise NotImplementedError("optimizer {} is not supported!".format(fore_optimizer))
+ if weight_optimizer.lower() == "adam":
+ self.weight_optimizer = optim.Adam(self.weight_model.parameters(), lr=self.weight_lr)
+ elif weight_optimizer.lower() == "gd":
+ self.weight_optimizer = optim.SGD(self.weight_model.parameters(), lr=self.weight_lr)
+ else:
+ raise NotImplementedError("optimizer {} is not supported!".format(weight_optimizer))
+
+ self.fitted = False
+ self.fore_model.to(self.device)
+ self.weight_model.to(self.device)
+
+ def loss_fn(self, pred, label, weight):
+
+ loc = torch.argmax(weight, 1)
+ loss = (pred - label[np.arange(weight.shape[0]), loc]) ** 2
+ return torch.mean(loss)
+
+ def train_epoch(self, x_train, y_train, x_valid, y_valid):
+
+ x_train_values = x_train.values
+ y_train_values = np.squeeze(y_train.values)
+
+ indices = np.arange(len(x_train_values))
+ np.random.shuffle(indices)
+
+ init_fore_model = copy.deepcopy(self.fore_model)
+ for p in init_fore_model.parameters():
+ p.init_fore_model = False
+
+ self.fore_model.train()
+ self.weight_model.train()
+
+ for p in self.weight_model.parameters():
+ p.requires_grad = False
+ for p in self.fore_model.parameters():
+ p.requires_grad = True
+
+ for i in range(self.steps):
+ for i in range(len(indices))[:: self.batch_size]:
+
+ if len(indices) - i < self.batch_size:
+ break
+
+ feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
+ label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
+
+ init_pred = init_fore_model(feature)
+ pred = self.fore_model(feature)
+
+ dis = init_pred - label.transpose(0, 1)
+ weight_feature = torch.cat((feature, dis.transpose(0, 1), label, init_pred.view(-1, 1)), 1)
+ weight = self.weight_model(weight_feature)
+
+ loss = self.loss_fn(pred, label, weight) # hard
+
+ self.fore_optimizer.zero_grad()
+ loss.backward()
+ torch.nn.utils.clip_grad_value_(self.fore_model.parameters(), 3.0)
+ self.fore_optimizer.step()
+
+ x_valid_values = x_valid.values
+ y_valid_values = np.squeeze(y_valid.values)
+
+ indices = np.arange(len(x_valid_values))
+ np.random.shuffle(indices)
+ for p in self.weight_model.parameters():
+ p.requires_grad = True
+ for p in self.fore_model.parameters():
+ p.requires_grad = False
+
+ # fix forecasting model and valid weight model
+ for i in range(len(indices))[:: self.batch_size]:
+
+ if len(indices) - i < self.batch_size:
+ break
+
+ feature = torch.from_numpy(x_valid_values[indices[i : i + self.batch_size]]).float().to(self.device)
+ label = torch.from_numpy(y_valid_values[indices[i : i + self.batch_size]]).float().to(self.device)
+
+ pred = self.fore_model(feature)
+ dis = pred - label.transpose(0, 1)
+ weight_feature = torch.cat((feature, dis.transpose(0, 1), label, pred.view(-1, 1)), 1)
+ weight = self.weight_model(weight_feature)
+ loc = torch.argmax(weight, 1)
+ valid_loss = torch.mean((pred - label[:, 0]) ** 2)
+ loss = torch.mean(-valid_loss * torch.log(weight[np.arange(weight.shape[0]), loc]))
+
+ self.weight_optimizer.zero_grad()
+ loss.backward()
+ torch.nn.utils.clip_grad_value_(self.weight_model.parameters(), 3.0)
+ self.weight_optimizer.step()
+
+ def test_epoch(self, data_x, data_y):
+
+ # prepare training data
+ x_values = data_x.values
+ y_values = np.squeeze(data_y.values)
+
+ self.fore_model.eval()
+
+ scores = []
+ losses = []
+
+ indices = np.arange(len(x_values))
+
+ for i in range(len(indices))[:: self.batch_size]:
+
+ if len(indices) - i < self.batch_size:
+ break
+
+ feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)
+ label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)
+
+ pred = self.fore_model(feature)
+ loss = torch.mean((pred - label[:, abs(self.target_label)]) ** 2)
+ losses.append(loss.item())
+
+ return np.mean(losses)
+
+ def fit(
+ self,
+ dataset: DatasetH,
+ evals_result=dict(),
+ verbose=True,
+ save_path=None,
+ ):
+
+ df_train, df_valid, df_test = dataset.prepare(
+ ["train", "valid", "test"],
+ col_set=["feature", "label"],
+ data_key=DataHandlerLP.DK_L,
+ )
+
+ x_train, y_train = df_train["feature"], df_train["label"]
+ x_valid, y_valid = df_valid["feature"], df_valid["label"]
+ x_test, y_test = df_test["feature"], df_test["label"]
+
+ if save_path == None:
+ save_path = create_save_path(save_path)
+
+ best_loss = np.inf
+ best_epoch = 0
+ stop_round = 0
+ fore_best_param = copy.deepcopy(self.fore_optimizer.state_dict())
+ weight_best_param = copy.deepcopy(self.weight_optimizer.state_dict())
+
+ for epoch in range(self.n_epochs):
+ print("Epoch:", epoch)
+
+ print("training...")
+ self.train_epoch(x_train, y_train, x_valid, y_valid)
+ print("evaluating...")
+ val_loss = self.test_epoch(x_valid, y_valid)
+ test_loss = self.test_epoch(x_test, y_test)
+
+ print("valid %.6f, test %.6f" % (val_loss, test_loss))
+
+ if val_loss < best_loss:
+ best_loss = val_loss
+ stop_round = 0
+ best_epoch = epoch
+ torch.save(copy.deepcopy(self.fore_model.state_dict()), save_path + "_fore_model.bin")
+ torch.save(copy.deepcopy(self.weight_model.state_dict()), save_path + "_weight_model.bin")
+
+ else:
+ stop_round += 1
+ if stop_round >= self.early_stop:
+ print("early stop")
+ break
+
+ print("best loss:", best_loss, "@", best_epoch)
+ best_param = torch.load(save_path + "_fore_model.bin")
+ self.fore_model.load_state_dict(best_param)
+ best_param = torch.load(save_path + "_weight_model.bin")
+ self.weight_model.load_state_dict(best_param)
+ self.fitted = True
+
+ if self.use_gpu:
+ torch.cuda.empty_cache()
+
+ def predict(self, dataset):
+ if not self.fitted:
+ raise ValueError("model is not fitted yet!")
+
+ x_test = dataset.prepare("test", col_set="feature")
+ index = x_test.index
+ self.fore_model.eval()
+ x_values = x_test.values
+ sample_num = x_values.shape[0]
+ preds = []
+
+ for begin in range(sample_num)[:: self.batch_size]:
+
+ if sample_num - begin < self.batch_size:
+ end = sample_num
+ else:
+ end = begin + self.batch_size
+
+ x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
+
+ with torch.no_grad():
+ if self.use_gpu:
+ pred = self.fore_model(x_batch).detach().cpu().numpy()
+ else:
+ pred = self.fore_model(x_batch).detach().numpy()
+
+ preds.append(pred)
+
+ return pd.Series(np.concatenate(preds), index=index)
+
+
+class MLPModel(nn.Module):
+ def __init__(self, d_feat, hidden_size=256, num_layers=3, dropout=0.0, output_dim=1):
+ super().__init__()
+
+ self.mlp = nn.Sequential()
+ self.softmax = nn.Softmax(dim=1)
+
+ for i in range(num_layers):
+ if i > 0:
+ self.mlp.add_module("drop_%d" % i, nn.Dropout(dropout))
+ self.mlp.add_module("fc_%d" % i, nn.Linear(d_feat if i == 0 else hidden_size, hidden_size))
+ self.mlp.add_module("relu_%d" % i, nn.ReLU())
+
+ self.mlp.add_module("fc_out", nn.Linear(hidden_size, output_dim))
+
+ def forward(self, x):
+ # feature
+ # [N, F]
+ out = self.mlp(x).squeeze()
+ out = self.softmax(out)
+ return out
+
+
+class GRUModel(nn.Module):
+ def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0):
+ super().__init__()
+
+ self.rnn = nn.GRU(
+ input_size=d_feat,
+ hidden_size=hidden_size,
+ num_layers=num_layers,
+ batch_first=True,
+ dropout=dropout,
+ )
+ self.fc_out = nn.Linear(hidden_size, 1)
+
+ self.d_feat = d_feat
+
+ def forward(self, x):
+ # x: [N, F*T]
+ x = x.reshape(len(x), self.d_feat, -1) # [N, F, T]
+ x = x.permute(0, 2, 1) # [N, T, F]
+ out, _ = self.rnn(x)
+ return self.fc_out(out[:, -1, :]).squeeze()
diff --git a/qlib/contrib/model/xgboost.py b/qlib/contrib/model/xgboost.py
index 2a38f4fe1..300326143 100755
--- a/qlib/contrib/model/xgboost.py
+++ b/qlib/contrib/model/xgboost.py
@@ -62,7 +62,7 @@ class XGBModel(Model, FeatureInt):
if self.model is None:
raise ValueError("model is not fitted yet!")
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
- return pd.Series(self.model.predict(xgb.DMatrix(x_test.values)), index=x_test.index)
+ return pd.Series(self.model.predict(xgb.DMatrix(x_test)), index=x_test.index)
def get_feature_importance(self, *args, **kwargs) -> pd.Series:
"""get feature importance
diff --git a/qlib/contrib/strategy/model_strategy.py b/qlib/contrib/strategy/model_strategy.py
index ba1e3c785..d88dcd7d6 100644
--- a/qlib/contrib/strategy/model_strategy.py
+++ b/qlib/contrib/strategy/model_strategy.py
@@ -51,6 +51,11 @@ class TopkDropoutStrategy(ModelStrategy):
trade_exchange : Exchange
exchange that provides market info, used to deal order and generate report
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
+ - It allowes different trade_exchanges is used in different executions.
+ - For example:
+ - In daily execution, both daily exchange and minutely are usable, but the daily exchange is recommended because it run faster.
+ - In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended.
+
"""
super(TopkDropoutStrategy, self).__init__(
model, dataset, level_infra=level_infra, common_infra=common_infra, **kwargs
@@ -253,6 +258,15 @@ class WeightStrategyBase(ModelStrategy):
common_infra=None,
**kwargs,
):
+ """
+ trade_exchange : Exchange
+ exchange that provides market info, used to deal order and generate report
+ - If `trade_exchange` is None, self.trade_exchange will be set with common_infra
+ - It allowes different trade_exchanges is used in different executions.
+ - For example:
+ - In daily execution, both daily exchange and minutely are usable, but the daily exchange is recommended because it run faster.
+ - In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended.
+ """
super(WeightStrategyBase, self).__init__(
model, dataset, level_infra=level_infra, common_infra=common_infra, **kwargs
)
@@ -301,18 +315,6 @@ class WeightStrategyBase(ModelStrategy):
raise NotImplementedError()
def generate_trade_decision(self, execute_result=None):
- """
- Parameters
- -----------
- score_series : pd.Seires
- stock_id , score.
- current : Position()
- current of account.
- trade_exchange : Exchange()
- exchange.
- trade_date : pd.Timestamp
- date.
- """
# generate_trade_decision
# generate_target_weight_position() and generate_order_list_from_target_weight_position() to generate order_list
diff --git a/qlib/contrib/strategy/rule_strategy.py b/qlib/contrib/strategy/rule_strategy.py
index b72f32c29..300c983a0 100644
--- a/qlib/contrib/strategy/rule_strategy.py
+++ b/qlib/contrib/strategy/rule_strategy.py
@@ -1,4 +1,6 @@
import warnings
+import numpy as np
+import pandas as pd
from typing import List, Union
from ...utils.resam import resam_ts_data
@@ -28,6 +30,10 @@ class TWAPStrategy(BaseStrategy):
trade_exchange : Exchange
exchange that provides market info, used to deal order and generate report
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
+ - It allowes different trade_exchanges is used in different executions.
+ - For example:
+ - In daily execution, both daily exchange and minutely are usable, but the daily exchange is recommended because it run faster.
+ - In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended.
"""
super(TWAPStrategy, self).__init__(
@@ -88,27 +94,29 @@ class TWAPStrategy(BaseStrategy):
# considering trade unit
if _amount_trade_unit is None:
# divide the order into equal parts, and trade one part
- _order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_step + 1)
+ _order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_step)
# without considering trade unit
- elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit:
+ else:
# divide the order into equal parts, and trade one part
# calculate the total count of trade units to trade
trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
# calculate the amount of one part, ceil the amount
# floor((trade_unit_cnt + trade_len - trade_step) / (trade_len - trade_step + 1)) == ceil(trade_unit_cnt / (trade_len - trade_step + 1))
_order_amount = (
- (trade_unit_cnt + trade_len - trade_step) // (trade_len - trade_step + 1) * _amount_trade_unit
+ (trade_unit_cnt + trade_len - trade_step - 1) // (trade_len - trade_step) * _amount_trade_unit
)
if order.direction == order.SELL:
# sell all amount at last
if self.trade_amount[(order.stock_id, order.direction)] > 1e-5 and (
- _order_amount is None or trade_step == trade_len
+ _order_amount < 1e-5 or trade_step == trade_len - 1
):
_order_amount = self.trade_amount[(order.stock_id, order.direction)]
- if _order_amount:
- _order_amount = min(_order_amount, self.trade_amount[(order.stock_id, order.direction)])
+ _order_amount = min(_order_amount, self.trade_amount[(order.stock_id, order.direction)])
+
+ if _order_amount > 1e-5:
+
_order = Order(
stock_id=order.stock_id,
amount=_order_amount,
@@ -145,6 +153,10 @@ class SBBStrategyBase(BaseStrategy):
trade_exchange : Exchange
exchange that provides market info, used to deal order and generate report
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
+ - It allowes different trade_exchanges is used in different executions.
+ - For example:
+ - In daily execution, both daily exchange and minutely are usable, but the daily exchange is recommended because it run faster.
+ - In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended.
"""
super(SBBStrategyBase, self).__init__(
outer_trade_decision=outer_trade_decision, level_infra=level_infra, common_infra=common_infra
@@ -222,7 +234,7 @@ class SBBStrategyBase(BaseStrategy):
# divide the order into equal parts, and trade one part
_order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_step)
# without considering trade unit
- elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit:
+ else:
# divide the order into equal parts, and trade one part
# calculate the total count of trade units to trade
trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
@@ -234,11 +246,13 @@ class SBBStrategyBase(BaseStrategy):
if order.direction == order.SELL:
# sell all amount at last
if self.trade_amount[(order.stock_id, order.direction)] > 1e-5 and (
- _order_amount is None or trade_step == trade_len - 1
+ _order_amount < 1e-5 or trade_step == trade_len - 1
):
_order_amount = self.trade_amount[(order.stock_id, order.direction)]
- if _order_amount:
+ _order_amount = min(_order_amount, self.trade_amount[(order.stock_id, order.direction)])
+
+ if _order_amount > 1e-5:
_order = Order(
stock_id=order.stock_id,
amount=_order_amount,
@@ -258,7 +272,7 @@ class SBBStrategyBase(BaseStrategy):
2 * self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_step + 1)
)
# without considering trade unit
- elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit:
+ else:
# cal how many trade unit
trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
# N trade day left, divide the order into N + 1 parts, and trade 2 parts
@@ -270,13 +284,14 @@ class SBBStrategyBase(BaseStrategy):
)
if order.direction == order.SELL:
# sell all amount at last
- if self.trade_amount[(order.stock_id, order.direction)] >= 1e-5 and (
- _order_amount is None or trade_step == trade_len - 1
+ if self.trade_amount[(order.stock_id, order.direction)] > 1e-5 and (
+ _order_amount < 1e-5 or trade_step == trade_len - 1
):
_order_amount = self.trade_amount[(order.stock_id, order.direction)]
- if _order_amount:
- _order_amount = min(_order_amount, self.trade_amount[(order.stock_id, order.direction)])
+ _order_amount = min(_order_amount, self.trade_amount[(order.stock_id, order.direction)])
+
+ if _order_amount > 1e-5:
if trade_step % 2 == 0:
# in the first one of two adjacent bars
# if look short on the price, sell the stock more
@@ -402,3 +417,176 @@ class SBBStrategyEMA(SBBStrategyBase):
# if EMA signal > 0, return short trend
else:
return self.TREND_SHORT
+
+
+class VAStrategy(BaseStrategy):
+ def __init__(
+ self,
+ lamb: float = 1e-6,
+ eta: float = 2.5e-6,
+ window_size: int = 20,
+ outer_trade_decision: List[Order] = None,
+ instruments: Union[List, str] = "csi300",
+ freq: str = "day",
+ trade_exchange: Exchange = None,
+ level_infra: LevelInfrastructure = None,
+ common_infra: CommonInfrastructure = None,
+ **kwargs,
+ ):
+ """
+ Parameters
+ ----------
+ instruments : Union[List, str], optional
+ instruments of Volatility, by default "csi300"
+ freq : str, optional
+ freq of Volatility, by default "day"
+ Note: `freq` may be different from `time_per_step`
+ """
+ self.lamb = lamb
+ self.eta = eta
+ self.window_size = window_size
+ if instruments is None:
+ warnings.warn("`instruments` is not set, will load all stocks")
+ self.instruments = "all"
+ if isinstance(instruments, str):
+ self.instruments = D.instruments(instruments)
+ self.freq = freq
+ super(VAStrategy, self).__init__(outer_trade_decision, level_infra, common_infra, **kwargs)
+
+ if trade_exchange is not None:
+ self.trade_exchange = trade_exchange
+
+ def _reset_signal(self):
+ trade_len = self.trade_calendar.get_trade_len()
+ fields = [
+ f"Power(Sum(Power(Log($close/Ref($close, 1)), 2), {self.window_size})/{self.window_size - 1}-Power(Sum(Log($close/Ref($close, 1)), {self.window_size}), 2)/({self.window_size}*{self.window_size - 1}), 0.5)"
+ ]
+ signal_start_time, _ = self.trade_calendar.get_step_time(trade_step=0, shift=1)
+ _, signal_end_time = self.trade_calendar.get_step_time(trade_step=trade_len - 1, shift=1)
+ signal_df = D.features(
+ self.instruments, fields, start_time=signal_start_time, end_time=signal_end_time, freq=self.freq
+ )
+ signal_df = convert_index_format(signal_df)
+ signal_df.columns = ["volatility"]
+ self.signal = {}
+
+ if not signal_df.empty:
+ for stock_id, stock_val in signal_df.groupby(level="instrument"):
+ self.signal[stock_id] = stock_val
+
+ def reset_common_infra(self, common_infra):
+ """
+ Parameters
+ ----------
+ common_infra : CommonInfrastructure, optional
+ common infrastructure for backtesting, by default None
+ - It should include `trade_account`, used to get position
+ - It should include `trade_exchange`, used to provide market info
+ """
+ super(VAStrategy, self).reset_common_infra(common_infra)
+
+ if common_infra.has("trade_exchange"):
+ self.trade_exchange = common_infra.get("trade_exchange")
+
+ def reset_level_infra(self, level_infra):
+ """
+ reset level-shared infra
+ - After reset the trade calendar, the signal will be changed
+ """
+ if not hasattr(self, "level_infra"):
+ self.level_infra = level_infra
+ else:
+ self.level_infra.update(level_infra)
+
+ if level_infra.has("trade_calendar"):
+ self.trade_calendar = level_infra.get("trade_calendar")
+ self._reset_signal()
+
+ def reset(self, outer_trade_decision: List[Order] = None, **kwargs):
+ """
+ Parameters
+ ----------
+ outer_trade_decision : List[Order], optional
+ """
+ super(VAStrategy, self).reset(outer_trade_decision=outer_trade_decision, **kwargs)
+ if outer_trade_decision is not None:
+ self.trade_amount = {}
+ # init the trade amount of order and predicted trade trend
+ for order in outer_trade_decision:
+ self.trade_amount[(order.stock_id, order.direction)] = order.amount
+
+ def generate_trade_decision(self, execute_result=None):
+
+ # update the order amount
+ if execute_result is not None:
+ for order, _, _, _ in execute_result:
+ self.trade_amount[(order.stock_id, order.direction)] -= order.deal_amount
+
+ # get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1]
+ trade_step = self.trade_calendar.get_trade_step()
+ # get the total count of trading step
+ trade_len = self.trade_calendar.get_trade_len()
+ trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
+ pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)
+ order_list = []
+ for order in self.outer_trade_decision:
+ # if not tradable, continue
+ if not self.trade_exchange.is_stock_tradable(
+ stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time
+ ):
+ continue
+ _order_amount = None
+ # considering trade unit
+
+ sig_sam = (
+ resam_ts_data(self.signal[order.stock_id]["volatility"], pred_start_time, pred_end_time, method="last")
+ if order.stock_id in self.signal
+ else None
+ )
+
+ if sig_sam is None or sig_sam.iloc[0] is None:
+ # no signal, TWAP
+ _amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(order.factor)
+ if _amount_trade_unit is None:
+ # divide the order into equal parts, and trade one part
+ _order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_step)
+ else:
+ # divide the order into equal parts, and trade one part
+ # calculate the total count of trade units to trade
+ trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
+ # calculate the amount of one part, ceil the amount
+ # floor((trade_unit_cnt + trade_len - trade_step - 1) / (trade_len - trade_step)) == ceil(trade_unit_cnt / (trade_len - trade_step))
+ _order_amount = (
+ (trade_unit_cnt + trade_len - trade_step - 1) // (trade_len - trade_step) * _amount_trade_unit
+ )
+ else:
+ # VA strategy
+ kappa_tild = self.lamb / self.eta * sig_sam.iloc[0] * sig_sam.iloc[0]
+ kappa = np.arccosh(kappa_tild / 2 + 1)
+ amount_ratio = (
+ np.sinh(kappa * (trade_len - trade_step)) - np.sinh(kappa * (trade_len - trade_step - 1))
+ ) / np.sinh(kappa * trade_len)
+ _order_amount = order.amount * amount_ratio
+ _order_amount = self.trade_exchange.round_amount_by_trade_unit(_order_amount, order.factor)
+
+ if order.direction == order.SELL:
+ # sell all amount at last
+ if self.trade_amount[(order.stock_id, order.direction)] > 1e-5 and (
+ _order_amount < 1e-5 or trade_step == trade_len - 1
+ ):
+ _order_amount = self.trade_amount[(order.stock_id, order.direction)]
+
+ _order_amount = min(_order_amount, self.trade_amount[(order.stock_id, order.direction)])
+
+ if _order_amount > 1e-5:
+
+ _order = Order(
+ stock_id=order.stock_id,
+ amount=_order_amount,
+ start_time=trade_start_time,
+ end_time=trade_end_time,
+ direction=order.direction, # 1 for buy
+ factor=order.factor,
+ )
+ order_list.append(_order)
+ return order_list
diff --git a/qlib/data/data.py b/qlib/data/data.py
index b2d8b075e..978fe6186 100644
--- a/qlib/data/data.py
+++ b/qlib/data/data.py
@@ -65,7 +65,6 @@ class CalendarProvider(abc.ABC, ProviderBackendMixin):
def __init__(self, *args, **kwargs):
self.backend = kwargs.get("backend", {})
- @abc.abstractmethod
def calendar(self, start_time=None, end_time=None, freq="day", freq_sam=None, future=False):
"""Get calendar of certain market in given time range.
@@ -87,7 +86,22 @@ class CalendarProvider(abc.ABC, ProviderBackendMixin):
list
calendar list
"""
- raise NotImplementedError("Subclass of CalendarProvider must implement `calendar` method")
+ _calendar, _ = self._get_calendar(freq=freq, freq_sam=freq_sam, future=future)
+ # strip
+ if start_time:
+ start_time = pd.Timestamp(start_time)
+ if start_time > _calendar[-1]:
+ return np.array([])
+ else:
+ start_time = _calendar[0]
+ if end_time:
+ end_time = pd.Timestamp(end_time)
+ if end_time < _calendar[0]:
+ return np.array([])
+ else:
+ end_time = _calendar[-1]
+ st, et, si, ei = self.locate_index(start_time, end_time, freq=freq, freq_sam=freq_sam, future=future)
+ return _calendar[si : ei + 1]
def locate_index(self, start_time, end_time, freq, freq_sam=None, future=False):
"""Locate the start time index and end time index in a calendar under certain frequency.
@@ -172,6 +186,21 @@ class CalendarProvider(abc.ABC, ProviderBackendMixin):
"""Get the uri of calendar generation task."""
return hash_args(start_time, end_time, freq, future)
+ def load_calendar(self, freq, future):
+ """Load original calendar timestamp from file.
+
+ Parameters
+ ----------
+ freq : str
+ frequency of read calendar file.
+
+ Returns
+ ----------
+ list
+ list of timestamps
+ """
+ raise NotImplementedError("Subclass of CalendarProvider must implement `load_calendar` method")
+
class InstrumentProvider(abc.ABC, ProviderBackendMixin):
"""Instrument provider base class
@@ -457,7 +486,8 @@ class DatasetProvider(abc.ABC):
normalize_column_names = normalize_cache_fields(column_names)
data = dict()
# One process for one task, so that the memory will be freed quicker.
- workers = min(C.kernels, len(instruments_d))
+ workers = max(min(C.kernels, len(instruments_d)), 1)
+
if C.maxtasksperchild is None:
p = Pool(processes=workers)
else:
@@ -504,7 +534,9 @@ class DatasetProvider(abc.ABC):
data = pd.concat(new_data, names=["instrument"], sort=False)
data = DiskDatasetCache.cache_to_origin_data(data, column_names)
else:
- data = pd.DataFrame(columns=column_names)
+ data = pd.DataFrame(
+ index=pd.MultiIndex.from_arrays([[], []], names=("instrument", "datetime")), columns=column_names
+ )
return data
@@ -558,19 +590,6 @@ class LocalCalendarProvider(CalendarProvider):
return os.path.join(C.get_data_path(), "calendars", "{}.txt")
def load_calendar(self, freq, future):
- """Load original calendar timestamp from file.
-
- Parameters
- ----------
- freq : str
- frequency of read calendar file.
-
- Returns
- ----------
- list
- list of timestamps
- """
-
try:
backend_obj = self.backend_obj(freq=freq, future=future).data
except ValueError:
@@ -587,24 +606,6 @@ class LocalCalendarProvider(CalendarProvider):
return [pd.Timestamp(x) for x in backend_obj]
- def calendar(self, start_time=None, end_time=None, freq="day", freq_sam=None, future=False):
- _calendar, _ = self._get_calendar(freq=freq, freq_sam=freq_sam, future=future)
- # strip
- if start_time:
- start_time = pd.Timestamp(start_time)
- if start_time > _calendar[-1]:
- return np.array([])
- else:
- start_time = _calendar[0]
- if end_time:
- end_time = pd.Timestamp(end_time)
- if end_time < _calendar[0]:
- return np.array([])
- else:
- end_time = _calendar[-1]
- st, et, si, ei = self.locate_index(start_time, end_time, freq=freq, freq_sam=freq_sam, future=future)
- return _calendar[si : ei + 1]
-
class LocalInstrumentProvider(InstrumentProvider):
"""Local instrument data provider class
@@ -719,7 +720,9 @@ class LocalDatasetProvider(DatasetProvider):
column_names = self.get_column_names(fields)
cal = Cal.calendar(start_time, end_time, freq)
if len(cal) == 0:
- return pd.DataFrame(columns=column_names)
+ return pd.DataFrame(
+ index=pd.MultiIndex.from_arrays([[], []], names=("instrument", "datetime")), columns=column_names
+ )
start_time = cal[0]
end_time = cal[-1]
@@ -741,7 +744,7 @@ class LocalDatasetProvider(DatasetProvider):
return
start_time = cal[0]
end_time = cal[-1]
- workers = min(C.kernels, len(instruments_d))
+ workers = max(min(C.kernels, len(instruments_d)), 1)
if C.maxtasksperchild is None:
p = Pool(processes=workers)
else:
@@ -789,7 +792,7 @@ class ClientCalendarProvider(CalendarProvider):
def calendar(self, start_time=None, end_time=None, freq="day", freq_sam=None, future=False):
self.conn.send_request(
- request_type="trade_calendar",
+ request_type="calendar",
request_content={
"start_time": str(start_time),
"end_time": str(end_time),
@@ -902,7 +905,10 @@ class ClientDatasetProvider(DatasetProvider):
column_names = self.get_column_names(fields)
cal = Cal.calendar(start_time, end_time, freq)
if len(cal) == 0:
- return pd.DataFrame(columns=column_names)
+ return pd.DataFrame(
+ index=pd.MultiIndex.from_arrays([[], []], names=("instrument", "datetime")),
+ columns=column_names,
+ )
start_time = cal[0]
end_time = cal[-1]
@@ -1004,7 +1010,7 @@ class LocalProvider(BaseProvider):
:param type: The type of resource for the uri
:param **kwargs:
"""
- if type == "trade_calendar":
+ if type == "calendar":
return Cal._uri(**kwargs)
elif type == "instrument":
return Inst._uri(**kwargs)
diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py
index fd76e6728..28d854477 100644
--- a/qlib/model/trainer.py
+++ b/qlib/model/trainer.py
@@ -12,9 +12,11 @@ In ``DelayTrainer``, the first step is only to save some necessary info to model
"""
import socket
+import time
from typing import Callable, List
from qlib.data.dataset import Dataset
+from qlib.log import get_module_logger
from qlib.model.base import Model
from qlib.utils import flatten_dict, get_cls_kwargs, init_instance_by_config
from qlib.workflow import R
@@ -190,6 +192,8 @@ class TrainerR(Trainer):
Returns:
List[Recorder]: a list of Recorders
"""
+ if isinstance(tasks, dict):
+ tasks = [tasks]
if len(tasks) == 0:
return []
if train_func is None:
@@ -213,6 +217,8 @@ class TrainerR(Trainer):
Returns:
List[Recorder]: the same list as the param.
"""
+ if isinstance(recs, Recorder):
+ recs = [recs]
for rec in recs:
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
return recs
@@ -250,6 +256,8 @@ class DelayTrainerR(TrainerR):
Returns:
List[Recorder]: a list of Recorders
"""
+ if isinstance(recs, Recorder):
+ recs = [recs]
if end_train_func is None:
end_train_func = self.end_train_func
if experiment_name is None:
@@ -275,6 +283,9 @@ class TrainerRM(Trainer):
STATUS_BEGIN = "begin_task_train"
STATUS_END = "end_task_train"
+ # This tag is the _id in TaskManager to distinguish tasks.
+ TM_ID = "_id in TaskManager"
+
def __init__(self, experiment_name: str = None, task_pool: str = None, train_func=task_train):
"""
Init TrainerR.
@@ -315,6 +326,8 @@ class TrainerRM(Trainer):
Returns:
List[Recorder]: a list of Recorders
"""
+ if isinstance(tasks, dict):
+ tasks = [tasks]
if len(tasks) == 0:
return []
if train_func is None:
@@ -326,19 +339,25 @@ class TrainerRM(Trainer):
task_pool = experiment_name
tm = TaskManager(task_pool=task_pool)
_id_list = tm.create_task(tasks) # all tasks will be saved to MongoDB
+ query = {"_id": {"$in": _id_list}}
run_task(
train_func,
task_pool,
+ query=query, # only train these tasks
experiment_name=experiment_name,
before_status=before_status,
after_status=after_status,
**kwargs,
)
+ if not self.is_delay():
+ tm.wait(query=query)
+
recs = []
for _id in _id_list:
rec = tm.re_query(_id)["res"]
rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN})
+ rec.set_tags(**{self.TM_ID: _id})
recs.append(rec)
return recs
@@ -352,10 +371,33 @@ class TrainerRM(Trainer):
Returns:
List[Recorder]: the same list as the param.
"""
+ if isinstance(recs, Recorder):
+ recs = [recs]
for rec in recs:
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
return recs
+ def worker(
+ self,
+ train_func: Callable = None,
+ experiment_name: str = None,
+ ):
+ """
+ The multiprocessing method for `train`. It can share a same task_pool with `train` and can run in other progress or other machines.
+
+ Args:
+ train_func (Callable): the training method which needs at least `task`s and `experiment_name`. None for the default training method.
+ experiment_name (str): the experiment name, None for use default name.
+ """
+ if train_func is None:
+ train_func = self.train_func
+ if experiment_name is None:
+ experiment_name = self.experiment_name
+ task_pool = self.task_pool
+ if task_pool is None:
+ task_pool = experiment_name
+ run_task(train_func, task_pool=task_pool, experiment_name=experiment_name)
+
class DelayTrainerRM(TrainerRM):
"""
@@ -395,6 +437,8 @@ class DelayTrainerRM(TrainerRM):
Returns:
List[Recorder]: a list of Recorders
"""
+ if isinstance(tasks, dict):
+ tasks = [tasks]
if len(tasks) == 0:
return []
return super().train(
@@ -410,8 +454,6 @@ class DelayTrainerRM(TrainerRM):
Given a list of Recorder and return a list of trained Recorder.
This class will finish real data loading and model fitting.
- NOTE: This method will train all STATUS_PART_DONE tasks in the task pool, not only the ``recs``.
-
Args:
recs (list): a list of Recorder, the tasks have been saved to them.
end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
@@ -421,7 +463,8 @@ class DelayTrainerRM(TrainerRM):
Returns:
List[Recorder]: a list of Recorders
"""
-
+ if isinstance(recs, Recorder):
+ recs = [recs]
if end_train_func is None:
end_train_func = self.end_train_func
if experiment_name is None:
@@ -429,18 +472,44 @@ class DelayTrainerRM(TrainerRM):
task_pool = self.task_pool
if task_pool is None:
task_pool = experiment_name
- tasks = []
+ _id_list = []
for rec in recs:
- tasks.append(rec.load_object("task"))
+ _id_list.append(rec.list_tags()[self.TM_ID])
+ query = {"_id": {"$in": _id_list}}
run_task(
end_train_func,
task_pool,
- query={"filter": {"$in": tasks}}, # only train these tasks
+ query=query, # only train these tasks
experiment_name=experiment_name,
before_status=TaskManager.STATUS_PART_DONE,
**kwargs,
)
+
+ TaskManager(task_pool=task_pool).wait(query=query)
+
for rec in recs:
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
return recs
+
+ def worker(self, end_train_func=None, experiment_name: str = None):
+ """
+ The multiprocessing method for `end_train`. It can share a same task_pool with `end_train` and can run in other progress or other machines.
+
+ Args:
+ end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
+ experiment_name (str): the experiment name, None for use default name.
+ """
+ if end_train_func is None:
+ end_train_func = self.end_train_func
+ if experiment_name is None:
+ experiment_name = self.experiment_name
+ task_pool = self.task_pool
+ if task_pool is None:
+ task_pool = experiment_name
+ run_task(
+ end_train_func,
+ task_pool=task_pool,
+ experiment_name=experiment_name,
+ before_status=TaskManager.STATUS_PART_DONE,
+ )
diff --git a/qlib/rl/interpreter.py b/qlib/rl/interpreter.py
index 1e310e8ad..c711b8380 100644
--- a/qlib/rl/interpreter.py
+++ b/qlib/rl/interpreter.py
@@ -5,14 +5,14 @@
class BaseInterpreter:
"""Base Interpreter"""
- def interpret(**kwargs):
+ def interpret(self, **kwargs):
raise NotImplementedError("interpret is not implemented!")
class ActionInterpreter(BaseInterpreter):
"""Action Interpreter that interpret rl agent action into qlib orders"""
- def interpret(action, **kwargs):
+ def interpret(self, action, **kwargs):
"""interpret method
Parameters
@@ -32,7 +32,7 @@ class ActionInterpreter(BaseInterpreter):
class StateInterpreter(BaseInterpreter):
"""State Interpreter that interpret execution result of qlib executor into rl env state"""
- def interpret(execute_result, **kwargs):
+ def interpret(self, execute_result, **kwargs):
"""interpret method
Parameters
diff --git a/qlib/strategy/base.py b/qlib/strategy/base.py
index 9d3e0c72b..961fb5044 100644
--- a/qlib/strategy/base.py
+++ b/qlib/strategy/base.py
@@ -175,7 +175,7 @@ class RLIntStrategy(RLStrategy):
self.action_interpreter = init_instance_by_config(action_interpreter, accept_types=ActionInterpreter)
def generate_trade_decision(self, execute_result=None):
- _interpret_state = self.state_interpretor.interpret(execute_result=execute_result)
+ _interpret_state = self.state_interpreter.interpret(execute_result=execute_result)
_action = self.policy.step(_interpret_state)
_trade_decision = self.action_interpreter.interpret(action=_action)
return _trade_decision
diff --git a/qlib/tests/config.py b/qlib/tests/config.py
index 80461f6f9..c61b5651e 100644
--- a/qlib/tests/config.py
+++ b/qlib/tests/config.py
@@ -43,17 +43,29 @@ RECORD_CONFIG = [
]
-def get_data_handler_config(market=CSI300_MARKET):
+def get_data_handler_config(
+ start_time="2008-01-01",
+ end_time="2020-08-01",
+ fit_start_time="2008-01-01",
+ fit_end_time="2014-12-31",
+ instruments=CSI300_MARKET,
+):
return {
- "start_time": "2008-01-01",
- "end_time": "2020-08-01",
- "fit_start_time": "2008-01-01",
- "fit_end_time": "2014-12-31",
- "instruments": market,
+ "start_time": start_time,
+ "end_time": end_time,
+ "fit_start_time": fit_start_time,
+ "fit_end_time": fit_end_time,
+ "instruments": instruments,
}
-def get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA158_CLASS):
+def get_dataset_config(
+ dataset_class=DATASET_ALPHA158_CLASS,
+ train=("2008-01-01", "2014-12-31"),
+ valid=("2015-01-01", "2016-12-31"),
+ test=("2017-01-01", "2020-08-01"),
+ handler_kwargs={"instruments": CSI300_MARKET},
+):
return {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
@@ -61,48 +73,88 @@ def get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA158_CLAS
"handler": {
"class": dataset_class,
"module_path": "qlib.contrib.data.handler",
- "kwargs": get_data_handler_config(market),
+ "kwargs": get_data_handler_config(**handler_kwargs),
},
"segments": {
- "train": ("2008-01-01", "2014-12-31"),
- "valid": ("2015-01-01", "2016-12-31"),
- "test": ("2017-01-01", "2020-08-01"),
+ "train": train,
+ "valid": valid,
+ "test": test,
},
},
}
-def get_gbdt_task(market=CSI300_MARKET):
+def get_gbdt_task(dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET}):
return {
"model": GBDT_MODEL,
- "dataset": get_dataset_config(market),
+ "dataset": get_dataset_config(**dataset_kwargs, handler_kwargs=handler_kwargs),
}
-def get_record_lgb_config(market=CSI300_MARKET):
+def get_record_lgb_config(dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET}):
return {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
},
- "dataset": get_dataset_config(market),
+ "dataset": get_dataset_config(**dataset_kwargs, handler_kwargs=handler_kwargs),
"record": RECORD_CONFIG,
}
-def get_record_xgboost_config(market=CSI300_MARKET):
+def get_record_xgboost_config(dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET}):
return {
"model": {
"class": "XGBModel",
"module_path": "qlib.contrib.model.xgboost",
},
- "dataset": get_dataset_config(market),
+ "dataset": get_dataset_config(**dataset_kwargs, handler_kwargs=handler_kwargs),
"record": RECORD_CONFIG,
}
-CSI300_DATASET_CONFIG = get_dataset_config(market=CSI300_MARKET)
-CSI300_GBDT_TASK = get_gbdt_task(market=CSI300_MARKET)
+CSI300_DATASET_CONFIG = get_dataset_config(handler_kwargs={"instruments": CSI300_MARKET})
+CSI300_GBDT_TASK = get_gbdt_task(handler_kwargs={"instruments": CSI300_MARKET})
-CSI100_RECORD_XGBOOST_TASK_CONFIG = get_record_xgboost_config(market=CSI100_MARKET)
-CSI100_RECORD_LGB_TASK_CONFIG = get_record_lgb_config(market=CSI100_MARKET)
+CSI100_RECORD_XGBOOST_TASK_CONFIG = get_record_xgboost_config(handler_kwargs={"instruments": CSI100_MARKET})
+CSI100_RECORD_LGB_TASK_CONFIG = get_record_lgb_config(handler_kwargs={"instruments": CSI100_MARKET})
+
+# use for rolling_online_managment.py
+ROLLING_HANDLER_CONFIG = {
+ "start_time": "2013-01-01",
+ "end_time": "2020-09-25",
+ "fit_start_time": "2013-01-01",
+ "fit_end_time": "2014-12-31",
+ "instruments": CSI100_MARKET,
+}
+ROLLING_DATASET_CONFIG = {
+ "train": ("2013-01-01", "2014-12-31"),
+ "valid": ("2015-01-01", "2015-12-31"),
+ "test": ("2016-01-01", "2020-07-10"),
+}
+CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING = get_record_xgboost_config(
+ dataset_kwargs=ROLLING_DATASET_CONFIG, handler_kwargs=ROLLING_HANDLER_CONFIG
+)
+CSI100_RECORD_LGB_TASK_CONFIG_ROLLING = get_record_lgb_config(
+ dataset_kwargs=ROLLING_DATASET_CONFIG, handler_kwargs=ROLLING_HANDLER_CONFIG
+)
+
+# use for online_management_simulate.py
+ONLINE_HANDLER_CONFIG = {
+ "start_time": "2018-01-01",
+ "end_time": "2018-10-31",
+ "fit_start_time": "2018-01-01",
+ "fit_end_time": "2018-03-31",
+ "instruments": CSI100_MARKET,
+}
+ONLINE_DATASET_CONFIG = {
+ "train": ("2018-01-01", "2018-03-31"),
+ "valid": ("2018-04-01", "2018-05-31"),
+ "test": ("2018-06-01", "2018-09-10"),
+}
+CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE = get_record_xgboost_config(
+ dataset_kwargs=ONLINE_DATASET_CONFIG, handler_kwargs=ONLINE_HANDLER_CONFIG
+)
+CSI100_RECORD_LGB_TASK_CONFIG_ONLINE = get_record_lgb_config(
+ dataset_kwargs=ONLINE_DATASET_CONFIG, handler_kwargs=ONLINE_HANDLER_CONFIG
+)
diff --git a/qlib/utils/resam.py b/qlib/utils/resam.py
index 71e0aa654..d8198fc99 100644
--- a/qlib/utils/resam.py
+++ b/qlib/utils/resam.py
@@ -8,6 +8,11 @@ from typing import Tuple, List, Union, Optional, Callable
from . import lazy_sort_index
from ..config import C
+NORM_FREQ_MONTH = "month"
+NORM_FREQ_WEEK = "week"
+NORM_FREQ_DAY = "day"
+NORM_FREQ_MINUTE = "minute"
+
def parse_freq(freq: str) -> Tuple[int, str]:
"""
@@ -43,14 +48,14 @@ def parse_freq(freq: str) -> Tuple[int, str]:
_count = int(match_obj.group(1)) if match_obj.group(1) else 1
_freq = match_obj.group(2)
_freq_format_dict = {
- "month": "month",
- "mon": "month",
- "week": "week",
- "w": "week",
- "day": "day",
- "d": "day",
- "minute": "minute",
- "min": "minute",
+ "month": NORM_FREQ_MONTH,
+ "mon": NORM_FREQ_MONTH,
+ "week": NORM_FREQ_WEEK,
+ "w": NORM_FREQ_WEEK,
+ "day": NORM_FREQ_DAY,
+ "d": NORM_FREQ_DAY,
+ "minute": NORM_FREQ_MINUTE,
+ "min": NORM_FREQ_MINUTE,
}
return _count, _freq_format_dict[_freq]
@@ -81,7 +86,7 @@ def resam_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> np
return calendar_raw
# if freq_sam is xminute, divide each trading day into several bars evenly
- if freq_sam == "minute":
+ if freq_sam == NORM_FREQ_MINUTE:
def cal_sam_minute(x, sam_minutes):
"""
@@ -114,7 +119,7 @@ def resam_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> np
else:
raise ValueError("calendar minute_index error, check `min_data_shift` in qlib.config.C")
- if freq_raw != "minute":
+ if freq_raw != NORM_FREQ_MINUTE:
raise ValueError("when sampling minute calendar, freq of raw calendar must be minute or min")
else:
if raw_count > sam_count:
@@ -125,15 +130,15 @@ def resam_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> np
# else, convert the raw calendar into day calendar, and divide the whole calendar into several bars evenly
else:
_calendar_day = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, 0, 0, 0), calendar_raw)))
- if freq_sam == "day":
+ if freq_sam == NORM_FREQ_DAY:
return _calendar_day[::sam_count]
- elif freq_sam == "week":
+ elif freq_sam == NORM_FREQ_WEEK:
_day_in_week = np.array(list(map(lambda x: x.dayofweek, _calendar_day)))
_calendar_week = _calendar_day[np.ediff1d(_day_in_week, to_begin=-1) < 0]
return _calendar_week[::sam_count]
- elif freq_sam == "month":
+ elif freq_sam == NORM_FREQ_MONTH:
_day_in_month = np.array(list(map(lambda x: x.day, _calendar_day)))
_calendar_month = _calendar_day[np.ediff1d(_day_in_month, to_begin=-1) < 0]
return _calendar_month[::sam_count]
@@ -184,7 +189,7 @@ def get_resam_calendar(
freq, freq_sam = freq, None
except (ValueError, KeyError):
freq_sam = freq
- if norm_freq in ["month", "week", "day"]:
+ if norm_freq in [NORM_FREQ_MONTH, NORM_FREQ_WEEK, NORM_FREQ_DAY]:
try:
_calendar = Cal.calendar(
start_time=start_time, end_time=end_time, freq="day", freq_sam=freq, future=future
@@ -195,7 +200,7 @@ def get_resam_calendar(
start_time=start_time, end_time=end_time, freq="1min", freq_sam=freq, future=future
)
freq = "1min"
- elif norm_freq == "minute":
+ elif norm_freq == NORM_FREQ_MINUTE:
_calendar = Cal.calendar(
start_time=start_time, end_time=end_time, freq="1min", freq_sam=freq, future=future
)
@@ -205,6 +210,57 @@ def get_resam_calendar(
return _calendar, freq, freq_sam
+def get_higher_freq_feature(instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1):
+ """[summary]
+
+ Parameters
+ ----------
+ instruments : [type]
+ [description]
+ fields : [type]
+ [description]
+ start_time : [type], optional
+ [description], by default None
+ end_time : [type], optional
+ [description], by default None
+ freq : str, optional
+ [description], by default "day"
+ disk_cache : int, optional
+ [description], by default 1
+
+ Returns
+ -------
+ [type]
+ [description]
+
+ Raises
+ ------
+ ValueError
+ [description]
+ """
+
+ from ..data.data import D
+
+ try:
+ _result = D.features(instruments, fields, start_time, end_time, freq=freq, disk_cache=disk_cache)
+ _freq = freq
+ except (ValueError, KeyError):
+ _, norm_freq = parse_freq(freq)
+ if norm_freq in [NORM_FREQ_MONTH, NORM_FREQ_WEEK, NORM_FREQ_DAY]:
+ try:
+ _result = D.features(instruments, fields, start_time, end_time, freq="day", disk_cache=disk_cache)
+ _freq = "day"
+ except (ValueError, KeyError):
+ _result = D.features(instruments, fields, start_time, end_time, freq="1min", disk_cache=disk_cache)
+ _freq = "1min"
+ elif norm_freq == NORM_FREQ_MINUTE:
+ _result = D.features(instruments, fields, start_time, end_time, freq="1min", disk_cache=disk_cache)
+ _freq = "1min"
+ else:
+ raise ValueError(f"freq {freq} is not supported")
+ return _result, _freq
+
+
def resam_ts_data(
ts_feature: Union[pd.DataFrame, pd.Series],
start_time: Union[str, pd.Timestamp] = None,
@@ -273,8 +329,9 @@ def resam_ts_data(
end sampling time, by default None
method : Union[str, Callable], optional
sample method, apply method function to each stock series data, by default "last"
- - If type(method) is str, it should be an attribute of SeriesGroupBy or DataFrameGroupby, and run feature.groupby
- - If `feature` has MultiIndex[instrument, datetime], method must be a member of pandas.groupby when it's type is str.or callable function.
+ - If type(method) is str or callable function, it should be an attribute of SeriesGroupBy or DataFrameGroupby, and applies groupy.method for the sliced time-series data
+ - If method is None, do nothing for the sliced time-series data.
+ - Only when the index `feature` is MultiIndex[instrument, datetime], the method is valid.
method_kwargs : dict, optional
arguments of method, by default {}
diff --git a/qlib/workflow/online/manager.py b/qlib/workflow/online/manager.py
index 443cd61ad..d3cc0cbf8 100644
--- a/qlib/workflow/online/manager.py
+++ b/qlib/workflow/online/manager.py
@@ -18,10 +18,12 @@ There are 4 total situations for using different trainers in different situation
========================= ===================================================================================
Situations Description
========================= ===================================================================================
-Online + Trainer When you REAL want to do a routine, the Trainer will help you train the models.
+Online + Trainer When you want to do a REAL routine, the Trainer will help you train the models. It
+ will train models task by task and strategy by strategy.
-Online + DelayTrainer In normal online routine, whether Trainer or DelayTrainer will REAL train models
- in this routine. So it is not necessary to use DelayTrainer when do a REAL routine.
+Online + DelayTrainer When your models don't have any temporal dependence, the DelayTrainer will train
+ nothing until all tasks have been prepared. It makes user can train all tasks in
+ the end of `routine` or `first_train`.
Simulation + Trainer When your models have some temporal dependence on the previous models, then you
need to consider using Trainer. This means it will REAL train your models in
@@ -103,17 +105,21 @@ class OnlineManager(Serializable):
"""
if strategies is None:
strategies = self.strategies
- for strategy in strategies:
+ models_list = []
+ for strategy in strategies:
self.logger.info(f"Strategy `{strategy.name_id}` begins first training...")
tasks = strategy.first_tasks()
models = self.trainer.train(tasks, experiment_name=strategy.name_id)
- models = self.trainer.end_train(models, experiment_name=strategy.name_id)
+ models_list.append(models)
self.logger.info(f"Finished training {len(models)} models.")
-
online_models = strategy.prepare_online_models(models, **model_kwargs)
self.history.setdefault(self.cur_time, {})[strategy] = online_models
+ if not self.status == self.STATUS_SIMULATING or not self.trainer.is_delay():
+ for strategy, models in zip(strategies, models_list):
+ models = self.trainer.end_train(models, experiment_name=strategy.name_id)
+
def routine(
self,
cur_time: Union[str, pd.Timestamp] = None,
@@ -139,33 +145,38 @@ class OnlineManager(Serializable):
cur_time = D.calendar(freq=self.freq).max()
self.cur_time = pd.Timestamp(cur_time) # None for latest date
+ models_list = []
for strategy in self.strategies:
self.logger.info(f"Strategy `{strategy.name_id}` begins routine...")
if self.status == self.STATUS_NORMAL:
strategy.tool.update_online_pred()
tasks = strategy.prepare_tasks(self.cur_time, **task_kwargs)
- models = self.trainer.train(tasks)
- if self.status == self.STATUS_NORMAL or not self.trainer.is_delay():
- models = self.trainer.end_train(models, experiment_name=strategy.name_id)
+ models = self.trainer.train(tasks, experiment_name=strategy.name_id)
+ models_list.append(models)
self.logger.info(f"Finished training {len(models)} models.")
online_models = strategy.prepare_online_models(models, **model_kwargs)
self.history.setdefault(self.cur_time, {})[strategy] = online_models
- if not self.trainer.is_delay():
+ if not self.status == self.STATUS_SIMULATING or not self.trainer.is_delay():
+ for strategy, models in zip(self.strategies, models_list):
+ models = self.trainer.end_train(models, experiment_name=strategy.name_id)
self.prepare_signals(**signal_kwargs)
- def get_collector(self) -> MergeCollector:
+ def get_collector(self, **kwargs) -> MergeCollector:
"""
Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results from every strategy.
This collector can be a basis as the signals preparation.
+ Args:
+ **kwargs: the params for get_collector.
+
Returns:
MergeCollector: the collector to merge other collectors.
"""
collector_dict = {}
for strategy in self.strategies:
- collector_dict[strategy.name_id] = strategy.get_collector()
+ collector_dict[strategy.name_id] = strategy.get_collector(**kwargs)
return MergeCollector(collector_dict, process_list=[])
def add_strategy(self, strategies: Union[OnlineStrategy, List[OnlineStrategy]]):
@@ -297,6 +308,7 @@ class OnlineManager(Serializable):
# NOTE: Assumption: the predictions of online models need less than next cur_time, or this method will work in a wrong way.
self.prepare_signals(**signal_kwargs)
if signals_time > cur_time:
+ # FIXME: if use DelayTrainer and worker (and worker is faster than main progress), there are some possibilities of showing this warning.
self.logger.warn(
f"The signals have already parpred to {signals_time} by last preparation, but current time is only {cur_time}. This may be because the online models predict more than they should, which can cause signals to be contaminated by the offline models."
)
diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py
index 8abcd6c14..9516d363a 100644
--- a/qlib/workflow/record_temp.py
+++ b/qlib/workflow/record_temp.py
@@ -17,7 +17,7 @@ from ..log import get_module_logger
from ..utils import flatten_dict
from ..utils.resam import parse_freq
from ..strategy.base import BaseStrategy
-from ..contrib.eva.alpha import calc_ic, calc_long_short_return
+from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec
logger = get_module_logger("workflow", logging.INFO)
@@ -302,7 +302,7 @@ class PortAnaRecord(RecordTemp):
define the executor class as well as the kwargs.
config["backtest"] : dict
define the backtest kwargs.
- risk_analysis_freq : int
+ risk_analysis_freq : str|List[str]
risk analysis freq of report
"""
super().__init__(recorder=recorder, **kwargs)
@@ -310,8 +310,11 @@ class PortAnaRecord(RecordTemp):
self.strategy_config = config["strategy"]
self.executor_config = config["executor"]
self.backtest_config = config["backtest"]
- _count, _freq = parse_freq(risk_analysis_freq)
- self.risk_analysis_freq = f"{_count}{_freq}"
+ if isinstance(risk_analysis_freq, str):
+ risk_analysis_freq = [risk_analysis_freq]
+ self.risk_analysis_freq = [
+ "{0}{1}".format(*parse_freq(_analysis_freq)) for _analysis_freq in risk_analysis_freq
+ ]
self.report_freq = self._get_report_freq(self.executor_config)
def _get_report_freq(self, executor_config):
@@ -336,34 +339,35 @@ class PortAnaRecord(RecordTemp):
**{f"positions_normal_{report_freq}.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path()
)
- if self.risk_analysis_freq not in report_dict:
- warnings.warn(
- f"the freq {self.risk_analysis_freq} report is not found, please set the corresponding env with `generate_report==True`"
- )
- else:
- report_normal, _ = report_dict.get(self.risk_analysis_freq)
- analysis = dict()
- analysis["excess_return_without_cost"] = risk_analysis(
- report_normal["return"] - report_normal["bench"], freq=self.risk_analysis_freq
- )
- analysis["excess_return_with_cost"] = risk_analysis(
- report_normal["return"] - report_normal["bench"] - report_normal["cost"], freq=self.risk_analysis_freq
- )
- analysis_df = pd.concat(analysis) # type: pd.DataFrame
- # log metrics
- self.recorder.log_metrics(**flatten_dict(analysis_df["risk"].unstack().T.to_dict()))
- # save results
- self.recorder.save_objects(
- **{f"port_analysis_{report_freq}.pkl": analysis_df}, artifact_path=PortAnaRecord.get_path()
- )
- logger.info(
- f"Portfolio analysis record 'port_analysis_{report_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
- )
- # print out results
- pprint("The following are analysis results of the excess return without cost.")
- pprint(analysis["excess_return_without_cost"])
- pprint("The following are analysis results of the excess return with cost.")
- pprint(analysis["excess_return_with_cost"])
+ for _analysis_freq in self.risk_analysis_freq:
+ if _analysis_freq not in report_dict:
+ warnings.warn(
+ f"the freq {_analysis_freq} report is not found, please set the corresponding env with `generate_report==True`"
+ )
+ else:
+ report_normal, _ = report_dict.get(_analysis_freq)
+ analysis = dict()
+ analysis["excess_return_without_cost"] = risk_analysis(
+ report_normal["return"] - report_normal["bench"], freq=_analysis_freq
+ )
+ analysis["excess_return_with_cost"] = risk_analysis(
+ report_normal["return"] - report_normal["bench"] - report_normal["cost"], freq=_analysis_freq
+ )
+ analysis_df = pd.concat(analysis) # type: pd.DataFrame
+ # log metrics
+ self.recorder.log_metrics(**flatten_dict(analysis_df["risk"].unstack().T.to_dict()))
+ # save results
+ self.recorder.save_objects(
+ **{f"port_analysis_{report_freq}.pkl": analysis_df}, artifact_path=PortAnaRecord.get_path()
+ )
+ logger.info(
+ f"Portfolio analysis record 'port_analysis_{report_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
+ )
+ # print out results
+ pprint("The following are analysis results of the excess return without cost.")
+ pprint(analysis["excess_return_without_cost"])
+ pprint("The following are analysis results of the excess return with cost.")
+ pprint(analysis["excess_return_with_cost"])
def list(self):
list_path = []
@@ -374,6 +378,10 @@ class PortAnaRecord(RecordTemp):
PortAnaRecord.get_path(f"positions_normal_{_freq}.pkl"),
]
)
- if _freq == self.risk_analysis_freq:
- list_path.append(PortAnaRecord.get_path(f"port_analysis_{_freq}.pkl"))
+
+ for _analysis_freq in self.risk_analysis_freq:
+ if _analysis_freq in self.report_freq:
+ list_path.append(PortAnaRecord.get_path(f"port_analysis_{_analysis_freq}.pkl"))
+ else:
+ warnings.warn(f"{_analysis_freq} is not found")
return list_path
diff --git a/qlib/workflow/task/manage.py b/qlib/workflow/task/manage.py
index 658eec4d6..7a85036da 100644
--- a/qlib/workflow/task/manage.py
+++ b/qlib/workflow/task/manage.py
@@ -69,28 +69,29 @@ class TaskManager:
ENCODE_FIELDS_PREFIX = ["def", "res"]
- def __init__(self, task_pool: str = None):
+ def __init__(self, task_pool: str):
"""
Init Task Manager, remember to make the statement of MongoDB url and database name firstly.
+ A TaskManager instance serves a specific task pool.
+ The static method of this module serves the whole MongoDB.
Parameters
----------
task_pool: str
the name of Collection in MongoDB
"""
- self.mdb = get_mongodb()
- if task_pool is not None:
- self.task_pool = getattr(self.mdb, task_pool)
+ self.task_pool = getattr(get_mongodb(), task_pool)
self.logger = get_module_logger(self.__class__.__name__)
- def list(self) -> list:
+ @staticmethod
+ def list() -> list:
"""
- List the all collection(task_pool) of the db
+ List the all collection(task_pool) of the db.
Returns:
list
"""
- return self.mdb.list_collection_names()
+ return get_mongodb().list_collection_names()
def _encode_task(self, task):
for prefix in self.ENCODE_FIELDS_PREFIX:
@@ -109,6 +110,25 @@ class TaskManager:
def _dict_to_str(self, flt):
return {k: str(v) for k, v in flt.items()}
+ def _decode_query(self, query):
+ """
+ If the query includes any `_id`, then it needs `ObjectId` to decode.
+ For example, when using TrainerRM, it needs query `{"_id": {"$in": _id_list}}`. Then we need to `ObjectId` every `_id` in `_id_list`.
+
+ Args:
+ query (dict): query dict. Defaults to {}.
+
+ Returns:
+ dict: the query after decoding.
+ """
+ if "_id" in query:
+ if isinstance(query["_id"], dict):
+ for key in query["_id"]:
+ query["_id"][key] = [ObjectId(i) for i in query["_id"][key]]
+ else:
+ query["_id"] = ObjectId(query["_id"])
+ return query
+
def replace_task(self, task, new_task):
"""
Use a new task to replace a old one
@@ -224,8 +244,7 @@ class TaskManager:
dict: a task(document in collection) after decoding
"""
query = query.copy()
- if "_id" in query:
- query["_id"] = ObjectId(query["_id"])
+ query = self._decode_query(query)
query.update({"status": status})
task = self.task_pool.find_one_and_update(
query, {"$set": {"status": self.STATUS_RUNNING}}, sort=[("priority", pymongo.DESCENDING)]
@@ -283,12 +302,11 @@ class TaskManager:
dict: a task(document in collection) after decoding
"""
query = query.copy()
- if "_id" in query:
- query["_id"] = ObjectId(query["_id"])
+ query = self._decode_query(query)
for t in self.task_pool.find(query):
yield self._decode_task(t)
- def re_query(self, _id):
+ def re_query(self, _id) -> dict:
"""
Use _id to query task.
@@ -339,8 +357,7 @@ class TaskManager:
"""
query = query.copy()
- if "_id" in query:
- query["_id"] = ObjectId(query["_id"])
+ query = self._decode_query(query)
self.task_pool.delete_many(query)
def task_stat(self, query={}) -> dict:
@@ -354,8 +371,7 @@ class TaskManager:
dict
"""
query = query.copy()
- if "_id" in query:
- query["_id"] = ObjectId(query["_id"])
+ query = self._decode_query(query)
tasks = self.query(query=query, decode=False)
status_stat = {}
for t in tasks:
@@ -377,8 +393,7 @@ class TaskManager:
def reset_status(self, query, status):
query = query.copy()
- if "_id" in query:
- query["_id"] = ObjectId(query["_id"])
+ query = self._decode_query(query)
print(self.task_pool.update_many(query, {"$set": {"status": status}}))
def prioritize(self, task, priority: int):
@@ -402,9 +417,19 @@ class TaskManager:
return sum(task_stat.values())
def wait(self, query={}):
+ """
+ When multiprocessing, the main progress may fetch nothing from TaskManager because there are still some running tasks.
+ So main progress should wait until all tasks are trained well by other progress or machines.
+
+ Args:
+ query (dict, optional): the query dict. Defaults to {}.
+ """
task_stat = self.task_stat(query)
total = self._get_total(task_stat)
last_undone_n = self._get_undone_n(task_stat)
+ if last_undone_n == 0:
+ return
+ self.logger.warn(f"Waiting for {last_undone_n} undone tasks. Please make sure they are running.")
with tqdm(total=total, initial=total - last_undone_n) as pbar:
while True:
time.sleep(10)
diff --git a/qlib/workflow/utils.py b/qlib/workflow/utils.py
index cd87187e9..5a93eacca 100644
--- a/qlib/workflow/utils.py
+++ b/qlib/workflow/utils.py
@@ -17,7 +17,6 @@ def experiment_exit_handler():
Thus, if any exception or user interuption occurs beforehead, we should handle them first. Once `R` is
ended, another call of `R.end_exp` will not take effect.
"""
- signal.signal(signal.SIGINT, experiment_kill_signal_handler) # handle user keyboard interupt
sys.excepthook = experiment_exception_hook # handle uncaught exception
atexit.register(R.end_exp, recorder_status=Recorder.STATUS_FI) # will not take effect if experiment ends
@@ -39,11 +38,3 @@ def experiment_exception_hook(type, value, tb):
print(f"{type.__name__}: {value}")
R.end_exp(recorder_status=Recorder.STATUS_FA)
-
-
-def experiment_kill_signal_handler(signum, frame):
- """
- End an experiment when user kill the program through keyboard (CTRL+C, etc.).
- """
- R.end_exp(recorder_status=Recorder.STATUS_FA)
- raise KeyboardInterrupt