mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 02:50:58 +08:00
Merge branch 'nested_decision_exe' of https://github.com/microsoft/qlib into rl-dummy
This commit is contained in:
52
examples/benchmarks/TCTS/TCTS.md
Normal file
52
examples/benchmarks/TCTS/TCTS.md
Normal file
@@ -0,0 +1,52 @@
|
||||
# Temporally Correlated Task Scheduling for Sequence Learning
|
||||
We provide the [code](https://github.com/microsoft/qlib/blob/main/qlib/contrib/model/pytorch_tcts.py) for reproducing the stock trend forecasting experiments.
|
||||
|
||||
### Background
|
||||
Sequence learning has attracted much research attention from the machine learning community in recent years. In many applications, a sequence learning task is usually associated with multiple temporally correlated auxiliary tasks, which are different in terms of how much input information to use or which future step to predict. In stock trend forecasting, as demonstrated in Figure1, one can predict the price of a stock in different future days (e.g., tomorrow, the day after tomorrow). In this paper, we propose a framework to make use of those temporally correlated tasks to help each other.
|
||||
|
||||
<p align="center">
|
||||
<img src="task_description.png" width="600" height="200"/>
|
||||
</p>
|
||||
|
||||
|
||||
### Method
|
||||
Given that there are usually multiple temporally correlated tasks, the key challenge lies in which tasks to use and when to use them in the training process. In this work, we introduce a learnable task scheduler for sequence learning, which adaptively selects temporally correlated tasks during the training process. The scheduler accesses the model status and the current training data (e.g., in current minibatch), and selects the best auxiliary task to help the training of the main task. The scheduler and the model for the main task are jointly trained through bi-level optimization: the scheduler is trained to maximize the validation performance of the model, and the model is trained to minimize the training loss guided by the scheduler. The process is demonstrated in Figure2.
|
||||
|
||||
<p align="center">
|
||||
<img src="workflow.png"/>
|
||||
</p>
|
||||
|
||||
At step <img src="https://render.githubusercontent.com/render/math?math=s">, with training data <img src="https://render.githubusercontent.com/render/math?math=x_s,y_s">, the scheduler <img src="https://render.githubusercontent.com/render/math?math=\varphi"> chooses a suitable task <img src="https://render.githubusercontent.com/render/math?math=T_{i_s}"> (green solid lines) to update the model <img src="https://render.githubusercontent.com/render/math?math=f"> (blue solid lines). After <img src="https://render.githubusercontent.com/render/math?math=S"> steps, we evaluate the model <img src="https://render.githubusercontent.com/render/math?math=f"> on the validation set and update the scheduler <img src="https://render.githubusercontent.com/render/math?math=\varphi"> (green dashed lines).
|
||||
|
||||
### DataSet
|
||||
* We use the historical transaction data for 300 stocks on [CSI300](http://www.csindex.com.cn/en/indices/index-detail/000300) from 01/01/2008 to 08/01/2020.
|
||||
* We split the data into training (01/01/2008-12/31/2013), validation (01/01/2014-12/31/2015), and test sets (01/01/2016-08/01/2020) based on the transaction time.
|
||||
|
||||
### Experiments
|
||||
#### Task Description
|
||||
* The main tasks <img src="https://render.githubusercontent.com/render/math?math=T_k"> (<img src="https://render.githubusercontent.com/render/math?math=task_k"> in Figure1) refers to forecasting return of stock <img src="https://render.githubusercontent.com/render/math?math=i"> as following,
|
||||
<div align=center>
|
||||
<img src="https://render.githubusercontent.com/render/math?math=r_{i}^k = \frac{\price_i^{t+k}}{\price_i^{t+k-1}} - 1">
|
||||
</div>
|
||||
|
||||
* Temporally correlated task sets <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_k = \{T_1, T_2, ... , T_k\}">, in this paper, <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">, <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5"> and <img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_10"> are used.
|
||||
#### Baselines
|
||||
* GRU/MLP/LightGBM (LGB)/Graph Attention Networks (GAT)
|
||||
* Multi-task learning (MTL): In multi-task learning, multiple tasks are jointly trained and mutually boosted. Each task is treated equally, while in our setting, we focus on the main task.
|
||||
* Curriculum transfer learning (CL): Transfer learning also leverages auxiliary tasks to boost the main task. [Curriculum transfer learning](https://arxiv.org/pdf/1804.00810.pdf) is one kind of transfer learning which schedules auxiliary tasks according to certain rules. Our problem can also be regarded as a special kind of transfer learning, where the auxiliary tasks are temporally correlated with the main task. Our learning process is dynamically controlled by a scheduler rather than some pre-defined rules. In the CL baseline, we start from the task <img src="https://render.githubusercontent.com/render/math?math=T_1" >, then <img src="https://render.githubusercontent.com/render/math?math=T_2" >, and gradually move to the last one.
|
||||
#### Result
|
||||
| Methods | <img src="https://render.githubusercontent.com/render/math?math=T_1" > | <img src="https://render.githubusercontent.com/render/math?math=T_2"> | <img src="https://render.githubusercontent.com/render/math?math=T_3"> |
|
||||
| :----: | :----: | :----: | :----: |
|
||||
| GRU | 0.049 / 1.903 | 0.018 / 1.972 | 0.014 / 1.989 |
|
||||
| MLP | 0.023 / 1.961 | 0.022 / 1.962 | 0.015 / 1.978 |
|
||||
| LGB | 0.038 / 1.883 | 0.023 / 1.952 | 0.007 / 1.987 |
|
||||
| GAT | 0.052 / 1.898 | 0.024 / 1.954 | 0.015 / 1.973 |
|
||||
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.061 / 1.862 | 0.023 / 1.942 | 0.012 / 1.956 |
|
||||
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.051 / 1.880 | 0.028 / 1.941 | 0.016 / 1.962 |
|
||||
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_3">) | 0.071 / 1.851 | 0.030 / 1.939 | 0.017 / 1.963 |
|
||||
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.057 / 1.875 | 0.021 / 1.939 | 0.017 / 1.959 |
|
||||
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.056 / 1.877 | 0.028 / 1.942 | 0.015 / 1.962 |
|
||||
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_5">) | 0.075 / 1.849 | 0.032 /1.939 | 0.021 / 1.955 |
|
||||
| MTL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.052 / 1.882 | 0.020 / 1.947 | 0.019 / 1.952 |
|
||||
| CL(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.051 / 1.882 | 0.028 / 1.950 | 0.016 / 1.961 |
|
||||
| Ours(<img src="https://render.githubusercontent.com/render/math?math=\mathcal{T}_{10}">) | 0.067 / 1.867 | 0.030 / 1.960 | 0.022 / 1.942|
|
||||
BIN
examples/benchmarks/TCTS/task_description.png
Normal file
BIN
examples/benchmarks/TCTS/task_description.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 25 KiB |
BIN
examples/benchmarks/TCTS/workflow.png
Normal file
BIN
examples/benchmarks/TCTS/workflow.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 29 KiB |
93
examples/benchmarks/TCTS/workflow_config_tcts_Alpha360.yaml
Normal file
93
examples/benchmarks/TCTS/workflow_config_tcts_Alpha360.yaml
Normal file
@@ -0,0 +1,93 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi300
|
||||
benchmark: &benchmark SH000300
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors:
|
||||
- class: RobustZScoreNorm
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
clip_outlier: true
|
||||
- class: Fillna
|
||||
kwargs:
|
||||
fields_group: feature
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1",
|
||||
"Ref($close, -3) / Ref($close, -1) - 1",
|
||||
"Ref($close, -4) / Ref($close, -1) - 1",
|
||||
"Ref($close, -5) / Ref($close, -1) - 1",
|
||||
"Ref($close, -6) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy.strategy
|
||||
kwargs:
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
verbose: False
|
||||
limit_threshold: 0.095
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: TCTS
|
||||
module_path: qlib.contrib.model.pytorch_tcts
|
||||
kwargs:
|
||||
d_feat: 6
|
||||
hidden_size: 64
|
||||
num_layers: 2
|
||||
dropout: 0.0
|
||||
n_epochs: 200
|
||||
lr: 1e-3
|
||||
early_stop: 20
|
||||
batch_size: 800
|
||||
metric: loss
|
||||
loss: mse
|
||||
GPU: 0
|
||||
fore_optimizer: adam
|
||||
weight_optimizer: adam
|
||||
output_dim: 5
|
||||
fore_lr: 5e-7
|
||||
weight_lr: 5e-7
|
||||
steps: 3
|
||||
target_label: 0
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs: {}
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -4,6 +4,7 @@
|
||||
"""
|
||||
This example shows how a TrainerRM works based on TaskManager with rolling tasks.
|
||||
After training, how to collect the rolling results will be shown in task_collecting.
|
||||
Based on the ability of TaskManager, `worker` method offer a simple way for multiprocessing.
|
||||
"""
|
||||
|
||||
from pprint import pprint
|
||||
@@ -13,7 +14,7 @@ import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.task.gen import RollingGen, task_generator
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.workflow.task.manage import TaskManager, run_task
|
||||
from qlib.workflow.task.collect import RecorderCollector
|
||||
from qlib.model.ens.group import RollingGroup
|
||||
from qlib.model.trainer import TrainerRM
|
||||
@@ -68,6 +69,11 @@ class RollingTaskExample:
|
||||
trainer = TrainerRM(self.experiment_name, self.task_pool)
|
||||
trainer.train(tasks)
|
||||
|
||||
def worker(self):
|
||||
# train tasks by other progress or machines for multiprocessing. It is same as TrainerRM.worker.
|
||||
print("========== worker ==========")
|
||||
run_task(task_train, self.task_pool, experiment_name=self.experiment_name)
|
||||
|
||||
def task_collecting(self):
|
||||
print("========== task_collecting ==========")
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Nested Decision Execution
|
||||
|
||||
This worflow is an example for nested decision execution in backtesting. Qlib supports nested decision execution in backtesting. It means that users can use different strategies to make trade decision in different frequencies.
|
||||
This workflow is an example for nested decision execution in backtesting. Qlib supports nested decision execution in backtesting. It means that users can use different strategies to make trade decision in different frequencies.
|
||||
|
||||
## Weekly Portfolio Generation and Daily Order Execution
|
||||
|
||||
|
||||
@@ -14,14 +14,13 @@ from qlib.tests.data import GetData
|
||||
from qlib.backtest import collect_data
|
||||
|
||||
|
||||
class NestedDecisonExecutionWorkflow:
|
||||
class NestedDecisionExecutionWorkflow:
|
||||
|
||||
market = "csi300"
|
||||
benchmark = "SH000300"
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2021-01-20",
|
||||
"end_time": "2020-12-31",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": market,
|
||||
@@ -53,9 +52,9 @@ class NestedDecisonExecutionWorkflow:
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"train": ("2007-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2021-01-20"),
|
||||
"test": ("2020-01-01", "2020-12-31"),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -66,35 +65,55 @@ class NestedDecisonExecutionWorkflow:
|
||||
"class": "NestedExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"time_per_step": "week",
|
||||
"time_per_step": "day",
|
||||
"inner_executor": {
|
||||
"class": "SimulatorExecutor",
|
||||
"class": "NestedExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"time_per_step": "day",
|
||||
"verbose": True,
|
||||
"time_per_step": "30min",
|
||||
"inner_executor": {
|
||||
"class": "SimulatorExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"time_per_step": "5min",
|
||||
"generate_report": True,
|
||||
"verbose": True,
|
||||
"indicator_config": {
|
||||
"show_indicator": True,
|
||||
},
|
||||
},
|
||||
},
|
||||
"inner_strategy": {
|
||||
"class": "TWAPStrategy",
|
||||
"module_path": "qlib.contrib.strategy.rule_strategy",
|
||||
},
|
||||
"generate_report": True,
|
||||
"indicator_config": {
|
||||
"show_indicator": True,
|
||||
},
|
||||
},
|
||||
},
|
||||
"inner_strategy": {
|
||||
"class": "SBBStrategyEMA",
|
||||
"module_path": "qlib.contrib.strategy.rule_strategy",
|
||||
"kwargs": {
|
||||
"freq": "day",
|
||||
"instruments": market,
|
||||
"freq": "1min",
|
||||
},
|
||||
},
|
||||
"generate_report": True,
|
||||
"track_data": True,
|
||||
"generate_report": True,
|
||||
"indicator_config": {
|
||||
"show_indicator": True,
|
||||
},
|
||||
},
|
||||
},
|
||||
"backtest": {
|
||||
"start_time": "2017-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"start_time": "2020-01-01",
|
||||
"end_time": "2020-12-31",
|
||||
"account": 100000000,
|
||||
"benchmark": benchmark,
|
||||
"exchange_kwargs": {
|
||||
"freq": "day",
|
||||
"freq": "1min",
|
||||
"limit_threshold": 0.095,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
@@ -106,11 +125,40 @@ class NestedDecisonExecutionWorkflow:
|
||||
|
||||
def _init_qlib(self):
|
||||
"""initialize qlib"""
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
# provider_uri_day = "/data/stock_data/huaxia/qlib"
|
||||
# provider_uri_1min = "/data2/stock_data/huaxia_1min_qlib"
|
||||
provider_uri_day = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
GetData().qlib_data(target_dir=provider_uri_day, region=REG_CN, version="v2", exists_skip=True)
|
||||
provider_uri_1min = HIGH_FREQ_CONFIG.get("provider_uri")
|
||||
GetData().qlib_data(
|
||||
target_dir=provider_uri_1min, interval="1min", region=REG_CN, version="v2", exists_skip=True
|
||||
)
|
||||
provider_uri_map = {"1min": provider_uri_1min, "day": provider_uri_day}
|
||||
client_config = {
|
||||
"calendar_provider": {
|
||||
"class": "LocalCalendarProvider",
|
||||
"module_path": "qlib.data.data",
|
||||
"kwargs": {
|
||||
"backend": {
|
||||
"class": "FileCalendarStorage",
|
||||
"module_path": "qlib.data.storage.file_storage",
|
||||
"kwargs": {"provider_uri_map": provider_uri_map},
|
||||
}
|
||||
},
|
||||
},
|
||||
"feature_provider": {
|
||||
"class": "LocalFeatureProvider",
|
||||
"module_path": "qlib.data.data",
|
||||
"kwargs": {
|
||||
"backend": {
|
||||
"class": "FileFeatureStorage",
|
||||
"module_path": "qlib.data.storage.file_storage",
|
||||
"kwargs": {"provider_uri_map": provider_uri_map},
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri_day, **client_config, redis_port=-1)
|
||||
|
||||
def _train_model(self, model, dataset):
|
||||
with R.start(experiment_name="train"):
|
||||
@@ -145,12 +193,25 @@ class NestedDecisonExecutionWorkflow:
|
||||
},
|
||||
}
|
||||
self.port_analysis_config["strategy"] = strategy_config
|
||||
self.port_analysis_config["backtest"]["benchmark"] = D.list_instruments(
|
||||
instruments=D.instruments(market=self.market), as_list=True
|
||||
)
|
||||
with R.start(experiment_name="backtest"):
|
||||
|
||||
recorder = R.get_recorder()
|
||||
par = PortAnaRecord(recorder, self.port_analysis_config, "day")
|
||||
par = PortAnaRecord(
|
||||
recorder,
|
||||
self.port_analysis_config,
|
||||
risk_analysis_freq=["day", "30min", "5min"],
|
||||
indicator_analysis_freq=["day", "30min", "5min"],
|
||||
indicator_analysis_method="value_weighted",
|
||||
)
|
||||
par.generate()
|
||||
|
||||
# report_normal_df = recorder.load_object("portfolio_analysis/report_normal_1day.pkl")
|
||||
# from qlib.contrib.report import analysis_position
|
||||
# analysis_position.report_graph(report_normal_df)
|
||||
|
||||
def collect_data(self):
|
||||
self._init_qlib()
|
||||
model = init_instance_by_config(self.task["model"])
|
||||
@@ -158,6 +219,7 @@ class NestedDecisonExecutionWorkflow:
|
||||
self._train_model(model, dataset)
|
||||
executor_config = self.port_analysis_config["executor"]
|
||||
backtest_config = self.port_analysis_config["backtest"]
|
||||
backtest_config["benchmark"] = D.list_instruments(instruments=D.instruments(market=self.market), as_list=True)
|
||||
strategy_config = {
|
||||
"class": "TopkDropoutStrategy",
|
||||
"module_path": "qlib.contrib.strategy.model_strategy",
|
||||
@@ -172,98 +234,6 @@ class NestedDecisonExecutionWorkflow:
|
||||
for trade_decision in data_generator:
|
||||
print(trade_decision)
|
||||
|
||||
def _init_qlib_with_backend(self):
|
||||
provider_uri_1min = HIGH_FREQ_CONFIG.get("provider_uri")
|
||||
if not exists_qlib_data(provider_uri_1min):
|
||||
print(f"Qlib data is not found in {provider_uri_1min}")
|
||||
GetData().qlib_data(target_dir=provider_uri_1min, interval="1min", region=REG_CN)
|
||||
|
||||
# TODO: update latest data
|
||||
provider_uri_day = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
if not exists_qlib_data(provider_uri_day):
|
||||
print(f"Qlib data is not found in {provider_uri_day}")
|
||||
GetData().qlib_data(target_dir=provider_uri_day, region=REG_CN)
|
||||
|
||||
provider_uri_map = {"1min": provider_uri_1min, "day": provider_uri_day}
|
||||
client_config = {
|
||||
"calendar_provider": {
|
||||
"class": "LocalCalendarProvider",
|
||||
"module_path": "qlib.data.data",
|
||||
"kwargs": {
|
||||
"backend": {
|
||||
"class": "FileCalendarStorage",
|
||||
"module_path": "qlib.data.storage.file_storage",
|
||||
"kwargs": {"provider_uri_map": provider_uri_map},
|
||||
}
|
||||
},
|
||||
},
|
||||
"feature_provider": {
|
||||
"class": "LocalFeatureProvider",
|
||||
"module_path": "qlib.data.data",
|
||||
"kwargs": {
|
||||
"backend": {
|
||||
"class": "FileFeatureStorage",
|
||||
"module_path": "qlib.data.storage.file_storage",
|
||||
"kwargs": {"provider_uri_map": provider_uri_map},
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
qlib.init(provider_uri=provider_uri_day, **client_config)
|
||||
|
||||
def _get_highfreq_config(self, model, dataset):
|
||||
|
||||
executor_config = self.port_analysis_config["executor"]
|
||||
# update executor with hierarchical decison freq ["day", "1min"]
|
||||
executor_config["kwargs"]["time_per_step"] = "day"
|
||||
executor_config["kwargs"]["inner_executor"]["kwargs"]["time_per_step"] = "15min"
|
||||
backtest_config = self.port_analysis_config["backtest"]
|
||||
|
||||
# yahoo highfreq data time
|
||||
backtest_config["start_time"] = "2020-09-20"
|
||||
backtest_config["end_time"] = "2021-01-20"
|
||||
|
||||
# update benchmark, yahoo data don't have SH000300
|
||||
instruments = D.instruments(market="csi300")
|
||||
instrument_list = D.list_instruments(instruments=instruments, as_list=True)
|
||||
backtest_config["benchmark"] = instrument_list
|
||||
|
||||
# update exchange config
|
||||
backtest_config["exchange_kwargs"]["freq"] = "1min"
|
||||
|
||||
# set strategy
|
||||
strategy_config = {
|
||||
"class": "TopkDropoutStrategy",
|
||||
"module_path": "qlib.contrib.strategy.model_strategy",
|
||||
"kwargs": {
|
||||
"model": model,
|
||||
"dataset": dataset,
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
},
|
||||
}
|
||||
|
||||
return executor_config, strategy_config, backtest_config
|
||||
|
||||
def backtest_highfreq(self):
|
||||
self._init_qlib_with_backend()
|
||||
model = init_instance_by_config(self.task["model"])
|
||||
dataset = init_instance_by_config(self.task["dataset"])
|
||||
self._train_model(model, dataset)
|
||||
executor_config, strategy_config, backtest_config = self._get_highfreq_config(model, dataset)
|
||||
|
||||
highfreq_port_analysis_config = {
|
||||
"executor": executor_config,
|
||||
"strategy": strategy_config,
|
||||
"backtest": backtest_config,
|
||||
}
|
||||
|
||||
with R.start(experiment_name="backtest_highfreq"):
|
||||
|
||||
recorder = R.get_recorder()
|
||||
par = PortAnaRecord(recorder, highfreq_port_analysis_config, "day")
|
||||
par.generate()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(NestedDecisonExecutionWorkflow)
|
||||
fire.Fire(NestedDecisionExecutionWorkflow)
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
This example is about how can simulate the OnlineManager based on rolling tasks.
|
||||
"""
|
||||
|
||||
from pprint import pprint
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM
|
||||
@@ -13,7 +14,7 @@ from qlib.workflow.online.manager import OnlineManager
|
||||
from qlib.workflow.online.strategy import RollingStrategy
|
||||
from qlib.workflow.task.gen import RollingGen
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG
|
||||
from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG_ONLINE, CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE
|
||||
|
||||
|
||||
class OnlineSimulationExample:
|
||||
@@ -22,8 +23,8 @@ class OnlineSimulationExample:
|
||||
provider_uri="~/.qlib/qlib_data/cn_data",
|
||||
region="cn",
|
||||
exp_name="rolling_exp",
|
||||
task_url="mongodb://10.0.0.4:27017/",
|
||||
task_db_name="rolling_db",
|
||||
task_url="mongodb://10.0.0.4:27017/", # not necessary when using TrainerR or DelayTrainerR
|
||||
task_db_name="rolling_db", # not necessary when using TrainerR or DelayTrainerR
|
||||
task_pool="rolling_task",
|
||||
rolling_step=80,
|
||||
start_time="2018-09-10",
|
||||
@@ -46,7 +47,7 @@ class OnlineSimulationExample:
|
||||
tasks (dict or list[dict]): a set of the task config waiting for rolling and training
|
||||
"""
|
||||
if tasks is None:
|
||||
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG]
|
||||
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE, CSI100_RECORD_LGB_TASK_CONFIG_ONLINE]
|
||||
self.exp_name = exp_name
|
||||
self.task_pool = task_pool
|
||||
self.start_time = start_time
|
||||
@@ -59,7 +60,7 @@ class OnlineSimulationExample:
|
||||
self.rolling_gen = RollingGen(
|
||||
step=rolling_step, rtype=RollingGen.ROLL_SD, ds_extra_mod_func=None
|
||||
) # The rolling tasks generator, ds_extra_mod_func is None because we just need to simulate to 2018-10-31 and needn't change the handler end time.
|
||||
self.trainer = DelayTrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR
|
||||
self.trainer = TrainerRM(self.exp_name, self.task_pool) # Also can be TrainerR, TrainerRM, DelayTrainerR
|
||||
self.rolling_online_manager = OnlineManager(
|
||||
RollingStrategy(exp_name, task_template=tasks, rolling_gen=self.rolling_gen),
|
||||
trainer=self.trainer,
|
||||
@@ -85,6 +86,15 @@ class OnlineSimulationExample:
|
||||
print("========== signals ==========")
|
||||
print(self.rolling_online_manager.get_signals())
|
||||
|
||||
def worker(self):
|
||||
# train tasks by other progress or machines for multiprocessing
|
||||
# FIXME: only can call after finishing simulation when using DelayTrainerRM, or there will be some exception.
|
||||
print("========== worker ==========")
|
||||
if isinstance(self.trainer, TrainerRM):
|
||||
self.trainer.worker()
|
||||
else:
|
||||
print(f"{type(self.trainer)} is not supported for worker.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
## to run all workflow automatically with your own parameters, use the command below
|
||||
|
||||
@@ -13,11 +13,13 @@ Finally, the OnlineManager will finish second routine and update all strategies.
|
||||
import os
|
||||
import fire
|
||||
import qlib
|
||||
from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerR, TrainerRM, end_task_train, task_train
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.online.strategy import RollingStrategy
|
||||
from qlib.workflow.task.gen import RollingGen
|
||||
from qlib.workflow.online.manager import OnlineManager
|
||||
from qlib.tests.config import CSI100_RECORD_XGBOOST_TASK_CONFIG, CSI100_RECORD_LGB_TASK_CONFIG
|
||||
from qlib.tests.config import CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING, CSI100_RECORD_LGB_TASK_CONFIG_ROLLING
|
||||
from qlib.workflow.task.manage import TaskManager
|
||||
|
||||
|
||||
class RollingOnlineExample:
|
||||
@@ -25,16 +27,17 @@ class RollingOnlineExample:
|
||||
self,
|
||||
provider_uri="~/.qlib/qlib_data/cn_data",
|
||||
region="cn",
|
||||
task_url="mongodb://10.0.0.4:27017/",
|
||||
task_db_name="rolling_db",
|
||||
trainer=DelayTrainerRM(), # you can choose from TrainerR, TrainerRM, DelayTrainerR, DelayTrainerRM
|
||||
task_url="mongodb://10.0.0.4:27017/", # not necessary when using TrainerR or DelayTrainerR
|
||||
task_db_name="rolling_db", # not necessary when using TrainerR or DelayTrainerR
|
||||
rolling_step=550,
|
||||
tasks=None,
|
||||
add_tasks=None,
|
||||
):
|
||||
if add_tasks is None:
|
||||
add_tasks = [CSI100_RECORD_LGB_TASK_CONFIG]
|
||||
add_tasks = [CSI100_RECORD_LGB_TASK_CONFIG_ROLLING]
|
||||
if tasks is None:
|
||||
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG]
|
||||
tasks = [CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING]
|
||||
mongo_conf = {
|
||||
"task_url": task_url, # your MongoDB url
|
||||
"task_db_name": task_db_name, # database name
|
||||
@@ -53,17 +56,28 @@ class RollingOnlineExample:
|
||||
RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD),
|
||||
)
|
||||
)
|
||||
|
||||
self.rolling_online_manager = OnlineManager(strategies)
|
||||
self.trainer = trainer
|
||||
self.rolling_online_manager = OnlineManager(strategies, trainer=self.trainer)
|
||||
|
||||
_ROLLING_MANAGER_PATH = (
|
||||
".RollingOnlineExample" # the OnlineManager will dump to this file, for it can be loaded when calling routine.
|
||||
)
|
||||
|
||||
def worker(self):
|
||||
# train tasks by other progress or machines for multiprocessing
|
||||
print("========== worker ==========")
|
||||
if isinstance(self.trainer, TrainerRM):
|
||||
for task in self.tasks + self.add_tasks:
|
||||
name_id = task["model"]["class"]
|
||||
self.trainer.worker(experiment_name=name_id)
|
||||
else:
|
||||
print(f"{type(self.trainer)} is not supported for worker.")
|
||||
|
||||
# Reset all things to the first status, be careful to save important data
|
||||
def reset(self):
|
||||
for task in self.tasks + self.add_tasks:
|
||||
name_id = task["model"]["class"]
|
||||
TaskManager(task_pool=name_id).remove()
|
||||
exp = R.get_exp(experiment_name=name_id)
|
||||
for rid in exp.list_recorders():
|
||||
exp.delete_recorder(rid)
|
||||
|
||||
@@ -362,8 +362,9 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"name": "pythonjvsc74a57bd0fcc004278713aaede7c629a6a43738a929cb09abb52817d4f72eb70db44cd87b",
|
||||
"display_name": "Python 3.8 ('qlib_backtest': conda)"
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
@@ -375,7 +376,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8"
|
||||
"version": "3.8.3"
|
||||
},
|
||||
"toc": {
|
||||
"base_numbering": 1,
|
||||
@@ -389,11 +390,6 @@
|
||||
"toc_position": {},
|
||||
"toc_section_display": true,
|
||||
"toc_window_display": false
|
||||
},
|
||||
"metadata": {
|
||||
"interpreter": {
|
||||
"hash": "fcc004278713aaede7c629a6a43738a929cb09abb52817d4f72eb70db44cd87b"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
import copy
|
||||
from typing import Union
|
||||
|
||||
from .account import Account
|
||||
from .exchange import Exchange
|
||||
from .executor import BaseExecutor
|
||||
from .backtest import backtest as backtest_func
|
||||
from .backtest import collect_data as data_generator
|
||||
from .order import Order
|
||||
from .utils import TradeCalendarManager
|
||||
|
||||
from .utils import CommonInfrastructure
|
||||
from .backtest import backtest_loop
|
||||
from .backtest import collect_data_loop
|
||||
from .utils import CommonInfrastructure, TradeCalendarManager
|
||||
from .order import Order
|
||||
|
||||
from ..strategy.base import BaseStrategy
|
||||
from ..utils import init_instance_by_config
|
||||
from ..log import get_module_logger
|
||||
@@ -92,42 +93,114 @@ def get_exchange(
|
||||
return init_instance_by_config(exchange, accept_types=Exchange)
|
||||
|
||||
|
||||
def get_strategy_executor(
|
||||
start_time, end_time, strategy, executor, benchmark="SH000300", account=1e9, exchange_kwargs={}
|
||||
):
|
||||
trade_account = Account(
|
||||
init_cash=account,
|
||||
benchmark_config={
|
||||
def create_account_instance(
|
||||
start_time, end_time, benchmark: str, account: float, pos_type: str = "Position"
|
||||
) -> Account:
|
||||
"""
|
||||
# TODO: is very strange pass benchmark_config in the account(maybe for report)
|
||||
# There should be a post-step to process the report.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
start_time :
|
||||
start time of the benchmark
|
||||
end_time :
|
||||
end time of the benchmark
|
||||
benchmark : str
|
||||
the benchmark for reporting
|
||||
account : Union[float, str]
|
||||
information for describing how to creating the account
|
||||
For `float`
|
||||
Using Account with a normal position
|
||||
For `str`:
|
||||
Using account with a specific Position
|
||||
"""
|
||||
kwargs = {
|
||||
"init_cash": account,
|
||||
"benchmark_config": {
|
||||
"benchmark": benchmark,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
},
|
||||
"pos_type": pos_type,
|
||||
}
|
||||
return Account(**kwargs)
|
||||
|
||||
|
||||
def get_strategy_executor(
|
||||
start_time,
|
||||
end_time,
|
||||
strategy: BaseStrategy,
|
||||
executor: BaseExecutor,
|
||||
benchmark: str = "SH000300",
|
||||
account: Union[float, str] = 1e9,
|
||||
exchange_kwargs: dict = {},
|
||||
pos_type: str = "Position",
|
||||
):
|
||||
|
||||
trade_account = create_account_instance(
|
||||
start_time=start_time, end_time=end_time, benchmark=benchmark, account=account, pos_type=pos_type
|
||||
)
|
||||
|
||||
exchange_kwargs = copy.copy(exchange_kwargs)
|
||||
if "start_time" not in exchange_kwargs:
|
||||
exchange_kwargs["start_time"] = start_time
|
||||
if "end_time" not in exchange_kwargs:
|
||||
exchange_kwargs["end_time"] = end_time
|
||||
trade_exchange = get_exchange(**exchange_kwargs)
|
||||
|
||||
common_infra = CommonInfrastructure(trade_account=trade_account, trade_exchange=trade_exchange)
|
||||
|
||||
trade_strategy = init_instance_by_config(strategy, accept_types=BaseStrategy, common_infra=common_infra)
|
||||
trade_executor = init_instance_by_config(executor, accept_types=BaseExecutor, common_infra=common_infra)
|
||||
|
||||
return trade_strategy, trade_executor
|
||||
|
||||
|
||||
def backtest(start_time, end_time, strategy, executor, benchmark="SH000300", account=1e9, exchange_kwargs={}):
|
||||
def backtest(
|
||||
start_time,
|
||||
end_time,
|
||||
strategy,
|
||||
executor,
|
||||
benchmark="SH000300",
|
||||
account=1e9,
|
||||
exchange_kwargs={},
|
||||
pos_type: str = "Position",
|
||||
):
|
||||
|
||||
trade_strategy, trade_executor = get_strategy_executor(
|
||||
start_time, end_time, strategy, executor, benchmark, account, exchange_kwargs
|
||||
start_time,
|
||||
end_time,
|
||||
strategy,
|
||||
executor,
|
||||
benchmark,
|
||||
account,
|
||||
exchange_kwargs,
|
||||
pos_type=pos_type,
|
||||
)
|
||||
report_dict = backtest_func(start_time, end_time, trade_strategy, trade_executor)
|
||||
report_dict, indicator_dict = backtest_loop(start_time, end_time, trade_strategy, trade_executor)
|
||||
|
||||
return report_dict
|
||||
return report_dict, indicator_dict
|
||||
|
||||
|
||||
def collect_data(start_time, end_time, strategy, executor, benchmark="SH000300", account=1e9, exchange_kwargs={}):
|
||||
def collect_data(
|
||||
start_time,
|
||||
end_time,
|
||||
strategy,
|
||||
executor,
|
||||
benchmark="SH000300",
|
||||
account=1e9,
|
||||
exchange_kwargs={},
|
||||
pos_type: str = "Position",
|
||||
):
|
||||
|
||||
trade_strategy, trade_executor = get_strategy_executor(
|
||||
start_time, end_time, strategy, executor, benchmark, account, exchange_kwargs
|
||||
start_time,
|
||||
end_time,
|
||||
strategy,
|
||||
executor,
|
||||
benchmark,
|
||||
account,
|
||||
exchange_kwargs,
|
||||
pos_type=pos_type,
|
||||
)
|
||||
report_dict = yield from data_generator(start_time, end_time, trade_strategy, trade_executor)
|
||||
|
||||
return report_dict
|
||||
yield from collect_data_loop(start_time, end_time, trade_strategy, trade_executor)
|
||||
|
||||
@@ -3,13 +3,14 @@
|
||||
|
||||
|
||||
import copy
|
||||
from qlib.utils import init_instance_by_config
|
||||
import warnings
|
||||
import pandas as pd
|
||||
|
||||
from .position import Position
|
||||
from .report import Report
|
||||
from .position import BasePosition, InfPosition, Position
|
||||
from .report import Report, Indicator
|
||||
from .order import Order
|
||||
|
||||
from .exchange import Exchange
|
||||
|
||||
"""
|
||||
rtn & earning in the Account
|
||||
@@ -25,29 +26,70 @@ rtn & earning in the Account
|
||||
while earning is the difference of two position value, so it considers cost, it is the true return rate
|
||||
in the specific accomplishment for rtn, it does not consider cost, in other words, rtn - cost = earning
|
||||
|
||||
Now rtn has been removed in the hierarchical backtest implemention.
|
||||
"""
|
||||
|
||||
|
||||
class AccumulatedInfo:
|
||||
"""accumulated trading info, including accumulated return\cost\turnover"""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.rtn = 0 # accumulated return, do not consider cost
|
||||
self.cost = 0 # accumulated cost
|
||||
self.to = 0 # accumulated turnover
|
||||
|
||||
def add_return_value(self, value):
|
||||
self.rtn += value
|
||||
|
||||
def add_cost(self, value):
|
||||
self.cost += value
|
||||
|
||||
def add_turnover(self, value):
|
||||
self.to += value
|
||||
|
||||
@property
|
||||
def get_return(self):
|
||||
return self.rtn
|
||||
|
||||
@property
|
||||
def get_cost(self):
|
||||
return self.cost
|
||||
|
||||
@property
|
||||
def get_turnover(self):
|
||||
return self.to
|
||||
|
||||
|
||||
class Account:
|
||||
def __init__(self, init_cash, freq: str = "day", benchmark_config: dict = {}):
|
||||
def __init__(
|
||||
self, init_cash: float = 1e9, freq: str = "day", benchmark_config: dict = {}, pos_type: str = "Position"
|
||||
):
|
||||
self.pos_type = pos_type
|
||||
self.init_vars(init_cash, freq, benchmark_config)
|
||||
|
||||
def init_vars(self, init_cash, freq: str, benchmark_config: dict):
|
||||
|
||||
# init cash
|
||||
self.init_cash = init_cash
|
||||
self.current = Position(cash=init_cash)
|
||||
self.current: BasePosition = init_instance_by_config(
|
||||
{
|
||||
"class": self.pos_type,
|
||||
"kwargs": {"cash": init_cash},
|
||||
"module_path": "qlib.backtest.position",
|
||||
}
|
||||
)
|
||||
self.accum_info = AccumulatedInfo()
|
||||
self.reset(freq=freq, benchmark_config=benchmark_config, init_report=True)
|
||||
|
||||
def reset_report(self, freq, benchmark_config):
|
||||
# portfolio related metrics
|
||||
self.report = Report(freq, benchmark_config)
|
||||
self.positions = {}
|
||||
self.rtn = 0
|
||||
self.ct = 0
|
||||
self.to = 0
|
||||
self.val = 0
|
||||
self.earning = 0
|
||||
|
||||
# trading related matric(e.g. high-frequency trading)
|
||||
self.indicator = Indicator()
|
||||
|
||||
def reset(self, freq=None, benchmark_config=None, init_report=False):
|
||||
"""reset freq and report of account
|
||||
@@ -73,27 +115,33 @@ class Account:
|
||||
return self.positions
|
||||
|
||||
def get_cash(self):
|
||||
return self.current.position["cash"]
|
||||
return self.current.get_cash()
|
||||
|
||||
def _update_state_from_order(self, order, trade_val, cost, trade_price):
|
||||
# update turnover
|
||||
self.to += trade_val
|
||||
self.accum_info.add_turnover(trade_val)
|
||||
# update cost
|
||||
self.ct += cost
|
||||
# update return
|
||||
# update self.rtn from order
|
||||
self.accum_info.add_cost(cost)
|
||||
|
||||
# update return from order
|
||||
trade_amount = trade_val / trade_price
|
||||
if order.direction == Order.SELL: # 0 for sell
|
||||
# when sell stock, get profit from price change
|
||||
profit = trade_val - self.current.get_stock_price(order.stock_id) * trade_amount
|
||||
self.rtn += profit # note here do not consider cost
|
||||
self.accum_info.add_return_value(profit) # note here do not consider cost
|
||||
|
||||
elif order.direction == Order.BUY: # 1 for buy
|
||||
# when buy stock, we get return for the rtn computing method
|
||||
# profit in buy order is to make self.rtn is consistent with self.earning at the end of date
|
||||
# profit in buy order is to make rtn is consistent with earning at the end of bar
|
||||
profit = self.current.get_stock_price(order.stock_id) * trade_amount - trade_val
|
||||
self.rtn += profit
|
||||
self.accum_info.add_return_value(profit) # note here do not consider cost
|
||||
|
||||
def update_order(self, order, trade_val, cost, trade_price):
|
||||
if self.current.skip_update():
|
||||
# TODO: supporting polymorphism for account
|
||||
# updating order for infinite position is meaningless
|
||||
return
|
||||
|
||||
# if stock is sold out, no stock price information in Position, then we should update account first, then update current position
|
||||
# if stock is bought, there is no stock in current position, update current, then update account
|
||||
# The cost will be substracted from the cash at last. So the trading logic can ignore the cost calculation
|
||||
@@ -110,47 +158,44 @@ class Account:
|
||||
self._update_state_from_order(order, trade_val, cost, trade_price)
|
||||
|
||||
def update_bar_count(self):
|
||||
self.current.add_count_all(bar=self.freq)
|
||||
|
||||
def update_bar_report(self, trade_start_time, trade_end_time, trade_exchange):
|
||||
"""
|
||||
trade_start_time: pd.TimeStamp
|
||||
trade_end_time: pd.TimeStamp
|
||||
quote: pd.DataFrame (code, date), collumns
|
||||
when the end of trade date
|
||||
- update rtn
|
||||
- update price for each asset
|
||||
- update value for this account
|
||||
- update earning (2nd view of return )
|
||||
- update holding day, count of stock
|
||||
- update position hitory
|
||||
- update report
|
||||
:return: None
|
||||
"""
|
||||
# update price for stock in the position and the profit from changed_price
|
||||
stock_list = self.current.get_stock_list()
|
||||
for code in stock_list:
|
||||
# if suspend, no new price to be updated, profit is 0
|
||||
if trade_exchange.check_stock_suspended(code, trade_start_time, trade_end_time):
|
||||
continue
|
||||
bar_close = trade_exchange.get_close(code, trade_start_time, trade_end_time)
|
||||
self.current.update_stock_price(stock_id=code, price=bar_close)
|
||||
"""at the end of the trading bar, update holding bar, count of stock"""
|
||||
# update holding day count
|
||||
if not self.current.skip_update():
|
||||
self.current.add_count_all(bar=self.freq)
|
||||
|
||||
# update value
|
||||
self.val = self.current.calculate_value()
|
||||
# update earning
|
||||
def update_current(self, trade_start_time, trade_end_time, trade_exchange):
|
||||
"""update current to make rtn consistent with earning at the end of bar"""
|
||||
# update price for stock in the position and the profit from changed_price
|
||||
if not self.current.skip_update():
|
||||
stock_list = self.current.get_stock_list()
|
||||
for code in stock_list:
|
||||
# if suspend, no new price to be updated, profit is 0
|
||||
if trade_exchange.check_stock_suspended(code, trade_start_time, trade_end_time):
|
||||
continue
|
||||
bar_close = trade_exchange.get_close(code, trade_start_time, trade_end_time)
|
||||
self.current.update_stock_price(stock_id=code, price=bar_close)
|
||||
|
||||
def update_report(self, trade_start_time, trade_end_time):
|
||||
"""update position history, report"""
|
||||
# calculate earning
|
||||
# account_value - last_account_value
|
||||
# for the first trade date, account_value - init_cash
|
||||
# self.report.is_empty() to judge is_first_trade_date
|
||||
# get last_account_value, now_account_value, now_stock_value
|
||||
# get last_account_value, last_total_cost, last_total_turnover
|
||||
if self.report.is_empty():
|
||||
last_account_value = self.init_cash
|
||||
last_total_cost = 0
|
||||
last_total_turnover = 0
|
||||
else:
|
||||
last_account_value = self.report.get_latest_account_value()
|
||||
last_total_cost = self.report.get_latest_total_cost()
|
||||
last_total_turnover = self.report.get_latest_total_turnover()
|
||||
# get now_account_value, now_stock_value, now_earning, now_cost, now_turnover
|
||||
now_account_value = self.current.calculate_value()
|
||||
now_stock_value = self.current.calculate_stock_value()
|
||||
self.earning = now_account_value - last_account_value
|
||||
now_earning = now_account_value - last_account_value
|
||||
now_cost = self.accum_info.get_cost - last_total_cost
|
||||
now_turnover = self.accum_info.get_turnover - last_total_turnover
|
||||
# update report for today
|
||||
# judge whether the the trading is begin.
|
||||
# and don't add init account state into report, due to we don't have excess return in those days.
|
||||
@@ -159,11 +204,13 @@ class Account:
|
||||
trade_end_time=trade_end_time,
|
||||
account_value=now_account_value,
|
||||
cash=self.current.position["cash"],
|
||||
return_rate=(self.earning + self.ct) / last_account_value,
|
||||
return_rate=(now_earning + now_cost) / last_account_value,
|
||||
# here use earning to calculate return, position's view, earning consider cost, true return
|
||||
# in order to make same definition with original backtest in evaluate.py
|
||||
turnover_rate=self.to / last_account_value,
|
||||
cost_rate=self.ct / last_account_value,
|
||||
total_turnover=self.accum_info.get_turnover,
|
||||
turnover_rate=now_turnover / last_account_value,
|
||||
total_cost=self.accum_info.get_cost,
|
||||
cost_rate=now_cost / last_account_value,
|
||||
stock_value=now_stock_value,
|
||||
)
|
||||
# set now_account_value to position
|
||||
@@ -173,8 +220,63 @@ class Account:
|
||||
# note use deepcopy
|
||||
self.positions[trade_start_time] = copy.deepcopy(self.current)
|
||||
|
||||
# finish today's updation
|
||||
# reset the bar variables
|
||||
self.rtn = 0
|
||||
self.ct = 0
|
||||
self.to = 0
|
||||
def update_bar_end(
|
||||
self,
|
||||
trade_start_time: pd.Timestamp,
|
||||
trade_end_time: pd.Timestamp,
|
||||
trade_exchange: Exchange,
|
||||
atomic: bool,
|
||||
generate_report: bool = False,
|
||||
trade_info: list = None,
|
||||
inner_order_indicators: Indicator = None,
|
||||
indicator_config: dict = {},
|
||||
):
|
||||
"""update account at each trading bar step
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trade_start_time : pd.Timestamp
|
||||
closed start time of step
|
||||
trade_end_time : pd.Timestamp
|
||||
closed end time of step
|
||||
trade_exchange : Exchange
|
||||
trading exchange, used to update current
|
||||
atomic : bool
|
||||
whether the trading executor is atomic, which means there is no higher-frequency trading executor inside it
|
||||
- if atomic is True, calculate the indicators with trade_info
|
||||
- else, aggregate indicators with inner indicators
|
||||
generate_report : bool, optional
|
||||
whether to generate report, by default False
|
||||
trade_info : List[(Order, float, float, float)], optional
|
||||
trading information, by default None
|
||||
- necessary if atomic is True
|
||||
- list of tuple(order, trade_val, trade_cost, trade_price)
|
||||
inner_order_indicators : Indicator, optional
|
||||
indicators of inner executor, by default None
|
||||
- necessary if atomic is False
|
||||
- used to aggregate outer indicators
|
||||
indicator_config : dict, optional
|
||||
config of calculating indicators, by default {}
|
||||
"""
|
||||
if atomic is True and trade_info is None:
|
||||
raise ValueError("trade_info is necessary in atomic executor")
|
||||
elif atomic is False and inner_order_indicators is None:
|
||||
raise ValueError("inner_order_indicators is necessary in unatomic executor")
|
||||
|
||||
if generate_report:
|
||||
# report is portfolio related analysis
|
||||
# TODO: `update_bar_count` and `update_current` should placed in Position and be merged.
|
||||
self.update_bar_count()
|
||||
self.update_current(trade_start_time, trade_end_time, trade_exchange)
|
||||
self.update_report(trade_start_time, trade_end_time)
|
||||
|
||||
# indicator is trading (e.g. high-frequency order execution) related analysis
|
||||
self.indicator.clear()
|
||||
|
||||
if atomic:
|
||||
self.indicator.update_order_indicators(trade_start_time, trade_end_time, trade_info, trade_exchange)
|
||||
else:
|
||||
self.indicator.agg_order_indicators(inner_order_indicators, indicator_config)
|
||||
|
||||
self.indicator.cal_trade_indicators(trade_start_time, self.freq, indicator_config)
|
||||
self.indicator.record(trade_start_time)
|
||||
|
||||
@@ -1,30 +1,76 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from qlib.backtest.order import BaseTradeDecision
|
||||
from qlib.strategy.base import BaseStrategy
|
||||
from qlib.backtest.executor import BaseExecutor
|
||||
from ..utils.time import Freq
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
|
||||
def backtest(start_time, end_time, trade_strategy, trade_executor):
|
||||
def backtest_loop(start_time, end_time, trade_strategy: BaseStrategy, trade_executor: BaseExecutor):
|
||||
"""backtest funciton for the interaction of the outermost strategy and executor in the nested decision execution
|
||||
|
||||
please refer to the docs of `collect_data_loop`
|
||||
|
||||
Returns
|
||||
-------
|
||||
report: Report
|
||||
it records the trading report information
|
||||
"""
|
||||
return_value = {}
|
||||
for _decision in collect_data_loop(start_time, end_time, trade_strategy, trade_executor, return_value):
|
||||
pass
|
||||
return return_value.get("report"), return_value.get("indicator")
|
||||
|
||||
|
||||
def collect_data_loop(
|
||||
start_time, end_time, trade_strategy: BaseStrategy, trade_executor: BaseExecutor, return_value: dict = None
|
||||
):
|
||||
"""Generator for collecting the trade decision data for rl training
|
||||
|
||||
Parameters
|
||||
----------
|
||||
start_time : pd.Timestamp|str
|
||||
closed start time for backtest
|
||||
**NOTE**: This will be applied to the outmost executor's calendar.
|
||||
end_time : pd.Timestamp|str
|
||||
closed end time for backtest
|
||||
**NOTE**: This will be applied to the outmost executor's calendar.
|
||||
E.g. Executor[day](Executor[1min]), setting `end_time == 20XX0301` will include all the minutes on 20XX0301
|
||||
trade_strategy : BaseStrategy
|
||||
the outermost portfolio strategy
|
||||
trade_executor : BaseExecutor
|
||||
the outermost executor
|
||||
return_value : dict
|
||||
used for backtest_loop
|
||||
|
||||
Yields
|
||||
-------
|
||||
object
|
||||
trade decision
|
||||
"""
|
||||
trade_executor.reset(start_time=start_time, end_time=end_time)
|
||||
level_infra = trade_executor.get_level_infra()
|
||||
trade_strategy.reset(level_infra=level_infra)
|
||||
|
||||
_execute_result = None
|
||||
while not trade_executor.finished():
|
||||
_trade_decision = trade_strategy.generate_trade_decision(_execute_result)
|
||||
_execute_result = trade_executor.execute(_trade_decision)
|
||||
with tqdm(total=trade_executor.trade_calendar.get_trade_len(), desc="backtest loop") as bar:
|
||||
_execute_result = None
|
||||
while not trade_executor.finished():
|
||||
_trade_decision: BaseTradeDecision = trade_strategy.generate_trade_decision(_execute_result)
|
||||
_execute_result = yield from trade_executor.collect_data(_trade_decision)
|
||||
bar.update(1)
|
||||
|
||||
return trade_executor.get_report()
|
||||
if return_value is not None:
|
||||
all_executors = trade_executor.get_all_executors()
|
||||
|
||||
|
||||
def collect_data(start_time, end_time, trade_strategy, trade_executor):
|
||||
|
||||
trade_executor.reset(start_time=start_time, end_time=end_time)
|
||||
level_infra = trade_executor.get_level_infra()
|
||||
trade_strategy.reset(level_infra=level_infra)
|
||||
|
||||
_execute_result = None
|
||||
while not trade_executor.finished():
|
||||
_trade_decision = trade_strategy.generate_trade_decision(_execute_result)
|
||||
_execute_result = yield from trade_executor.collect_data(_trade_decision)
|
||||
|
||||
return trade_executor.get_report()
|
||||
all_reports = {
|
||||
"{}{}".format(*Freq.parse(_executor.time_per_step)): _executor.get_report()
|
||||
for _executor in all_executors
|
||||
if _executor.generate_report
|
||||
}
|
||||
all_indicators = {}
|
||||
for _executor in all_executors:
|
||||
key = "{}{}".format(*Freq.parse(_executor.time_per_step))
|
||||
all_indicators[key] = _executor.get_trade_indicator().generate_trade_indicators_dataframe()
|
||||
all_indicators[key + "_obj"] = _executor.get_trade_indicator()
|
||||
return_value.update({"report": all_reports, "indicator": all_indicators})
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
import random
|
||||
import logging
|
||||
from typing import Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
@@ -48,14 +49,17 @@ class Exchange:
|
||||
:param trade_unit: trade unit, 100 for China A market
|
||||
:param min_cost: min cost, default 5
|
||||
:param extra_quote: pandas, dataframe consists of
|
||||
columns: like ['$vwap', '$close', '$factor', 'limit'].
|
||||
columns: like ['$vwap', '$close', '$volume', '$factor', 'limit_sell', 'limit_buy'].
|
||||
The limit indicates that the etf is tradable on a specific day.
|
||||
Necessary fields:
|
||||
$close is for calculating the total value at end of each day.
|
||||
Optional fields:
|
||||
$volume is only necessary when we limit the trade amount or caculate PA(vwap) indicator
|
||||
$vwap is only necessary when we use the $vwap price as the deal price
|
||||
$factor is for rounding to the trading unit
|
||||
limit will be set to False by default(False indicates we can buy this
|
||||
limit_sell will be set to False by default(False indicates we can sell this
|
||||
target on this day).
|
||||
limit_buy will be set to False by default(False indicates we can buy this
|
||||
target on this day).
|
||||
index: MultipleIndex(instrument, pd.Datetime)
|
||||
"""
|
||||
@@ -171,8 +175,8 @@ class Exchange:
|
||||
self.quote = quote_dict
|
||||
|
||||
def _update_limit(self, buy_limit, sell_limit):
|
||||
self.quote["limit_buy"] = ~self.quote["$change"].lt(buy_limit)
|
||||
self.quote["limit_sell"] = ~self.quote["$change"].gt(-sell_limit)
|
||||
self.quote["limit_buy"] = self.quote["$change"].ge(buy_limit)
|
||||
self.quote["limit_sell"] = self.quote["$change"].le(-sell_limit)
|
||||
|
||||
def check_stock_limit(self, stock_id, start_time, end_time, direction=None):
|
||||
"""
|
||||
@@ -256,6 +260,16 @@ class Exchange:
|
||||
|
||||
return trade_val, trade_cost, trade_price
|
||||
|
||||
def create_order(self, code, amount, start_time, end_time, direction) -> Order:
|
||||
return Order(
|
||||
stock_id=code,
|
||||
amount=amount,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
direction=direction,
|
||||
factor=self.get_factor(code, start_time, end_time),
|
||||
)
|
||||
|
||||
def get_quote_info(self, stock_id, start_time, end_time):
|
||||
return resam_ts_data(self.quote[stock_id], start_time, end_time, method="last").iloc[0]
|
||||
|
||||
@@ -275,8 +289,20 @@ class Exchange:
|
||||
deal_price = self.get_close(stock_id, start_time, end_time)
|
||||
return deal_price
|
||||
|
||||
def get_factor(self, stock_id, start_time, end_time):
|
||||
return resam_ts_data(self.quote[stock_id]["$factor"], start_time, end_time, method="last").iloc[0]
|
||||
def get_factor(self, stock_id, start_time, end_time) -> Union[float, None]:
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
Union[float, None]:
|
||||
`None`: if the stock is suspended `None` may be returned
|
||||
`float`: return factor if the factor exists
|
||||
"""
|
||||
if stock_id not in self.quote:
|
||||
return None
|
||||
res = resam_ts_data(self.quote[stock_id]["$factor"], start_time, end_time, method="last")
|
||||
if res is not None:
|
||||
res = res.iloc[0]
|
||||
return res
|
||||
|
||||
def generate_amount_position_from_weight_position(self, weight_position, cash, start_time, end_time):
|
||||
"""
|
||||
@@ -342,7 +368,10 @@ class Exchange:
|
||||
return -deal_amount
|
||||
|
||||
def generate_order_for_target_amount_position(self, target_position, current_position, start_time, end_time):
|
||||
"""Parameter:
|
||||
"""
|
||||
Note: some future information is used in this function
|
||||
|
||||
Parameter:
|
||||
target_position : dict { stock_id : amount }
|
||||
current_postion : dict { stock_id : amount}
|
||||
trade_unit : trade_unit
|
||||
|
||||
@@ -3,14 +3,16 @@ import warnings
|
||||
import pandas as pd
|
||||
from typing import Union
|
||||
|
||||
from ..utils import init_instance_by_config
|
||||
from ..utils.resam import parse_freq
|
||||
from qlib.backtest.report import Indicator
|
||||
|
||||
|
||||
from .order import Order
|
||||
from .order import Order, BaseTradeDecision
|
||||
from .exchange import Exchange
|
||||
from .utils import TradeCalendarManager, CommonInfrastructure, LevelInfrastructure
|
||||
|
||||
from ..utils import init_instance_by_config
|
||||
from ..utils.time import Freq
|
||||
from ..strategy.base import BaseStrategy
|
||||
|
||||
|
||||
class BaseExecutor:
|
||||
"""Base executor for trading"""
|
||||
@@ -20,6 +22,7 @@ class BaseExecutor:
|
||||
time_per_step: str,
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
indicator_config: dict = {},
|
||||
generate_report: bool = False,
|
||||
verbose: bool = False,
|
||||
track_data: bool = False,
|
||||
@@ -31,12 +34,47 @@ class BaseExecutor:
|
||||
----------
|
||||
time_per_step : str
|
||||
trade time per trading step, used for genreate the trade calendar
|
||||
show_indicator: bool, optional
|
||||
whether to show indicators, :
|
||||
- 'pa', the price advantage
|
||||
- 'pos', the positive rate
|
||||
- 'ffr', the fulfill rate
|
||||
indicator_config: dict, optional
|
||||
config for calculating trade indicator, including the following fields:
|
||||
- 'show_indicator': whether to show indicators, optional, default by False. The indicators includes
|
||||
- 'pa', the price advantage
|
||||
- 'pos', the positive rate
|
||||
- 'ffr', the fulfill rate
|
||||
- 'pa_config': config for calculating price advantage(pa), optional
|
||||
- 'base_price': the based price than which the trading price is advanced, Optional, default by 'twap'
|
||||
- If 'base_price' is 'twap', the based price is the time weighted average price
|
||||
- If 'base_price' is 'vwap', the based price is the volume weighted average price
|
||||
- 'weight_method': weighted method when calculating total trading pa by different orders' pa in each step, optional, default by 'mean'
|
||||
- If 'weight_method' is 'mean', calculating mean value of different orders' pa
|
||||
- If 'weight_method' is 'amount_weighted', calculating amount weighted average value of different orders' pa
|
||||
- If 'weight_method' is 'value_weighted', calculating value weighted average value of different orders' pa
|
||||
- 'ffr_config': config for calculating fulfill rate(ffr), optional
|
||||
- 'weight_method': weighted method when calculating total trading ffr by different orders' ffr in each step, optional, default by 'mean'
|
||||
- If 'weight_method' is 'mean', calculating mean value of different orders' ffr
|
||||
- If 'weight_method' is 'amount_weighted', calculating amount weighted average value of different orders' ffr
|
||||
- If 'weight_method' is 'value_weighted', calculating value weighted average value of different orders' ffr
|
||||
Example:
|
||||
{
|
||||
'show_indicator': True,
|
||||
'pa_config': {
|
||||
'base_value': 'twap',
|
||||
'weight_method': 'value_weighted',
|
||||
},
|
||||
'ffr_config':{
|
||||
'weight_method': 'value_weighted',
|
||||
}
|
||||
}
|
||||
generate_report : bool, optional
|
||||
whether to generate report, by default False
|
||||
verbose : bool, optional
|
||||
whether to print trading info, by default False
|
||||
track_data : bool, optional
|
||||
whether to generate trade_decision, will be used when making data for multi-level training
|
||||
whether to generate trade_decision, will be used when training rl agent
|
||||
- If `self.track_data` is true, when making data for training, the input `trade_decision` of `execute` will be generated by `collect_data`
|
||||
- Else, `trade_decision` will not be generated
|
||||
common_infra : CommonInfrastructure, optional:
|
||||
@@ -48,6 +86,7 @@ class BaseExecutor:
|
||||
|
||||
"""
|
||||
self.time_per_step = time_per_step
|
||||
self.indicator_config = indicator_config
|
||||
self.generate_report = generate_report
|
||||
self.verbose = verbose
|
||||
self.track_data = track_data
|
||||
@@ -98,28 +137,51 @@ class BaseExecutor:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trade_decision : object
|
||||
trade_decision : BaseTradeDecision
|
||||
|
||||
Returns
|
||||
----------
|
||||
execute_result : List[object]
|
||||
the executed result for trade decison
|
||||
the executed result for trade decision
|
||||
"""
|
||||
raise NotImplementedError("execute is not implemented!")
|
||||
|
||||
def collect_data(self, trade_decision):
|
||||
"""Generator for collecting the trade decision data for rl training
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trade_decision : BaseTradeDecision
|
||||
|
||||
Returns
|
||||
----------
|
||||
execute_result : List[object]
|
||||
the executed result for trade decision
|
||||
|
||||
Yields
|
||||
-------
|
||||
object
|
||||
trade decision
|
||||
"""
|
||||
if self.track_data:
|
||||
yield trade_decision
|
||||
return self.execute(trade_decision)
|
||||
|
||||
def get_trade_account(self):
|
||||
raise NotImplementedError("get_trade_account is not implemented!")
|
||||
|
||||
def get_report(self):
|
||||
raise NotImplementedError("get_report is not implemented!")
|
||||
"""get the history report and postions instance"""
|
||||
if self.generate_report:
|
||||
_report = self.trade_account.report.generate_report_dataframe()
|
||||
_positions = self.trade_account.get_positions()
|
||||
return _report, _positions
|
||||
else:
|
||||
raise ValueError("generate_report should be True if you want to generate report")
|
||||
|
||||
def get_trade_indicator(self) -> Indicator:
|
||||
"""get the trade indicator instance, which has pa/pos/ffr info."""
|
||||
return self.trade_account.indicator
|
||||
|
||||
def get_all_executors(self):
|
||||
"""Return all executors"""
|
||||
"""get all executors"""
|
||||
return [self]
|
||||
|
||||
|
||||
@@ -129,8 +191,6 @@ class NestedExecutor(BaseExecutor):
|
||||
- At each time `execute` is called, it will call the inner strategy and executor to execute the `trade_decision` in a higher frequency env.
|
||||
"""
|
||||
|
||||
from ..strategy.base import BaseStrategy
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
time_per_step: str,
|
||||
@@ -138,6 +198,7 @@ class NestedExecutor(BaseExecutor):
|
||||
inner_strategy: Union[BaseStrategy, dict],
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
indicator_config: dict = {},
|
||||
generate_report: bool = False,
|
||||
verbose: bool = False,
|
||||
track_data: bool = False,
|
||||
@@ -161,13 +222,14 @@ class NestedExecutor(BaseExecutor):
|
||||
inner_executor, common_infra=common_infra, accept_types=BaseExecutor
|
||||
)
|
||||
self.inner_strategy = init_instance_by_config(
|
||||
inner_strategy, common_infra=common_infra, accept_types=self.BaseStrategy
|
||||
inner_strategy, common_infra=common_infra, accept_types=BaseStrategy
|
||||
)
|
||||
|
||||
super(NestedExecutor, self).__init__(
|
||||
time_per_step=time_per_step,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
indicator_config=indicator_config,
|
||||
generate_report=generate_report,
|
||||
verbose=verbose,
|
||||
track_data=track_data,
|
||||
@@ -175,7 +237,7 @@ class NestedExecutor(BaseExecutor):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if generate_report and trade_exchange is not None:
|
||||
if trade_exchange is not None:
|
||||
self.trade_exchange = trade_exchange
|
||||
|
||||
def reset_common_infra(self, common_infra):
|
||||
@@ -186,7 +248,7 @@ class NestedExecutor(BaseExecutor):
|
||||
"""
|
||||
super(NestedExecutor, self).reset_common_infra(common_infra)
|
||||
|
||||
if self.generate_report and common_infra.has("trade_exchange"):
|
||||
if common_infra.has("trade_exchange"):
|
||||
self.trade_exchange = common_infra.get("trade_exchange")
|
||||
|
||||
self.inner_executor.reset_common_infra(common_infra)
|
||||
@@ -199,57 +261,56 @@ class NestedExecutor(BaseExecutor):
|
||||
sub_level_infra = self.inner_executor.get_level_infra()
|
||||
self.inner_strategy.reset(level_infra=sub_level_infra, outer_trade_decision=trade_decision)
|
||||
|
||||
def _update_trade_account(self):
|
||||
trade_step = self.trade_calendar.get_trade_step()
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
|
||||
self.trade_account.update_bar_count()
|
||||
if self.generate_report:
|
||||
self.trade_account.update_bar_report(
|
||||
trade_start_time=trade_start_time,
|
||||
trade_end_time=trade_end_time,
|
||||
trade_exchange=self.trade_exchange,
|
||||
)
|
||||
|
||||
def execute(self, trade_decision):
|
||||
self._init_sub_trading(trade_decision)
|
||||
execute_result = []
|
||||
_inner_execute_result = None
|
||||
while not self.inner_executor.finished():
|
||||
_inner_trade_decision = self.inner_strategy.generate_trade_decision(_inner_execute_result)
|
||||
_inner_execute_result = self.inner_executor.execute(trade_decision=_inner_trade_decision)
|
||||
execute_result.extend(_inner_execute_result)
|
||||
if hasattr(self, "trade_account"):
|
||||
self._update_trade_account()
|
||||
self.trade_calendar.step()
|
||||
return execute_result
|
||||
return_value = {}
|
||||
for _decision in self.collect_data(trade_decision, return_value):
|
||||
pass
|
||||
return return_value.get("execute_result")
|
||||
|
||||
def collect_data(self, trade_decision):
|
||||
def collect_data(self, trade_decision: BaseTradeDecision, return_value=None):
|
||||
if self.track_data:
|
||||
yield trade_decision
|
||||
self.trade_calendar.step()
|
||||
self._init_sub_trading(trade_decision)
|
||||
execute_result = []
|
||||
inner_order_indicators = []
|
||||
_inner_execute_result = None
|
||||
while not self.inner_executor.finished():
|
||||
_inner_trade_decision = self.inner_strategy.generate_trade_decision(_inner_execute_result)
|
||||
_inner_execute_result = yield from self.inner_executor.collect_data(trade_decision=_inner_trade_decision)
|
||||
execute_result.extend(_inner_execute_result)
|
||||
if hasattr(self, "trade_account"):
|
||||
self._update_trade_account()
|
||||
# outter strategy have chance to update decision each iterator
|
||||
updated_trade_decision = trade_decision.update(self.inner_executor.trade_calendar)
|
||||
if updated_trade_decision is not None:
|
||||
trade_decision = updated_trade_decision
|
||||
# NEW UPDATE
|
||||
# create a hook for inner strategy to update outter decision
|
||||
self.inner_strategy.alter_outer_trade_decision(trade_decision)
|
||||
|
||||
_inner_trade_decision = self.inner_strategy.generate_trade_decision(_inner_execute_result)
|
||||
|
||||
# NOTE: Trade Calendar will step forward in the follow line
|
||||
_inner_execute_result = yield from self.inner_executor.collect_data(trade_decision=_inner_trade_decision)
|
||||
|
||||
execute_result.extend(_inner_execute_result)
|
||||
inner_order_indicators.append(self.inner_executor.get_trade_indicator().get_order_indicator())
|
||||
|
||||
if hasattr(self, "trade_account"):
|
||||
trade_step = self.trade_calendar.get_trade_step()
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
|
||||
self.trade_account.update_bar_end(
|
||||
trade_start_time,
|
||||
trade_end_time,
|
||||
self.trade_exchange,
|
||||
atomic=False,
|
||||
generate_report=self.generate_report,
|
||||
inner_order_indicators=inner_order_indicators,
|
||||
indicator_config=self.indicator_config,
|
||||
)
|
||||
|
||||
self.trade_calendar.step()
|
||||
if return_value is not None:
|
||||
return_value.update({"execute_result": execute_result})
|
||||
return execute_result
|
||||
|
||||
def get_report(self):
|
||||
sub_env_report_dict = self.inner_executor.get_report()
|
||||
if self.generate_report:
|
||||
_report = self.trade_account.report.generate_report_dataframe()
|
||||
_positions = self.trade_account.get_positions()
|
||||
_count, _freq = parse_freq(self.time_per_step)
|
||||
sub_env_report_dict.update({f"{_count}{_freq}": (_report, _positions)})
|
||||
return sub_env_report_dict
|
||||
|
||||
def get_all_executors(self):
|
||||
"""Return all executors, including self and inner_executor.get_all_executors()"""
|
||||
"""get all executors, including self and inner_executor.get_all_executors()"""
|
||||
return [self, *self.inner_executor.get_all_executors()]
|
||||
|
||||
|
||||
@@ -261,6 +322,7 @@ class SimulatorExecutor(BaseExecutor):
|
||||
time_per_step: str,
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
indicator_config: dict = {},
|
||||
generate_report: bool = False,
|
||||
verbose: bool = False,
|
||||
track_data: bool = False,
|
||||
@@ -279,6 +341,7 @@ class SimulatorExecutor(BaseExecutor):
|
||||
time_per_step=time_per_step,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
indicator_config=indicator_config,
|
||||
generate_report=generate_report,
|
||||
verbose=verbose,
|
||||
track_data=track_data,
|
||||
@@ -297,12 +360,12 @@ class SimulatorExecutor(BaseExecutor):
|
||||
if common_infra.has("trade_exchange"):
|
||||
self.trade_exchange = common_infra.get("trade_exchange")
|
||||
|
||||
def execute(self, trade_decision):
|
||||
def execute(self, trade_decision: BaseTradeDecision):
|
||||
|
||||
trade_step = self.trade_calendar.get_trade_step()
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
|
||||
execute_result = []
|
||||
for order in trade_decision:
|
||||
for order in trade_decision.get_decision():
|
||||
if self.trade_exchange.check_order(order) is True:
|
||||
# execute the order
|
||||
trade_val, trade_cost, trade_price = self.trade_exchange.deal_order(
|
||||
@@ -337,26 +400,18 @@ class SimulatorExecutor(BaseExecutor):
|
||||
|
||||
else:
|
||||
if self.verbose:
|
||||
print("[W {:%Y-%m-%d}]: {} wrong.".format(trade_start_time, order.stock_id))
|
||||
print("[W {:%Y-%m-%d %H:%M:%S}]: {} wrong.".format(trade_start_time, order.stock_id))
|
||||
# do nothing
|
||||
pass
|
||||
|
||||
self.trade_account.update_bar_count()
|
||||
|
||||
if self.generate_report:
|
||||
self.trade_account.update_bar_report(
|
||||
trade_start_time=trade_start_time,
|
||||
trade_end_time=trade_end_time,
|
||||
trade_exchange=self.trade_exchange,
|
||||
)
|
||||
self.trade_account.update_bar_end(
|
||||
trade_start_time,
|
||||
trade_end_time,
|
||||
self.trade_exchange,
|
||||
atomic=True,
|
||||
generate_report=self.generate_report,
|
||||
trade_info=execute_result,
|
||||
indicator_config=self.indicator_config,
|
||||
)
|
||||
self.trade_calendar.step()
|
||||
return execute_result
|
||||
|
||||
def get_report(self):
|
||||
if self.generate_report:
|
||||
_report = self.trade_account.report.generate_report_dataframe()
|
||||
_positions = self.trade_account.get_positions()
|
||||
_count, _freq = parse_freq(self.time_per_step)
|
||||
return {f"{_count}{_freq}": (_report, _positions)}
|
||||
else:
|
||||
return {}
|
||||
|
||||
@@ -1,8 +1,18 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
# TODO: rename it with decision.py
|
||||
from __future__ import annotations
|
||||
|
||||
# try to fix circular imports when enabling type hints
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qlib.strategy.base import BaseStrategy
|
||||
from qlib.backtest.utils import TradeCalendarManager
|
||||
import warnings
|
||||
import pandas as pd
|
||||
from dataclasses import dataclass, field
|
||||
from typing import ClassVar, Optional
|
||||
from typing import ClassVar, Optional, Union, List, Set, Tuple
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -34,3 +44,198 @@ class Order:
|
||||
if self.direction not in {Order.SELL, Order.BUY}:
|
||||
raise NotImplementedError("direction not supported, `Order.SELL` for sell, `Order.BUY` for buy")
|
||||
self.deal_amount = 0
|
||||
|
||||
|
||||
class BaseTradeDecision:
|
||||
"""
|
||||
Trade decisions ara made by strategy and executed by exeuter
|
||||
|
||||
Motivation:
|
||||
Here are several typical scenarios for `BaseTradeDecision`
|
||||
|
||||
Case 1:
|
||||
1. Outer strategy makes a decision. The decision is not available at the start of current interval
|
||||
2. After a period of time, the decision are updated and become available
|
||||
3. The inner strategy try to get the decision and start to execute the decision according to `get_range_limit`
|
||||
Case 2:
|
||||
1. The outer strategy's decision is available at the start of the interval
|
||||
2. Same as `case 1.3`
|
||||
"""
|
||||
|
||||
def __init__(self, strategy: BaseStrategy):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
strategy : BaseStrategy
|
||||
The strategy who make the decision
|
||||
"""
|
||||
self.strategy = strategy
|
||||
|
||||
def get_decision(self) -> List[object]:
|
||||
"""
|
||||
get the **concrete decision** (e.g. execution orders)
|
||||
This will be called by the inner strategy
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[object]:
|
||||
The decision result. Typically it is some orders
|
||||
Example:
|
||||
[]:
|
||||
Decision not available
|
||||
concrete_decision:
|
||||
available
|
||||
"""
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
def update(self, trade_calendar: TradeCalendarManager) -> Union["BaseTradeDecision", None]:
|
||||
"""
|
||||
Be called at the **start** of each step
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trade_calendar : TradeCalendarManager
|
||||
The calendar of the **inner strategy**!!!!!
|
||||
|
||||
Returns
|
||||
-------
|
||||
None:
|
||||
No update, use previous decision(or unavailable)
|
||||
BaseTradeDecision:
|
||||
New update, use new decision
|
||||
"""
|
||||
return self.strategy.update_trade_decision(self, trade_calendar)
|
||||
|
||||
def get_range_limit(self) -> Tuple[int, int]:
|
||||
"""
|
||||
return the expected step range for limiting the decision execution time
|
||||
Both left and right are **closed**
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[int, int]:
|
||||
|
||||
Raises
|
||||
------
|
||||
NotImplementedError:
|
||||
If the decision can't provide a unified start and end
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `func` method")
|
||||
|
||||
|
||||
class TradeDecisionWO(BaseTradeDecision):
|
||||
"""
|
||||
Trade Decision (W)ith (O)rder.
|
||||
Besides, the time_range is also included.
|
||||
"""
|
||||
|
||||
def __init__(self, order_list: List[Order], strategy: BaseStrategy, idx_range: Tuple = None):
|
||||
super().__init__(strategy)
|
||||
self.order_list = order_list
|
||||
self.idx_range = idx_range
|
||||
|
||||
def get_range_limit(self) -> Tuple[int, int]:
|
||||
if self.idx_range is None:
|
||||
# Default to get full index
|
||||
raise NotImplementedError(f"The decision didn't provide an index range")
|
||||
return self.idx_range
|
||||
|
||||
def get_decision(self) -> List[object]:
|
||||
return self.order_list
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"strategy: {self.strategy}; idx_range: {self.idx_range}; order_list[{len(self.order_list)}]"
|
||||
|
||||
|
||||
# TODO: the orders below need to be discussed ------------------------------------
|
||||
# - The classes below are designed for Case 1
|
||||
# - However, Case 1 can't take `order_pool` as the an argument as the constructor function
|
||||
class TradeDecisionWithOrderPool:
|
||||
"""trade decision that made by strategy"""
|
||||
|
||||
def __init__(self, strategy, order_pool):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
strategy : BaseStrategy
|
||||
the original strategy that make the decision
|
||||
order_pool : list, optional
|
||||
the candinate order pool for generate trade decision
|
||||
"""
|
||||
super(TradeDecisionWithOrderPool, self).__init__(strategy)
|
||||
self.order_pool = order_pool
|
||||
self.order_list = []
|
||||
|
||||
def pop_order_pool(self, pop_len):
|
||||
if pop_len > len(self.order_pool):
|
||||
warnings.warn(
|
||||
f"pop len {pop_len} is too much length than order pool, cut it as pool length {len(self.order_pool)}"
|
||||
)
|
||||
pop_len = len(self.order_pool)
|
||||
res = self.order_pool[:pop_len]
|
||||
del self.order_pool[:pop_len]
|
||||
return res
|
||||
|
||||
def push_order_list(self, order_list):
|
||||
self.order_list.extend(order_list)
|
||||
|
||||
def get_decision(self):
|
||||
"""get the order list
|
||||
|
||||
Parameters
|
||||
----------
|
||||
only_enable : bool, optional
|
||||
wether to ignore disabled order, by default False
|
||||
only_disable : bool, optional
|
||||
wether to ignore enabled order, by default False
|
||||
Returns
|
||||
-------
|
||||
List[Order]
|
||||
the order list
|
||||
"""
|
||||
return self.order_list
|
||||
|
||||
def update(self, trade_calendar):
|
||||
"""make the original strategy update the enabled status of orders."""
|
||||
self.ori_strategy.update_trade_decision(self, trade_calendar)
|
||||
|
||||
|
||||
class BaseDecisionUpdater:
|
||||
def update_decision(self, decision, trade_calendar) -> BaseTradeDecision:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
decision : BaseTradeDecision
|
||||
the trade decision to be updated
|
||||
trade_calendar : BaseTradeCalendar
|
||||
the trade calendar of inner execution
|
||||
|
||||
Returns
|
||||
-------
|
||||
BaseTradeDecision
|
||||
the updated decision
|
||||
"""
|
||||
raise NotImplementedError(f"This method is not implemented")
|
||||
|
||||
|
||||
class DecisionUpdaterWithOrderPool:
|
||||
def __init__(self, plan_config=None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
plan_config : Dict[Tuple(int, float)], optional
|
||||
the plan config, by default None
|
||||
"""
|
||||
if plan_config is None:
|
||||
self.plan_config = [(0, 1)]
|
||||
else:
|
||||
self.plan_config = plan_config
|
||||
|
||||
def update_decision(self, decision, trade_calendar) -> BaseTradeDecision:
|
||||
# get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1]
|
||||
trade_step = self.trade_calendar.get_trade_step()
|
||||
for _index, _ratio in self.plan_config:
|
||||
if trade_step == _index:
|
||||
pop_len = len(decision.order_pool) * _ratio
|
||||
pop_order_list = decision.pop_order_pool(pop_len)
|
||||
decision.push_order_list(pop_order_list)
|
||||
|
||||
@@ -4,30 +4,200 @@
|
||||
|
||||
import copy
|
||||
import pathlib
|
||||
from typing import Dict, List
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from .order import Order
|
||||
|
||||
"""
|
||||
Position module
|
||||
"""
|
||||
|
||||
"""
|
||||
current state of position
|
||||
a typical example is :{
|
||||
<instrument_id>: {
|
||||
'count': <how many days the security has been hold>,
|
||||
'amount': <the amount of the security>,
|
||||
'price': <the close price of security in the last trading day>,
|
||||
'weight': <the security weight of total position value>,
|
||||
},
|
||||
}
|
||||
class BasePosition:
|
||||
"""
|
||||
The Position want to maintain the position like a dictionary
|
||||
Please refer to the `Position` class for the position
|
||||
"""
|
||||
|
||||
"""
|
||||
def __init__(self, cash=0.0, *args, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
def skip_update(self) -> bool:
|
||||
"""
|
||||
Should we skip updating operation for this position
|
||||
For example, updating is meaningless for InfPosition
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool:
|
||||
should we skip the updating operator
|
||||
"""
|
||||
return False
|
||||
|
||||
def check_stock(self, stock_id: str) -> bool:
|
||||
"""
|
||||
check if is the stock in the position
|
||||
|
||||
Parameters
|
||||
----------
|
||||
stock_id : str
|
||||
the id of the stock
|
||||
|
||||
Returns
|
||||
-------
|
||||
bool:
|
||||
if is the stock in the position
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `check_stock` method")
|
||||
|
||||
def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
order : Order
|
||||
the order to update the position
|
||||
trade_val : float
|
||||
the trade value(money) of dealing results
|
||||
cost : float
|
||||
the trade cost of the dealing results
|
||||
trade_price : float
|
||||
the trade price of the dealing results
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `update_order` method")
|
||||
|
||||
def update_stock_price(self, stock_id, price: float):
|
||||
"""
|
||||
Updating the latest price of the order
|
||||
The useful when clearing balance at each bar end
|
||||
|
||||
Parameters
|
||||
----------
|
||||
stock_id :
|
||||
the id of the stock
|
||||
price : float
|
||||
the price to be updated
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `update stock price` method")
|
||||
|
||||
def calculate_stock_value(self) -> float:
|
||||
"""
|
||||
calculate the value of the all assets except cash in the position
|
||||
|
||||
Returns
|
||||
-------
|
||||
float:
|
||||
the value(money) of all the stock
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `calculate_stock_value` method")
|
||||
|
||||
def get_stock_list(self) -> List:
|
||||
"""
|
||||
Get the list of stocks in the position.
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_stock_list` method")
|
||||
|
||||
def get_stock_price(self, code) -> float:
|
||||
"""
|
||||
get the latest price of the stock
|
||||
|
||||
Parameters
|
||||
----------
|
||||
code :
|
||||
the code of the stock
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_stock_price` method")
|
||||
|
||||
def get_stock_amount(self, code) -> float:
|
||||
"""
|
||||
get the amount of the stock
|
||||
|
||||
Parameters
|
||||
----------
|
||||
code :
|
||||
the code of the stock
|
||||
|
||||
Returns
|
||||
-------
|
||||
float:
|
||||
the amount of the stock
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_stock_amount` method")
|
||||
|
||||
def get_cash(self) -> float:
|
||||
"""
|
||||
|
||||
Returns
|
||||
-------
|
||||
float:
|
||||
the cash in position
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_cash` method")
|
||||
|
||||
def get_stock_amount_dict(self) -> Dict:
|
||||
"""
|
||||
generate stock amount dict {stock_id : amount of stock}
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict:
|
||||
{stock_id : amount of stock}
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_stock_amount_dict` method")
|
||||
|
||||
def get_stock_weight_dict(self, only_stock: bool = False) -> Dict:
|
||||
"""
|
||||
generate stock weight dict {stock_id : value weight of stock in the position}
|
||||
it is meaningful in the beginning or the end of each trade date
|
||||
|
||||
Parameters
|
||||
----------
|
||||
only_stock : bool
|
||||
If only_stock=True, the weight of each stock in total stock will be returned
|
||||
If only_stock=False, the weight of each stock in total assets(stock + cash) will be returned
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict:
|
||||
{stock_id : value weight of stock in the position}
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_stock_weight_dict` method")
|
||||
|
||||
def add_count_all(self, bar):
|
||||
"""
|
||||
Will be called at the end of each bar on each level
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bar :
|
||||
The level to be updated
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `add_count_all` method")
|
||||
|
||||
def update_weight_all(self):
|
||||
"""
|
||||
Updating the position weight;
|
||||
|
||||
# TODO: this function is a little weird. The weight data in the position is in a wrong state after dealing order
|
||||
# and before updating weight.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bar :
|
||||
The level to be updated
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `add_count_all` method")
|
||||
|
||||
|
||||
class Position:
|
||||
"""Position"""
|
||||
class Position(BasePosition):
|
||||
"""Position
|
||||
|
||||
current state of position
|
||||
a typical example is :{
|
||||
<instrument_id>: {
|
||||
'count': <how many days the security has been hold>,
|
||||
'amount': <the amount of the security>,
|
||||
'price': <the close price of security in the last trading day>,
|
||||
'weight': <the security weight of total position value>,
|
||||
},
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, cash=0, position_dict={}, now_account_value=0):
|
||||
# NOTE: The position dict must be copied!!!
|
||||
@@ -37,23 +207,35 @@ class Position:
|
||||
self.position["cash"] = cash
|
||||
self.position["now_account_value"] = now_account_value
|
||||
|
||||
def init_stock(self, stock_id, amount, price=None):
|
||||
def _init_stock(self, stock_id, amount, price=None):
|
||||
"""
|
||||
initialization the stock in current position
|
||||
|
||||
Parameters
|
||||
----------
|
||||
stock_id :
|
||||
the id of the stock
|
||||
amount : float
|
||||
the amount of the stock
|
||||
price :
|
||||
the price when buying the init stock
|
||||
"""
|
||||
self.position[stock_id] = {}
|
||||
self.position[stock_id]["amount"] = amount
|
||||
self.position[stock_id]["price"] = price
|
||||
self.position[stock_id]["weight"] = 0 # update the weight in the end of the trade date
|
||||
|
||||
def buy_stock(self, stock_id, trade_val, cost, trade_price):
|
||||
def _buy_stock(self, stock_id, trade_val, cost, trade_price):
|
||||
trade_amount = trade_val / trade_price
|
||||
if stock_id not in self.position:
|
||||
self.init_stock(stock_id=stock_id, amount=trade_amount, price=trade_price)
|
||||
self._init_stock(stock_id=stock_id, amount=trade_amount, price=trade_price)
|
||||
else:
|
||||
# exist, add amount
|
||||
self.position[stock_id]["amount"] += trade_amount
|
||||
|
||||
self.position["cash"] -= trade_val + cost
|
||||
|
||||
def sell_stock(self, stock_id, trade_val, cost, trade_price):
|
||||
def _sell_stock(self, stock_id, trade_val, cost, trade_price):
|
||||
trade_amount = trade_val / trade_price
|
||||
if stock_id not in self.position:
|
||||
raise KeyError("{} not in current position".format(stock_id))
|
||||
@@ -66,11 +248,11 @@ class Position:
|
||||
"only have {} {}, require {}".format(self.position[stock_id]["amount"], stock_id, trade_amount)
|
||||
)
|
||||
elif abs(self.position[stock_id]["amount"]) <= 1e-5:
|
||||
self.del_stock(stock_id)
|
||||
self._del_stock(stock_id)
|
||||
|
||||
self.position["cash"] += trade_val - cost
|
||||
|
||||
def del_stock(self, stock_id):
|
||||
def _del_stock(self, stock_id):
|
||||
del self.position[stock_id]
|
||||
|
||||
def check_stock(self, stock_id):
|
||||
@@ -80,10 +262,10 @@ class Position:
|
||||
# handle order, order is a order class, defined in exchange.py
|
||||
if order.direction == Order.BUY:
|
||||
# BUY
|
||||
self.buy_stock(order.stock_id, trade_val, cost, trade_price)
|
||||
self._buy_stock(order.stock_id, trade_val, cost, trade_price)
|
||||
elif order.direction == Order.SELL:
|
||||
# SELL
|
||||
self.sell_stock(order.stock_id, trade_val, cost, trade_price)
|
||||
self._sell_stock(order.stock_id, trade_val, cost, trade_price)
|
||||
else:
|
||||
raise NotImplementedError("do not support order direction {}".format(order.direction))
|
||||
|
||||
@@ -122,6 +304,7 @@ class Position:
|
||||
return self.position[code]["amount"]
|
||||
|
||||
def get_stock_count(self, code, bar):
|
||||
"""the days the account has been hold, it may be used in some special strategies"""
|
||||
if f"count_{bar}" in self.position[code]:
|
||||
return self.position[code][f"count_{bar}"]
|
||||
else:
|
||||
@@ -215,3 +398,59 @@ class Position:
|
||||
self.position = positions
|
||||
self.position["cash"] = cash
|
||||
self.position["now_account_value"] = now_account_value
|
||||
|
||||
|
||||
class InfPosition(BasePosition):
|
||||
"""
|
||||
Position with infinite cash and amount.
|
||||
|
||||
This is useful for generating random orders.
|
||||
"""
|
||||
|
||||
def skip_update(self) -> bool:
|
||||
""" Updating state is meaningless for InfPosition """
|
||||
return True
|
||||
|
||||
def check_stock(self, stock_id: str) -> bool:
|
||||
# InfPosition always have any stocks
|
||||
return True
|
||||
|
||||
def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float):
|
||||
pass
|
||||
|
||||
def update_stock_price(self, stock_id, price: float):
|
||||
pass
|
||||
|
||||
def calculate_stock_value(self) -> float:
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
float:
|
||||
infinity stock value
|
||||
"""
|
||||
return np.inf
|
||||
|
||||
def get_stock_list(self) -> List:
|
||||
raise NotImplementedError(f"InfPosition doesn't support stock list position")
|
||||
|
||||
def get_stock_price(self, code) -> float:
|
||||
"""the price of the inf position is meaningless"""
|
||||
return np.nan
|
||||
|
||||
def get_stock_amount(self, code) -> float:
|
||||
return np.inf
|
||||
|
||||
def get_cash(self) -> float:
|
||||
return np.inf
|
||||
|
||||
def get_stock_amount_dict(self) -> Dict:
|
||||
raise NotImplementedError(f"InfPosition doesn't support get_stock_amount_dict")
|
||||
|
||||
def get_stock_weight_dict(self, only_stock: bool) -> Dict:
|
||||
raise NotImplementedError(f"InfPosition doesn't support get_stock_weight_dict")
|
||||
|
||||
def add_count_all(self, bar):
|
||||
raise NotImplementedError(f"InfPosition doesn't support get_stock_weight_dict")
|
||||
|
||||
def update_weight_all(self):
|
||||
raise NotImplementedError(f"InfPosition doesn't support update_weight_all")
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
"""
|
||||
This module is not well maintained.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
@@ -7,18 +7,27 @@ from logging import warning
|
||||
import pandas as pd
|
||||
import pathlib
|
||||
import warnings
|
||||
from pandas.core import groupby
|
||||
|
||||
from pandas.core.frame import DataFrame
|
||||
|
||||
from ..utils.resam import parse_freq, resam_ts_data
|
||||
from ..utils.time import Freq
|
||||
from ..utils.resam import resam_ts_data, get_higher_eq_freq_feature
|
||||
from ..data import D
|
||||
from ..tests.config import CSI300_BENCH
|
||||
|
||||
|
||||
class Report:
|
||||
# daily report of the account
|
||||
# contain those followings: returns, costs turnovers, accounts, cash, bench, value
|
||||
# update report
|
||||
"""
|
||||
Motivation:
|
||||
Report is for supporting portfolio related metrics.
|
||||
|
||||
Implementation:
|
||||
daily report of the account
|
||||
contain those followings: returns, costs turnovers, accounts, cash, bench, value
|
||||
update report
|
||||
"""
|
||||
|
||||
def __init__(self, freq: str = "day", benchmark_config: dict = {}):
|
||||
"""
|
||||
Parameters
|
||||
@@ -51,11 +60,13 @@ class Report:
|
||||
self.init_bench(freq=freq, benchmark_config=benchmark_config)
|
||||
|
||||
def init_vars(self):
|
||||
self.accounts = OrderedDict() # account postion value for each trade date
|
||||
self.returns = OrderedDict() # daily return rate for each trade date
|
||||
self.turnovers = OrderedDict() # turnover for each trade date
|
||||
self.costs = OrderedDict() # trade cost for each trade date
|
||||
self.values = OrderedDict() # value for each trade date
|
||||
self.accounts = OrderedDict() # account postion value for each trade time
|
||||
self.returns = OrderedDict() # daily return rate for each trade time
|
||||
self.total_turnovers = OrderedDict() # total turnover for each trade time
|
||||
self.turnovers = OrderedDict() # turnover for each trade time
|
||||
self.total_costs = OrderedDict() # total trade cost for each trade time
|
||||
self.costs = OrderedDict() # trade cost rate for each trade time
|
||||
self.values = OrderedDict() # value for each trade time
|
||||
self.cashes = OrderedDict()
|
||||
self.benches = OrderedDict()
|
||||
self.latest_report_time = None # pd.TimeStamp
|
||||
@@ -69,6 +80,9 @@ class Report:
|
||||
|
||||
def _cal_benchmark(self, benchmark_config, freq):
|
||||
benchmark = benchmark_config.get("benchmark", CSI300_BENCH)
|
||||
if benchmark is None:
|
||||
return None
|
||||
|
||||
if isinstance(benchmark, pd.Series):
|
||||
return benchmark
|
||||
else:
|
||||
@@ -79,29 +93,20 @@ class Report:
|
||||
raise ValueError("benchmark freq can't be None!")
|
||||
_codes = benchmark if isinstance(benchmark, list) else [benchmark]
|
||||
fields = ["$close/Ref($close,1)-1"]
|
||||
try:
|
||||
_temp_result = D.features(_codes, fields, start_time, end_time, freq=freq, disk_cache=1)
|
||||
except (ValueError, KeyError):
|
||||
_, norm_freq = parse_freq(freq)
|
||||
if norm_freq in ["month", "week", "day"]:
|
||||
try:
|
||||
_temp_result = D.features(_codes, fields, start_time, end_time, freq="day", disk_cache=1)
|
||||
except (ValueError, KeyError):
|
||||
_temp_result = D.features(_codes, fields, start_time, end_time, freq="1min", disk_cache=1)
|
||||
elif norm_freq == "minute":
|
||||
_temp_result = D.features(_codes, fields, start_time, end_time, freq="1min", disk_cache=1)
|
||||
else:
|
||||
raise ValueError(f"benchmark freq {freq} is not supported")
|
||||
_temp_result, _ = get_higher_eq_freq_feature(_codes, fields, start_time, end_time, freq=freq)
|
||||
if len(_temp_result) == 0:
|
||||
raise ValueError(f"The benchmark {_codes} does not exist. Please provide the right benchmark")
|
||||
return _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean().fillna(0)
|
||||
|
||||
def _sample_benchmark(self, bench, trade_start_time, trade_end_time):
|
||||
if self.bench is None:
|
||||
return None
|
||||
|
||||
def cal_change(x):
|
||||
return (x + 1).prod() - 1
|
||||
return (x + 1).prod()
|
||||
|
||||
_ret = resam_ts_data(bench, trade_start_time, trade_end_time, method=cal_change)
|
||||
return 0.0 if _ret is None else _ret
|
||||
return 0.0 if _ret is None else _ret - 1
|
||||
|
||||
def is_empty(self):
|
||||
return len(self.accounts) == 0
|
||||
@@ -112,6 +117,12 @@ class Report:
|
||||
def get_latest_account_value(self):
|
||||
return self.accounts[self.latest_report_time]
|
||||
|
||||
def get_latest_total_cost(self):
|
||||
return self.total_costs[self.latest_report_time]
|
||||
|
||||
def get_latest_total_turnover(self):
|
||||
return self.total_turnovers[self.latest_report_time]
|
||||
|
||||
def update_report_record(
|
||||
self,
|
||||
trade_start_time=None,
|
||||
@@ -119,41 +130,55 @@ class Report:
|
||||
account_value=None,
|
||||
cash=None,
|
||||
return_rate=None,
|
||||
total_turnover=None,
|
||||
turnover_rate=None,
|
||||
total_cost=None,
|
||||
cost_rate=None,
|
||||
stock_value=None,
|
||||
bench_value=None,
|
||||
):
|
||||
# check data
|
||||
if None in [
|
||||
trade_start_time,
|
||||
trade_end_time,
|
||||
account_value,
|
||||
cash,
|
||||
return_rate,
|
||||
total_turnover,
|
||||
turnover_rate,
|
||||
total_cost,
|
||||
cost_rate,
|
||||
stock_value,
|
||||
]:
|
||||
raise ValueError(
|
||||
"None in [trade_start_time, trade_end_time, account_value, cash, return_rate, turnover_rate, cost_rate, stock_value]"
|
||||
"None in [trade_start_time, account_value, cash, return_rate, total_turnover, turnover_rate, total_cost, cost_rate, stock_value]"
|
||||
)
|
||||
|
||||
if trade_end_time is None and bench_value is None:
|
||||
raise ValueError("Both trade_end_time and bench_value is None, benchmark is not usable.")
|
||||
elif bench_value is None:
|
||||
bench_value = self._sample_benchmark(self.bench, trade_start_time, trade_end_time)
|
||||
|
||||
# update report data
|
||||
self.accounts[trade_start_time] = account_value
|
||||
self.returns[trade_start_time] = return_rate
|
||||
self.total_turnovers[trade_start_time] = total_turnover
|
||||
self.turnovers[trade_start_time] = turnover_rate
|
||||
self.total_costs[trade_start_time] = total_cost
|
||||
self.costs[trade_start_time] = cost_rate
|
||||
self.values[trade_start_time] = stock_value
|
||||
self.cashes[trade_start_time] = cash
|
||||
self.benches[trade_start_time] = self._sample_benchmark(self.bench, trade_start_time, trade_end_time)
|
||||
self.benches[trade_start_time] = bench_value
|
||||
# update latest_report_date
|
||||
self.latest_report_time = trade_start_time
|
||||
# finish daily report update
|
||||
# finish report update in each step
|
||||
|
||||
def generate_report_dataframe(self):
|
||||
report = pd.DataFrame()
|
||||
report["account"] = pd.Series(self.accounts)
|
||||
report["return"] = pd.Series(self.returns)
|
||||
report["total_turnover"] = pd.Series(self.total_turnovers)
|
||||
report["turnover"] = pd.Series(self.turnovers)
|
||||
report["total_cost"] = pd.Series(self.total_costs)
|
||||
report["cost"] = pd.Series(self.costs)
|
||||
report["value"] = pd.Series(self.values)
|
||||
report["cash"] = pd.Series(self.cashes)
|
||||
@@ -168,7 +193,7 @@ class Report:
|
||||
def load_report(self, path):
|
||||
"""load report from a file
|
||||
should have format like
|
||||
columns = ['account', 'return', 'turnover', 'cost', 'value', 'cash', 'bench']
|
||||
columns = ['account', 'return', 'total_turnover', 'turnover', 'cost', 'total_cost', 'value', 'cash', 'bench']
|
||||
:param
|
||||
path: str/ pathlib.Path()
|
||||
"""
|
||||
@@ -178,14 +203,204 @@ class Report:
|
||||
|
||||
index = r.index
|
||||
self.init_vars()
|
||||
for trade_time in index:
|
||||
for trade_start_time in index:
|
||||
self.update_report_record(
|
||||
trade_time=trade_time,
|
||||
account_value=r.loc[trade_time]["account"],
|
||||
cash=r.loc[trade_time]["cash"],
|
||||
return_rate=r.loc[trade_time]["return"],
|
||||
turnover_rate=r.loc[trade_time]["turnover"],
|
||||
cost_rate=r.loc[trade_time]["cost"],
|
||||
stock_value=r.loc[trade_time]["value"],
|
||||
bench_value=r.loc[trade_time]["bench"],
|
||||
trade_start_time=trade_start_time,
|
||||
account_value=r.loc[trade_start_time]["account"],
|
||||
cash=r.loc[trade_start_time]["cash"],
|
||||
return_rate=r.loc[trade_start_time]["return"],
|
||||
total_turnover=r.loc[trade_start_time]["total_turnover"],
|
||||
turnover_rate=r.loc[trade_start_time]["turnover"],
|
||||
total_cost=r.loc[trade_start_time]["total_cost"],
|
||||
cost_rate=r.loc[trade_start_time]["cost"],
|
||||
stock_value=r.loc[trade_start_time]["value"],
|
||||
bench_value=r.loc[trade_start_time]["bench"],
|
||||
)
|
||||
|
||||
|
||||
class Indicator:
|
||||
def __init__(self):
|
||||
self.order_indicator_his = OrderedDict()
|
||||
self.order_indicator = OrderedDict()
|
||||
self.trade_indicator_his = OrderedDict()
|
||||
self.trade_indicator = OrderedDict()
|
||||
|
||||
def clear(self):
|
||||
self.order_indicator = OrderedDict()
|
||||
self.trade_indicator = OrderedDict()
|
||||
|
||||
def record(self, trade_start_time):
|
||||
self.order_indicator_his[trade_start_time] = self.order_indicator
|
||||
self.trade_indicator_his[trade_start_time] = self.trade_indicator
|
||||
|
||||
def _update_order_trade_info(self, trade_info: list):
|
||||
amount = dict()
|
||||
deal_amount = dict()
|
||||
trade_price = dict()
|
||||
trade_value = dict()
|
||||
trade_cost = dict()
|
||||
|
||||
for order, _trade_val, _trade_cost, _trade_price in trade_info:
|
||||
amount[order.stock_id] = order.amount * (order.direction * 2 - 1)
|
||||
deal_amount[order.stock_id] = order.deal_amount * (order.direction * 2 - 1)
|
||||
trade_price[order.stock_id] = _trade_price
|
||||
trade_value[order.stock_id] = _trade_val * (order.direction * 2 - 1)
|
||||
trade_cost[order.stock_id] = _trade_cost
|
||||
|
||||
self.order_indicator["amount"] = pd.Series(amount)
|
||||
self.order_indicator["deal_amount"] = pd.Series(deal_amount)
|
||||
self.order_indicator["trade_price"] = pd.Series(trade_price)
|
||||
self.order_indicator["trade_value"] = pd.Series(trade_value)
|
||||
self.order_indicator["trade_cost"] = pd.Series(trade_cost)
|
||||
|
||||
def _update_order_fulfill_rate(self):
|
||||
self.order_indicator["ffr"] = self.order_indicator["deal_amount"] / self.order_indicator["amount"]
|
||||
|
||||
def _update_order_price_advantage(self, trade_exchange, trade_start_time, trade_end_time):
|
||||
self.order_indicator["base_price"] = self.order_indicator["trade_price"]
|
||||
instruments = list(self.order_indicator["base_price"].index)
|
||||
self.order_indicator["volume"] = pd.Series(
|
||||
[
|
||||
trade_exchange.get_volume(stock_id=inst, start_time=trade_start_time, end_time=trade_end_time)
|
||||
for inst in instruments
|
||||
],
|
||||
index=instruments,
|
||||
)
|
||||
self.order_indicator["pa"] = (
|
||||
self.order_indicator["trade_price"] - self.order_indicator["base_price"]
|
||||
) / self.order_indicator["base_price"]
|
||||
|
||||
def _agg_order_trade_info(self, inner_order_indicators):
|
||||
amount = pd.Series()
|
||||
deal_amount = pd.Series()
|
||||
trade_price = pd.Series()
|
||||
trade_value = pd.Series()
|
||||
trade_cost = pd.Series()
|
||||
for _order_indicator in inner_order_indicators:
|
||||
amount = amount.add(_order_indicator["amount"], fill_value=0)
|
||||
deal_amount = deal_amount.add(_order_indicator["deal_amount"], fill_value=0)
|
||||
trade_price = trade_price.add(
|
||||
_order_indicator["trade_price"] * _order_indicator["deal_amount"], fill_value=0
|
||||
)
|
||||
trade_value = trade_value.add(_order_indicator["trade_value"], fill_value=0)
|
||||
trade_cost = trade_cost.add(_order_indicator["trade_cost"], fill_value=0)
|
||||
|
||||
self.order_indicator["amount"] = amount
|
||||
self.order_indicator["deal_amount"] = deal_amount
|
||||
trade_price /= self.order_indicator["deal_amount"]
|
||||
self.order_indicator["trade_price"] = trade_price
|
||||
self.order_indicator["trade_value"] = trade_value
|
||||
self.order_indicator["trade_cost"] = trade_cost
|
||||
|
||||
def _agg_order_fulfill_rate(self):
|
||||
self.order_indicator["ffr"] = self.order_indicator["deal_amount"] / self.order_indicator["amount"]
|
||||
|
||||
def _agg_order_price_advantage(self, inner_order_indicators, base_price="twap"):
|
||||
base_price = base_price.lower()
|
||||
volume = pd.Series()
|
||||
for _order_indicator in inner_order_indicators:
|
||||
volume = volume.add(_order_indicator["volume"], fill_value=0)
|
||||
self.order_indicator["volume"] = volume
|
||||
|
||||
if base_price == "twap":
|
||||
base_price = pd.Series()
|
||||
price_count = pd.Series()
|
||||
for _order_indicator in inner_order_indicators:
|
||||
base_price = base_price.add(_order_indicator["base_price"], fill_value=0)
|
||||
price_count = price_count.add(pd.Series(1, index=_order_indicator["base_price"].index), fill_value=0)
|
||||
base_price /= price_count
|
||||
self.order_indicator["base_price"] = base_price
|
||||
|
||||
elif base_price == "vwap":
|
||||
base_price = pd.Series()
|
||||
for _order_indicator in inner_order_indicators:
|
||||
base_price = base_price.add(_order_indicator["base_price"] * _order_indicator["volume"], fill_value=0)
|
||||
base_price /= self.order_indicator["volume"]
|
||||
self.order_indicator["base_price"] = base_price
|
||||
|
||||
else:
|
||||
raise ValueError(f"base_price {base_price} is not supported!")
|
||||
|
||||
self.order_indicator["pa"] = self.order_indicator["trade_price"] / self.order_indicator["base_price"] - 1
|
||||
# print("trade_price", self.order_indicator["trade_price"], "base_price", self.order_indicator["base_price"], "pa", self.order_indicator["pa"]* (2 * (self.order_indicator["amount"] < 0).astype(int) - 1))
|
||||
|
||||
def _cal_trade_fulfill_rate(self, method="mean"):
|
||||
if method == "mean":
|
||||
return self.order_indicator["ffr"].mean()
|
||||
elif method == "amount_weighted":
|
||||
weights = self.order_indicator["deal_amount"].abs()
|
||||
return (self.order_indicator["ffr"] * weights).sum() / weights.sum()
|
||||
elif method == "value_weighted":
|
||||
weights = self.order_indicator["trade_value"].abs()
|
||||
return (self.order_indicator["ffr"] * weights).sum() / weights.sum()
|
||||
else:
|
||||
raise ValueError(f"method {method} is not supported!")
|
||||
|
||||
def _cal_trade_price_advantage(self, method="mean"):
|
||||
pa_order = self.order_indicator["pa"] * (2 * (self.order_indicator["amount"] < 0).astype(int) - 1)
|
||||
if method == "mean":
|
||||
return pa_order.mean()
|
||||
elif method == "amount_weighted":
|
||||
weights = self.order_indicator["deal_amount"].abs()
|
||||
return (pa_order * weights).sum() / weights.sum()
|
||||
elif method == "value_weighted":
|
||||
weights = self.order_indicator["trade_value"].abs()
|
||||
return (pa_order * weights).sum() / weights.sum()
|
||||
else:
|
||||
raise ValueError(f"method {method} is not supported!")
|
||||
|
||||
def _cal_trade_positive_rate(self):
|
||||
pa_order = self.order_indicator["pa"] * (2 * (self.order_indicator["amount"] < 0).astype(int) - 1)
|
||||
return (pa_order > 0).astype(int).sum() / pa_order.count()
|
||||
|
||||
def _cal_trade_amount(self):
|
||||
return self.order_indicator["deal_amount"].abs().sum()
|
||||
|
||||
def _cal_trade_value(self):
|
||||
return self.order_indicator["trade_value"].abs().sum()
|
||||
|
||||
def _cal_trade_order_count(self):
|
||||
return self.order_indicator["amount"].count()
|
||||
|
||||
def update_order_indicators(self, trade_start_time, trade_end_time, trade_info, trade_exchange):
|
||||
self._update_order_trade_info(trade_info=trade_info)
|
||||
self._update_order_fulfill_rate()
|
||||
self._update_order_price_advantage(trade_exchange, trade_start_time, trade_end_time)
|
||||
|
||||
def agg_order_indicators(self, inner_order_indicators, indicator_config={}):
|
||||
self._agg_order_trade_info(inner_order_indicators)
|
||||
self._agg_order_fulfill_rate()
|
||||
pa_config = indicator_config.get("pa_config", {})
|
||||
self._agg_order_price_advantage(inner_order_indicators, base_price=pa_config.get("base_price", "twap"))
|
||||
|
||||
def cal_trade_indicators(self, trade_start_time, freq, indicator_config={}):
|
||||
show_indicator = indicator_config.get("show_indicator", False)
|
||||
ffr_config = indicator_config.get("ffr_config", {})
|
||||
pa_config = indicator_config.get("pa_config", {})
|
||||
fulfill_rate = self._cal_trade_fulfill_rate(method=ffr_config.get("weight_method", "mean"))
|
||||
price_advantage = self._cal_trade_price_advantage(method=pa_config.get("weight_method", "mean"))
|
||||
positive_rate = self._cal_trade_positive_rate()
|
||||
trade_amount = self._cal_trade_amount()
|
||||
trade_value = self._cal_trade_value()
|
||||
order_count = self._cal_trade_order_count()
|
||||
self.trade_indicator["ffr"] = fulfill_rate
|
||||
self.trade_indicator["pa"] = price_advantage
|
||||
self.trade_indicator["pos"] = positive_rate
|
||||
self.trade_indicator["amount"] = trade_amount
|
||||
self.trade_indicator["value"] = trade_value
|
||||
self.trade_indicator["count"] = order_count
|
||||
if show_indicator:
|
||||
print(
|
||||
"[Indicator({}) {:%Y-%m-%d %H:%M:%S}]: FFR: {}, PA: {}, POS: {}".format(
|
||||
freq, trade_start_time, fulfill_rate, price_advantage, positive_rate
|
||||
)
|
||||
)
|
||||
|
||||
def get_order_indicator(self):
|
||||
return self.order_indicator
|
||||
|
||||
def get_trade_indicator(self):
|
||||
return self.trade_indicator
|
||||
|
||||
def generate_trade_indicators_dataframe(self):
|
||||
return pd.DataFrame.from_dict(self.trade_indicator_his, orient="index")
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
import pandas as pd
|
||||
import warnings
|
||||
from typing import Union
|
||||
from typing import Tuple, Union, List, Set
|
||||
|
||||
from ..utils.resam import get_resam_calendar
|
||||
from ..data.data import Cal
|
||||
@@ -74,7 +74,12 @@ class TradeCalendarManager:
|
||||
|
||||
def get_step_time(self, trade_step=0, shift=0):
|
||||
"""
|
||||
Get the time range of trading step
|
||||
Get the left and right endpoints of the trade_step'th trading interval
|
||||
|
||||
About the endpoints:
|
||||
- Qlib uses the closed interval in time-series data selection, which has the same performance as pandas.Series.loc
|
||||
- The returned right endpoints should minus 1 seconds becasue of the closed interval representation in Qlib.
|
||||
Note: Qlib supports up to minutely decision execution, so 1 seconds is less than any trading time interval.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -98,6 +103,9 @@ class TradeCalendarManager:
|
||||
"""Get the start_time and end_time for trading"""
|
||||
return self.start_time, self.end_time
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.start_time}[{self.start_index}]~{self.end_time}[{self.end_index}]: [{self.trade_step}/{self.trade_len}]"
|
||||
|
||||
|
||||
class BaseInfrastructure:
|
||||
def __init__(self, **kwargs):
|
||||
|
||||
@@ -11,7 +11,7 @@ import warnings
|
||||
from ..log import get_module_logger
|
||||
from ..backtest import get_exchange, backtest as backtest_func
|
||||
from ..utils import get_date_range
|
||||
from ..utils.resam import parse_freq
|
||||
from ..utils.resam import Freq
|
||||
|
||||
from ..data import D
|
||||
from ..config import C
|
||||
@@ -35,14 +35,14 @@ def risk_analysis(r, N: int = None, freq: str = "day"):
|
||||
"""
|
||||
|
||||
def cal_risk_analysis_scaler(freq):
|
||||
_count, _freq = parse_freq(freq)
|
||||
_count, _freq = Freq.parse(freq)
|
||||
_freq_scaler = {
|
||||
"minute": 240 * 252,
|
||||
"day": 252,
|
||||
"week": 50,
|
||||
"month": 12,
|
||||
Freq.NORM_FREQ_MINUTE: 240 * 252,
|
||||
Freq.NORM_FREQ_DAY: 252,
|
||||
Freq.NORM_FREQ_WEEK: 50,
|
||||
Freq.NORM_FREQ_MONTH: 12,
|
||||
}
|
||||
return _count * _freq_scaler[_freq]
|
||||
return _freq_scaler[_freq] / _count
|
||||
|
||||
if N is None and freq is None:
|
||||
raise ValueError("at least one of `N` and `freq` should exist")
|
||||
@@ -63,7 +63,55 @@ def risk_analysis(r, N: int = None, freq: str = "day"):
|
||||
"information_ratio": information_ratio,
|
||||
"max_drawdown": max_drawdown,
|
||||
}
|
||||
res = pd.Series(data, index=data.keys()).to_frame("risk")
|
||||
res = pd.Series(data).to_frame("risk")
|
||||
return res
|
||||
|
||||
|
||||
def indicator_analysis(df, method="mean"):
|
||||
"""analyze statistical time-series indicators of trading
|
||||
|
||||
Parameters
|
||||
----------
|
||||
df : pandas.DataFrame
|
||||
columns: like ['pa', 'pos', 'ffr', 'amount', 'value'].
|
||||
Necessary fields:
|
||||
- 'pa' is the price advantage in trade indicators
|
||||
- 'pos' is the positive rate in trade indicators
|
||||
- 'ffr' is the fulfill rate in trade indicators
|
||||
Optional fields:
|
||||
- 'amount' is the total deal amount, only necessary when method is 'amount_weighted'
|
||||
- 'value' is the total trade value, only necessary when method is 'value_weighted'
|
||||
|
||||
index: Index(datetime)
|
||||
method : str, optional
|
||||
statistics method of pa/ffr, by default "mean"
|
||||
- if method is 'mean', count the mean statistical value of each trade indicator
|
||||
- if method is 'amount_weighted', count the amount weighted mean statistical value of each trade indicator
|
||||
- if method is 'value_weighted', count the value weighted mean statistical value of each trade indicator
|
||||
Note: statistics method of pos is always "mean"
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame
|
||||
statistical value of each trade indicators
|
||||
"""
|
||||
weights_dict = {
|
||||
"mean": df["count"],
|
||||
"amount_weighted": df["amount"].abs(),
|
||||
"value_weighted": df["value"].abs(),
|
||||
}
|
||||
if method not in weights_dict:
|
||||
raise ValueError(f"indicator_analysis method {method} is not supported!")
|
||||
|
||||
# statistic pa/ffr indicator
|
||||
indicators_df = df[["ffr", "pa"]]
|
||||
weights = weights_dict.get(method)
|
||||
res = indicators_df.mul(weights, axis=0).sum() / weights.sum()
|
||||
|
||||
# statistic pos
|
||||
weights = weights_dict.get("mean")
|
||||
res.loc["pos"] = df["pos"].mul(weights).sum() / weights.sum()
|
||||
res = res.to_frame("value")
|
||||
return res
|
||||
|
||||
|
||||
|
||||
393
qlib/contrib/model/pytorch_tcts.py
Normal file
393
qlib/contrib/model/pytorch_tcts.py
Normal file
@@ -0,0 +1,393 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import copy
|
||||
from sklearn.metrics import roc_auc_score, mean_squared_error
|
||||
import logging
|
||||
from ...utils import (
|
||||
unpack_archive_with_buffer,
|
||||
save_multiple_parts_file,
|
||||
create_save_path,
|
||||
drop_nan_by_y_index,
|
||||
)
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
|
||||
from ...model.base import Model
|
||||
from ...data.dataset import DatasetH
|
||||
from ...data.dataset.handler import DataHandlerLP
|
||||
|
||||
|
||||
class TCTS(Model):
|
||||
"""TCTS Model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
d_feat : int
|
||||
input dimension for each time step
|
||||
metric: str
|
||||
the evaluate metric used in early stop
|
||||
optimizer : str
|
||||
optimizer name
|
||||
GPU : str
|
||||
the GPU ID(s) used for training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_feat=6,
|
||||
hidden_size=64,
|
||||
num_layers=2,
|
||||
dropout=0.0,
|
||||
n_epochs=200,
|
||||
batch_size=2000,
|
||||
early_stop=20,
|
||||
loss="mse",
|
||||
fore_optimizer="adam",
|
||||
weight_optimizer="adam",
|
||||
output_dim=5,
|
||||
fore_lr=5e-7,
|
||||
weight_lr=5e-7,
|
||||
steps=3,
|
||||
GPU=0,
|
||||
seed=None,
|
||||
target_label=0,
|
||||
**kwargs
|
||||
):
|
||||
# Set logger.
|
||||
self.logger = get_module_logger("TCTS")
|
||||
self.logger.info("TCTS pytorch version...")
|
||||
|
||||
# set hyper-parameters.
|
||||
self.d_feat = d_feat
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.dropout = dropout
|
||||
self.n_epochs = n_epochs
|
||||
self.batch_size = batch_size
|
||||
self.early_stop = early_stop
|
||||
self.loss = loss
|
||||
self.device = torch.device("cuda:%d" % (GPU) if torch.cuda.is_available() else "cpu")
|
||||
self.use_gpu = torch.cuda.is_available()
|
||||
self.seed = seed
|
||||
self.output_dim = output_dim
|
||||
self.fore_lr = fore_lr
|
||||
self.weight_lr = weight_lr
|
||||
self.steps = steps
|
||||
self.target_label = target_label
|
||||
|
||||
self.logger.info(
|
||||
"TCTS parameters setting:"
|
||||
"\nd_feat : {}"
|
||||
"\nhidden_size : {}"
|
||||
"\nnum_layers : {}"
|
||||
"\ndropout : {}"
|
||||
"\nn_epochs : {}"
|
||||
"\nbatch_size : {}"
|
||||
"\nearly_stop : {}"
|
||||
"\nloss_type : {}"
|
||||
"\nvisible_GPU : {}"
|
||||
"\nuse_GPU : {}"
|
||||
"\nseed : {}".format(
|
||||
d_feat,
|
||||
hidden_size,
|
||||
num_layers,
|
||||
dropout,
|
||||
n_epochs,
|
||||
batch_size,
|
||||
early_stop,
|
||||
loss,
|
||||
GPU,
|
||||
self.use_gpu,
|
||||
seed,
|
||||
)
|
||||
)
|
||||
|
||||
if self.seed is not None:
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
self.fore_model = GRUModel(
|
||||
d_feat=self.d_feat,
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
)
|
||||
self.weight_model = MLPModel(
|
||||
d_feat=360 + 2 * self.output_dim + 1,
|
||||
hidden_size=self.hidden_size,
|
||||
num_layers=self.num_layers,
|
||||
dropout=self.dropout,
|
||||
output_dim=self.output_dim,
|
||||
)
|
||||
if fore_optimizer.lower() == "adam":
|
||||
self.fore_optimizer = optim.Adam(self.fore_model.parameters(), lr=self.fore_lr)
|
||||
elif fore_optimizer.lower() == "gd":
|
||||
self.fore_optimizer = optim.SGD(self.fore_model.parameters(), lr=self.fore_lr)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(fore_optimizer))
|
||||
if weight_optimizer.lower() == "adam":
|
||||
self.weight_optimizer = optim.Adam(self.weight_model.parameters(), lr=self.weight_lr)
|
||||
elif weight_optimizer.lower() == "gd":
|
||||
self.weight_optimizer = optim.SGD(self.weight_model.parameters(), lr=self.weight_lr)
|
||||
else:
|
||||
raise NotImplementedError("optimizer {} is not supported!".format(weight_optimizer))
|
||||
|
||||
self.fitted = False
|
||||
self.fore_model.to(self.device)
|
||||
self.weight_model.to(self.device)
|
||||
|
||||
def loss_fn(self, pred, label, weight):
|
||||
|
||||
loc = torch.argmax(weight, 1)
|
||||
loss = (pred - label[np.arange(weight.shape[0]), loc]) ** 2
|
||||
return torch.mean(loss)
|
||||
|
||||
def train_epoch(self, x_train, y_train, x_valid, y_valid):
|
||||
|
||||
x_train_values = x_train.values
|
||||
y_train_values = np.squeeze(y_train.values)
|
||||
|
||||
indices = np.arange(len(x_train_values))
|
||||
np.random.shuffle(indices)
|
||||
|
||||
init_fore_model = copy.deepcopy(self.fore_model)
|
||||
for p in init_fore_model.parameters():
|
||||
p.init_fore_model = False
|
||||
|
||||
self.fore_model.train()
|
||||
self.weight_model.train()
|
||||
|
||||
for p in self.weight_model.parameters():
|
||||
p.requires_grad = False
|
||||
for p in self.fore_model.parameters():
|
||||
p.requires_grad = True
|
||||
|
||||
for i in range(self.steps):
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
feature = torch.from_numpy(x_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_train_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
|
||||
init_pred = init_fore_model(feature)
|
||||
pred = self.fore_model(feature)
|
||||
|
||||
dis = init_pred - label.transpose(0, 1)
|
||||
weight_feature = torch.cat((feature, dis.transpose(0, 1), label, init_pred.view(-1, 1)), 1)
|
||||
weight = self.weight_model(weight_feature)
|
||||
|
||||
loss = self.loss_fn(pred, label, weight) # hard
|
||||
|
||||
self.fore_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(self.fore_model.parameters(), 3.0)
|
||||
self.fore_optimizer.step()
|
||||
|
||||
x_valid_values = x_valid.values
|
||||
y_valid_values = np.squeeze(y_valid.values)
|
||||
|
||||
indices = np.arange(len(x_valid_values))
|
||||
np.random.shuffle(indices)
|
||||
for p in self.weight_model.parameters():
|
||||
p.requires_grad = True
|
||||
for p in self.fore_model.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
# fix forecasting model and valid weight model
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
feature = torch.from_numpy(x_valid_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_valid_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
|
||||
pred = self.fore_model(feature)
|
||||
dis = pred - label.transpose(0, 1)
|
||||
weight_feature = torch.cat((feature, dis.transpose(0, 1), label, pred.view(-1, 1)), 1)
|
||||
weight = self.weight_model(weight_feature)
|
||||
loc = torch.argmax(weight, 1)
|
||||
valid_loss = torch.mean((pred - label[:, 0]) ** 2)
|
||||
loss = torch.mean(-valid_loss * torch.log(weight[np.arange(weight.shape[0]), loc]))
|
||||
|
||||
self.weight_optimizer.zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_value_(self.weight_model.parameters(), 3.0)
|
||||
self.weight_optimizer.step()
|
||||
|
||||
def test_epoch(self, data_x, data_y):
|
||||
|
||||
# prepare training data
|
||||
x_values = data_x.values
|
||||
y_values = np.squeeze(data_y.values)
|
||||
|
||||
self.fore_model.eval()
|
||||
|
||||
scores = []
|
||||
losses = []
|
||||
|
||||
indices = np.arange(len(x_values))
|
||||
|
||||
for i in range(len(indices))[:: self.batch_size]:
|
||||
|
||||
if len(indices) - i < self.batch_size:
|
||||
break
|
||||
|
||||
feature = torch.from_numpy(x_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
label = torch.from_numpy(y_values[indices[i : i + self.batch_size]]).float().to(self.device)
|
||||
|
||||
pred = self.fore_model(feature)
|
||||
loss = torch.mean((pred - label[:, abs(self.target_label)]) ** 2)
|
||||
losses.append(loss.item())
|
||||
|
||||
return np.mean(losses)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
evals_result=dict(),
|
||||
verbose=True,
|
||||
save_path=None,
|
||||
):
|
||||
|
||||
df_train, df_valid, df_test = dataset.prepare(
|
||||
["train", "valid", "test"],
|
||||
col_set=["feature", "label"],
|
||||
data_key=DataHandlerLP.DK_L,
|
||||
)
|
||||
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label"]
|
||||
x_test, y_test = df_test["feature"], df_test["label"]
|
||||
|
||||
if save_path == None:
|
||||
save_path = create_save_path(save_path)
|
||||
|
||||
best_loss = np.inf
|
||||
best_epoch = 0
|
||||
stop_round = 0
|
||||
fore_best_param = copy.deepcopy(self.fore_optimizer.state_dict())
|
||||
weight_best_param = copy.deepcopy(self.weight_optimizer.state_dict())
|
||||
|
||||
for epoch in range(self.n_epochs):
|
||||
print("Epoch:", epoch)
|
||||
|
||||
print("training...")
|
||||
self.train_epoch(x_train, y_train, x_valid, y_valid)
|
||||
print("evaluating...")
|
||||
val_loss = self.test_epoch(x_valid, y_valid)
|
||||
test_loss = self.test_epoch(x_test, y_test)
|
||||
|
||||
print("valid %.6f, test %.6f" % (val_loss, test_loss))
|
||||
|
||||
if val_loss < best_loss:
|
||||
best_loss = val_loss
|
||||
stop_round = 0
|
||||
best_epoch = epoch
|
||||
torch.save(copy.deepcopy(self.fore_model.state_dict()), save_path + "_fore_model.bin")
|
||||
torch.save(copy.deepcopy(self.weight_model.state_dict()), save_path + "_weight_model.bin")
|
||||
|
||||
else:
|
||||
stop_round += 1
|
||||
if stop_round >= self.early_stop:
|
||||
print("early stop")
|
||||
break
|
||||
|
||||
print("best loss:", best_loss, "@", best_epoch)
|
||||
best_param = torch.load(save_path + "_fore_model.bin")
|
||||
self.fore_model.load_state_dict(best_param)
|
||||
best_param = torch.load(save_path + "_weight_model.bin")
|
||||
self.weight_model.load_state_dict(best_param)
|
||||
self.fitted = True
|
||||
|
||||
if self.use_gpu:
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def predict(self, dataset):
|
||||
if not self.fitted:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
|
||||
x_test = dataset.prepare("test", col_set="feature")
|
||||
index = x_test.index
|
||||
self.fore_model.eval()
|
||||
x_values = x_test.values
|
||||
sample_num = x_values.shape[0]
|
||||
preds = []
|
||||
|
||||
for begin in range(sample_num)[:: self.batch_size]:
|
||||
|
||||
if sample_num - begin < self.batch_size:
|
||||
end = sample_num
|
||||
else:
|
||||
end = begin + self.batch_size
|
||||
|
||||
x_batch = torch.from_numpy(x_values[begin:end]).float().to(self.device)
|
||||
|
||||
with torch.no_grad():
|
||||
if self.use_gpu:
|
||||
pred = self.fore_model(x_batch).detach().cpu().numpy()
|
||||
else:
|
||||
pred = self.fore_model(x_batch).detach().numpy()
|
||||
|
||||
preds.append(pred)
|
||||
|
||||
return pd.Series(np.concatenate(preds), index=index)
|
||||
|
||||
|
||||
class MLPModel(nn.Module):
|
||||
def __init__(self, d_feat, hidden_size=256, num_layers=3, dropout=0.0, output_dim=1):
|
||||
super().__init__()
|
||||
|
||||
self.mlp = nn.Sequential()
|
||||
self.softmax = nn.Softmax(dim=1)
|
||||
|
||||
for i in range(num_layers):
|
||||
if i > 0:
|
||||
self.mlp.add_module("drop_%d" % i, nn.Dropout(dropout))
|
||||
self.mlp.add_module("fc_%d" % i, nn.Linear(d_feat if i == 0 else hidden_size, hidden_size))
|
||||
self.mlp.add_module("relu_%d" % i, nn.ReLU())
|
||||
|
||||
self.mlp.add_module("fc_out", nn.Linear(hidden_size, output_dim))
|
||||
|
||||
def forward(self, x):
|
||||
# feature
|
||||
# [N, F]
|
||||
out = self.mlp(x).squeeze()
|
||||
out = self.softmax(out)
|
||||
return out
|
||||
|
||||
|
||||
class GRUModel(nn.Module):
|
||||
def __init__(self, d_feat=6, hidden_size=64, num_layers=2, dropout=0.0):
|
||||
super().__init__()
|
||||
|
||||
self.rnn = nn.GRU(
|
||||
input_size=d_feat,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
batch_first=True,
|
||||
dropout=dropout,
|
||||
)
|
||||
self.fc_out = nn.Linear(hidden_size, 1)
|
||||
|
||||
self.d_feat = d_feat
|
||||
|
||||
def forward(self, x):
|
||||
# x: [N, F*T]
|
||||
x = x.reshape(len(x), self.d_feat, -1) # [N, F, T]
|
||||
x = x.permute(0, 2, 1) # [N, T, F]
|
||||
out, _ = self.rnn(x)
|
||||
return self.fc_out(out[:, -1, :]).squeeze()
|
||||
@@ -62,7 +62,7 @@ class XGBModel(Model, FeatureInt):
|
||||
if self.model is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare(segment, col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
return pd.Series(self.model.predict(xgb.DMatrix(x_test.values)), index=x_test.index)
|
||||
return pd.Series(self.model.predict(xgb.DMatrix(x_test)), index=x_test.index)
|
||||
|
||||
def get_feature_importance(self, *args, **kwargs) -> pd.Series:
|
||||
"""get feature importance
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
This strategy is not well maintained
|
||||
"""
|
||||
|
||||
|
||||
from .order_generator import OrderGenWInteract
|
||||
|
||||
@@ -1,15 +1,21 @@
|
||||
import copy
|
||||
from qlib.backtest.position import Position
|
||||
import warnings
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from ...utils.resam import resam_ts_data
|
||||
from ...strategy.base import ModelStrategy
|
||||
from ...backtest.order import Order
|
||||
from ...backtest.order import Order, BaseTradeDecision, TradeDecisionWO
|
||||
|
||||
from .order_generator import OrderGenWInteract
|
||||
|
||||
|
||||
class TopkDropoutStrategy(ModelStrategy):
|
||||
# TODO:
|
||||
# 1. Supporting leverage the get_range_limit result from the decision
|
||||
# 2. Supporting alter_outer_trade_decision
|
||||
# 3. Supporting checking the availability of trade decision
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
@@ -51,6 +57,11 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
trade_exchange : Exchange
|
||||
exchange that provides market info, used to deal order and generate report
|
||||
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
|
||||
- It allowes different trade_exchanges is used in different executions.
|
||||
- For example:
|
||||
- In daily execution, both daily exchange and minutely are usable, but the daily exchange is recommended because it run faster.
|
||||
- In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended.
|
||||
|
||||
"""
|
||||
super(TopkDropoutStrategy, self).__init__(
|
||||
model, dataset, level_infra=level_infra, common_infra=common_infra, **kwargs
|
||||
@@ -94,7 +105,7 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)
|
||||
pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
|
||||
if pred_score is None:
|
||||
return []
|
||||
return TradeDecisionWO([], self)
|
||||
if self.only_tradable:
|
||||
# If The strategy only consider tradable stock when make decision
|
||||
# It needs following actions to filter stocks
|
||||
@@ -239,10 +250,14 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
factor=factor,
|
||||
)
|
||||
buy_order_list.append(buy_order)
|
||||
return sell_order_list + buy_order_list
|
||||
return TradeDecisionWO(sell_order_list + buy_order_list, self)
|
||||
|
||||
|
||||
class WeightStrategyBase(ModelStrategy):
|
||||
# TODO:
|
||||
# 1. Supporting leverage the get_range_limit result from the decision
|
||||
# 2. Supporting alter_outer_trade_decision
|
||||
# 3. Supporting checking the availability of trade decision
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
@@ -253,6 +268,15 @@ class WeightStrategyBase(ModelStrategy):
|
||||
common_infra=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
trade_exchange : Exchange
|
||||
exchange that provides market info, used to deal order and generate report
|
||||
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
|
||||
- It allowes different trade_exchanges is used in different executions.
|
||||
- For example:
|
||||
- In daily execution, both daily exchange and minutely are usable, but the daily exchange is recommended because it run faster.
|
||||
- In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended.
|
||||
"""
|
||||
super(WeightStrategyBase, self).__init__(
|
||||
model, dataset, level_infra=level_infra, common_infra=common_infra, **kwargs
|
||||
)
|
||||
@@ -301,18 +325,6 @@ class WeightStrategyBase(ModelStrategy):
|
||||
raise NotImplementedError()
|
||||
|
||||
def generate_trade_decision(self, execute_result=None):
|
||||
"""
|
||||
Parameters
|
||||
-----------
|
||||
score_series : pd.Seires
|
||||
stock_id , score.
|
||||
current : Position()
|
||||
current of account.
|
||||
trade_exchange : Exchange()
|
||||
exchange.
|
||||
trade_date : pd.Timestamp
|
||||
date.
|
||||
"""
|
||||
# generate_trade_decision
|
||||
# generate_target_weight_position() and generate_order_list_from_target_weight_position() to generate order_list
|
||||
|
||||
@@ -322,8 +334,10 @@ class WeightStrategyBase(ModelStrategy):
|
||||
pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)
|
||||
pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
|
||||
if pred_score is None:
|
||||
return []
|
||||
return TradeDecisionWO([], self)
|
||||
current_temp = copy.deepcopy(self.trade_position)
|
||||
assert isinstance(current_temp, Position) # Avoid InfPosition
|
||||
|
||||
target_weight_position = self.generate_target_weight_position(
|
||||
score=pred_score, current=current_temp, trade_start_time=trade_start_time, trade_end_time=trade_end_time
|
||||
)
|
||||
@@ -337,4 +351,4 @@ class WeightStrategyBase(ModelStrategy):
|
||||
trade_start_time=trade_start_time,
|
||||
trade_end_time=trade_end_time,
|
||||
)
|
||||
return order_list
|
||||
return TradeDecisionWO(order_list, self)
|
||||
|
||||
@@ -6,6 +6,8 @@ This order generator is for strategies based on WeightStrategyBase
|
||||
"""
|
||||
from ...backtest.position import Position
|
||||
from ...backtest.exchange import Exchange
|
||||
from ...backtest.order import BaseTradeDecision, TradeDecisionWO
|
||||
|
||||
import pandas as pd
|
||||
import copy
|
||||
|
||||
@@ -125,7 +127,7 @@ class OrderGenWInteract(OrderGenerator):
|
||||
trade_start_time=trade_start_time,
|
||||
trade_end_time=trade_end_time,
|
||||
)
|
||||
return order_list
|
||||
return TradeDecisionWO(order_list, self)
|
||||
|
||||
|
||||
class OrderGenWOInteract(OrderGenerator):
|
||||
@@ -189,4 +191,4 @@ class OrderGenWOInteract(OrderGenerator):
|
||||
trade_start_time=trade_start_time,
|
||||
trade_end_time=trade_end_time,
|
||||
)
|
||||
return order_list
|
||||
return TradeDecisionWO(order_list, self)
|
||||
|
||||
@@ -1,21 +1,46 @@
|
||||
import warnings
|
||||
from typing import List, Union
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
from ...utils.resam import resam_ts_data
|
||||
from ...data.data import D
|
||||
from ...data.dataset.utils import convert_index_format
|
||||
from ...strategy.base import BaseStrategy
|
||||
from ...backtest.order import Order
|
||||
from ...backtest.order import BaseTradeDecision, Order, TradeDecisionWO
|
||||
from ...backtest.exchange import Exchange
|
||||
from ...backtest.utils import CommonInfrastructure, LevelInfrastructure
|
||||
|
||||
|
||||
def get_start_end_idx(strategy: BaseStrategy, outer_trade_decision: BaseTradeDecision) -> Union[int, int]:
|
||||
"""
|
||||
A helper function for getting the decision-level index range limitation for inner strategy
|
||||
- NOTE: this function is not applicable to order-level
|
||||
|
||||
Parameters
|
||||
----------
|
||||
strategy : BaseStrategy
|
||||
the inner strawtegy
|
||||
outer_trade_decision : BaseTradeDecision
|
||||
the trade decision made by outer strategy
|
||||
|
||||
Returns
|
||||
-------
|
||||
Union[int, int]:
|
||||
start index and end index
|
||||
"""
|
||||
try:
|
||||
return outer_trade_decision.get_range_limit()
|
||||
except NotImplementedError:
|
||||
return 0, strategy.trade_calendar.get_trade_len() - 1
|
||||
|
||||
|
||||
class TWAPStrategy(BaseStrategy):
|
||||
"""TWAP Strategy for trading"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
outer_trade_decision: List[Order] = None,
|
||||
outer_trade_decision: BaseTradeDecision = None,
|
||||
trade_exchange: Exchange = None,
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
@@ -23,11 +48,15 @@ class TWAPStrategy(BaseStrategy):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
outer_trade_decision : List[Order]
|
||||
the trade decison of outer strategy which this startegy relies, it should be List[Order] in TWAPStrategy
|
||||
outer_trade_decision : BaseTradeDecision
|
||||
the trade decision of outer strategy which this startegy relies
|
||||
trade_exchange : Exchange
|
||||
exchange that provides market info, used to deal order and generate report
|
||||
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
|
||||
- It allowes different trade_exchanges is used in different executions.
|
||||
- For example:
|
||||
- In daily execution, both daily exchange and minutely are usable, but the daily exchange is recommended because it run faster.
|
||||
- In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended.
|
||||
|
||||
"""
|
||||
super(TWAPStrategy, self).__init__(
|
||||
@@ -51,33 +80,44 @@ class TWAPStrategy(BaseStrategy):
|
||||
if common_infra.has("trade_exchange"):
|
||||
self.trade_exchange = common_infra.get("trade_exchange")
|
||||
|
||||
def reset(self, outer_trade_decision: List[Order] = None, **kwargs):
|
||||
def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
outer_trade_decision : List[Order], optional
|
||||
outer_trade_decision : BaseTradeDecision, optional
|
||||
"""
|
||||
|
||||
super(TWAPStrategy, self).reset(outer_trade_decision=outer_trade_decision, **kwargs)
|
||||
if outer_trade_decision is not None:
|
||||
self.trade_amount = {}
|
||||
for order in outer_trade_decision:
|
||||
self.trade_amount[(order.stock_id, order.direction)] = order.amount
|
||||
for order in outer_trade_decision.get_decision():
|
||||
self.trade_amount[order.stock_id] = order.amount
|
||||
|
||||
def generate_trade_decision(self, execute_result=None):
|
||||
|
||||
# update the order amount
|
||||
if execute_result is not None:
|
||||
for order, _, _, _ in execute_result:
|
||||
self.trade_amount[(order.stock_id, order.direction)] -= order.deal_amount
|
||||
# strategy is not available. Give an empty decision
|
||||
if len(self.outer_trade_decision.get_decision()) == 0:
|
||||
return TradeDecisionWO(order_list=[], strategy=self)
|
||||
|
||||
# get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1]
|
||||
trade_step = self.trade_calendar.get_trade_step()
|
||||
# get the total count of trading step
|
||||
trade_len = self.trade_calendar.get_trade_len()
|
||||
start_idx, end_idx = get_start_end_idx(self, self.outer_trade_decision)
|
||||
trade_len = end_idx - start_idx + 1
|
||||
|
||||
if trade_step < start_idx:
|
||||
# It is not time to start trading
|
||||
return TradeDecisionWO(order_list=[], strategy=self)
|
||||
|
||||
rel_trade_step = trade_step - start_idx # trade_step relative to start_idx
|
||||
|
||||
# update the order amount
|
||||
if execute_result is not None:
|
||||
for order, _, _, _ in execute_result:
|
||||
self.trade_amount[order.stock_id] -= order.deal_amount
|
||||
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
|
||||
order_list = []
|
||||
for order in self.outer_trade_decision:
|
||||
for order in self.outer_trade_decision.get_decision():
|
||||
# if not tradable, continue
|
||||
if not self.trade_exchange.is_stock_tradable(
|
||||
stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time
|
||||
@@ -88,27 +128,31 @@ class TWAPStrategy(BaseStrategy):
|
||||
# considering trade unit
|
||||
if _amount_trade_unit is None:
|
||||
# divide the order into equal parts, and trade one part
|
||||
_order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_step + 1)
|
||||
_order_amount = self.trade_amount[order.stock_id] / (trade_len - rel_trade_step)
|
||||
# without considering trade unit
|
||||
elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit:
|
||||
else:
|
||||
# divide the order into equal parts, and trade one part
|
||||
# calculate the total count of trade units to trade
|
||||
trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
|
||||
trade_unit_cnt = int(self.trade_amount[order.stock_id] // _amount_trade_unit)
|
||||
# calculate the amount of one part, ceil the amount
|
||||
# floor((trade_unit_cnt + trade_len - trade_step) / (trade_len - trade_step + 1)) == ceil(trade_unit_cnt / (trade_len - trade_step + 1))
|
||||
# floor((trade_unit_cnt + trade_len - rel_trade_step) / (trade_len - rel_trade_step + 1)) == ceil(trade_unit_cnt / (trade_len - rel_trade_step + 1))
|
||||
_order_amount = (
|
||||
(trade_unit_cnt + trade_len - trade_step) // (trade_len - trade_step + 1) * _amount_trade_unit
|
||||
(trade_unit_cnt + trade_len - rel_trade_step - 1)
|
||||
// (trade_len - rel_trade_step)
|
||||
* _amount_trade_unit
|
||||
)
|
||||
|
||||
if order.direction == order.SELL:
|
||||
# sell all amount at last
|
||||
if self.trade_amount[(order.stock_id, order.direction)] > 1e-5 and (
|
||||
_order_amount is None or trade_step == trade_len
|
||||
if self.trade_amount[order.stock_id] > 1e-5 and (
|
||||
_order_amount < 1e-5 or rel_trade_step == trade_len - 1
|
||||
):
|
||||
_order_amount = self.trade_amount[(order.stock_id, order.direction)]
|
||||
_order_amount = self.trade_amount[order.stock_id]
|
||||
|
||||
_order_amount = min(_order_amount, self.trade_amount[order.stock_id])
|
||||
|
||||
if _order_amount > 1e-5:
|
||||
|
||||
if _order_amount:
|
||||
_order_amount = min(_order_amount, self.trade_amount[(order.stock_id, order.direction)])
|
||||
_order = Order(
|
||||
stock_id=order.stock_id,
|
||||
amount=_order_amount,
|
||||
@@ -118,7 +162,7 @@ class TWAPStrategy(BaseStrategy):
|
||||
factor=order.factor,
|
||||
)
|
||||
order_list.append(_order)
|
||||
return order_list
|
||||
return TradeDecisionWO(order_list=order_list, strategy=self)
|
||||
|
||||
|
||||
class SBBStrategyBase(BaseStrategy):
|
||||
@@ -130,9 +174,14 @@ class SBBStrategyBase(BaseStrategy):
|
||||
TREND_SHORT = 1
|
||||
TREND_LONG = 2
|
||||
|
||||
# TODO:
|
||||
# 1. Supporting leverage the get_range_limit result from the decision
|
||||
# 2. Supporting alter_outer_trade_decision
|
||||
# 3. Supporting checking the availability of trade decision
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
outer_trade_decision: List[Order] = None,
|
||||
outer_trade_decision: BaseTradeDecision = None,
|
||||
trade_exchange: Exchange = None,
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
@@ -140,11 +189,15 @@ class SBBStrategyBase(BaseStrategy):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
outer_trade_decision : List[Order]
|
||||
the trade decison of outer strategy which this startegy relies, it should be List[Order] in SBBStrategyBase
|
||||
outer_trade_decision : BaseTradeDecision
|
||||
the trade decision of outer strategy which this startegy relies
|
||||
trade_exchange : Exchange
|
||||
exchange that provides market info, used to deal order and generate report
|
||||
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
|
||||
- It allowes different trade_exchanges is used in different executions.
|
||||
- For example:
|
||||
- In daily execution, both daily exchange and minutely are usable, but the daily exchange is recommended because it run faster.
|
||||
- In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended.
|
||||
"""
|
||||
super(SBBStrategyBase, self).__init__(
|
||||
outer_trade_decision=outer_trade_decision, level_infra=level_infra, common_infra=common_infra
|
||||
@@ -166,52 +219,53 @@ class SBBStrategyBase(BaseStrategy):
|
||||
if common_infra.has("trade_exchange"):
|
||||
self.trade_exchange = common_infra.get("trade_exchange")
|
||||
|
||||
def reset(self, outer_trade_decision: List[Order] = None, **kwargs):
|
||||
def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
outer_trade_decision : List[Order], optional
|
||||
outer_trade_decision : BaseTradeDecision, optional
|
||||
"""
|
||||
super(SBBStrategyBase, self).reset(outer_trade_decision=outer_trade_decision, **kwargs)
|
||||
if outer_trade_decision is not None:
|
||||
self.trade_trend = {}
|
||||
self.trade_amount = {}
|
||||
# init the trade amount of order and predicted trade trend
|
||||
for order in outer_trade_decision:
|
||||
self.trade_trend[(order.stock_id, order.direction)] = self.TREND_MID
|
||||
self.trade_amount[(order.stock_id, order.direction)] = order.amount
|
||||
for order in outer_trade_decision.get_decision():
|
||||
self.trade_trend[order.stock_id] = self.TREND_MID
|
||||
self.trade_amount[order.stock_id] = order.amount
|
||||
|
||||
def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None):
|
||||
raise NotImplementedError("pred_price_trend method is not implemented!")
|
||||
|
||||
def generate_trade_decision(self, execute_result=None):
|
||||
|
||||
# update the order amount
|
||||
if execute_result is not None:
|
||||
for order, _, _, _ in execute_result:
|
||||
self.trade_amount[(order.stock_id, order.direction)] -= order.deal_amount
|
||||
# get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1]
|
||||
trade_step = self.trade_calendar.get_trade_step()
|
||||
# get the total count of trading step
|
||||
trade_len = self.trade_calendar.get_trade_len()
|
||||
|
||||
# update the order amount
|
||||
if execute_result is not None:
|
||||
for order, _, _, _ in execute_result:
|
||||
self.trade_amount[order.stock_id] -= order.deal_amount
|
||||
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
|
||||
pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)
|
||||
order_list = []
|
||||
# for each order in in self.outer_trade_decision
|
||||
for order in self.outer_trade_decision:
|
||||
for order in self.outer_trade_decision.get_decision():
|
||||
# get the price trend
|
||||
if trade_step % 2 == 0:
|
||||
# in the first of two adjacent bars, predict the price trend
|
||||
_pred_trend = self._pred_price_trend(order.stock_id, pred_start_time, pred_end_time)
|
||||
else:
|
||||
# in the second of two adjacent bars, use the trend predicted in the first one
|
||||
_pred_trend = self.trade_trend[(order.stock_id, order.direction)]
|
||||
_pred_trend = self.trade_trend[order.stock_id]
|
||||
# if not tradable, continue
|
||||
if not self.trade_exchange.is_stock_tradable(
|
||||
stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time
|
||||
):
|
||||
if trade_step % 2 == 0:
|
||||
self.trade_trend[(order.stock_id, order.direction)] = _pred_trend
|
||||
self.trade_trend[order.stock_id] = _pred_trend
|
||||
continue
|
||||
# get amount of one trade unit
|
||||
_amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(order.factor)
|
||||
@@ -220,12 +274,12 @@ class SBBStrategyBase(BaseStrategy):
|
||||
# considering trade unit
|
||||
if _amount_trade_unit is None:
|
||||
# divide the order into equal parts, and trade one part
|
||||
_order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_step)
|
||||
_order_amount = self.trade_amount[order.stock_id] / (trade_len - trade_step)
|
||||
# without considering trade unit
|
||||
elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit:
|
||||
else:
|
||||
# divide the order into equal parts, and trade one part
|
||||
# calculate the total count of trade units to trade
|
||||
trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
|
||||
trade_unit_cnt = int(self.trade_amount[order.stock_id] // _amount_trade_unit)
|
||||
# calculate the amount of one part, ceil the amount
|
||||
# floor((trade_unit_cnt + trade_len - trade_step - 1) / (trade_len - trade_step)) == ceil(trade_unit_cnt / (trade_len - trade_step))
|
||||
_order_amount = (
|
||||
@@ -233,12 +287,14 @@ class SBBStrategyBase(BaseStrategy):
|
||||
)
|
||||
if order.direction == order.SELL:
|
||||
# sell all amount at last
|
||||
if self.trade_amount[(order.stock_id, order.direction)] > 1e-5 and (
|
||||
_order_amount is None or trade_step == trade_len - 1
|
||||
if self.trade_amount[order.stock_id] > 1e-5 and (
|
||||
_order_amount < 1e-5 or trade_step == trade_len - 1
|
||||
):
|
||||
_order_amount = self.trade_amount[(order.stock_id, order.direction)]
|
||||
_order_amount = self.trade_amount[order.stock_id]
|
||||
|
||||
if _order_amount:
|
||||
_order_amount = min(_order_amount, self.trade_amount[order.stock_id])
|
||||
|
||||
if _order_amount > 1e-5:
|
||||
_order = Order(
|
||||
stock_id=order.stock_id,
|
||||
amount=_order_amount,
|
||||
@@ -254,13 +310,11 @@ class SBBStrategyBase(BaseStrategy):
|
||||
# considering trade unit
|
||||
if _amount_trade_unit is None:
|
||||
# N trade day left, divide the order into N + 1 parts, and trade 2 parts
|
||||
_order_amount = (
|
||||
2 * self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_step + 1)
|
||||
)
|
||||
_order_amount = 2 * self.trade_amount[order.stock_id] / (trade_len - trade_step + 1)
|
||||
# without considering trade unit
|
||||
elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit:
|
||||
else:
|
||||
# cal how many trade unit
|
||||
trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
|
||||
trade_unit_cnt = int(self.trade_amount[order.stock_id] // _amount_trade_unit)
|
||||
# N trade day left, divide the order into N + 1 parts, and trade 2 parts
|
||||
_order_amount = (
|
||||
(trade_unit_cnt + trade_len - trade_step)
|
||||
@@ -270,13 +324,14 @@ class SBBStrategyBase(BaseStrategy):
|
||||
)
|
||||
if order.direction == order.SELL:
|
||||
# sell all amount at last
|
||||
if self.trade_amount[(order.stock_id, order.direction)] >= 1e-5 and (
|
||||
_order_amount is None or trade_step == trade_len - 1
|
||||
if self.trade_amount[order.stock_id] > 1e-5 and (
|
||||
_order_amount < 1e-5 or trade_step == trade_len - 1
|
||||
):
|
||||
_order_amount = self.trade_amount[(order.stock_id, order.direction)]
|
||||
_order_amount = self.trade_amount[order.stock_id]
|
||||
|
||||
if _order_amount:
|
||||
_order_amount = min(_order_amount, self.trade_amount[(order.stock_id, order.direction)])
|
||||
_order_amount = min(_order_amount, self.trade_amount[order.stock_id])
|
||||
|
||||
if _order_amount > 1e-5:
|
||||
if trade_step % 2 == 0:
|
||||
# in the first one of two adjacent bars
|
||||
# if look short on the price, sell the stock more
|
||||
@@ -318,9 +373,9 @@ class SBBStrategyBase(BaseStrategy):
|
||||
|
||||
if trade_step % 2 == 0:
|
||||
# in the first one of two adjacent bars, store the trend for the second one to use
|
||||
self.trade_trend[(order.stock_id, order.direction)] = _pred_trend
|
||||
self.trade_trend[order.stock_id] = _pred_trend
|
||||
|
||||
return order_list
|
||||
return TradeDecisionWO(order_list, self)
|
||||
|
||||
|
||||
class SBBStrategyEMA(SBBStrategyBase):
|
||||
@@ -328,9 +383,14 @@ class SBBStrategyEMA(SBBStrategyBase):
|
||||
(S)elect the (B)etter one among every two adjacent trading (B)ars to sell or buy with (EMA) signal.
|
||||
"""
|
||||
|
||||
# TODO:
|
||||
# 1. Supporting leverage the get_range_limit result from the decision
|
||||
# 2. Supporting alter_outer_trade_decision
|
||||
# 3. Supporting checking the availability of trade decision
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
outer_trade_decision: List[Order] = None,
|
||||
outer_trade_decision: BaseTradeDecision = None,
|
||||
instruments: Union[List, str] = "csi300",
|
||||
freq: str = "day",
|
||||
trade_exchange: Exchange = None,
|
||||
@@ -399,6 +459,240 @@ class SBBStrategyEMA(SBBStrategyBase):
|
||||
# if EMA signal > 0, return long trend
|
||||
elif _sample_signal.iloc[0] > 0:
|
||||
return self.TREND_LONG
|
||||
# if EMA signal > 0, return short trend
|
||||
# if EMA signal < 0, return short trend
|
||||
else:
|
||||
return self.TREND_SHORT
|
||||
|
||||
|
||||
class ACStrategy(BaseStrategy):
|
||||
# TODO:
|
||||
# 1. Supporting leverage the get_range_limit result from the decision
|
||||
# 2. Supporting alter_outer_trade_decision
|
||||
# 3. Supporting checking the availability of trade decision
|
||||
def __init__(
|
||||
self,
|
||||
lamb: float = 1e-6,
|
||||
eta: float = 2.5e-6,
|
||||
window_size: int = 20,
|
||||
outer_trade_decision: BaseTradeDecision = None,
|
||||
instruments: Union[List, str] = "csi300",
|
||||
freq: str = "day",
|
||||
trade_exchange: Exchange = None,
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
instruments : Union[List, str], optional
|
||||
instruments of Volatility, by default "csi300"
|
||||
freq : str, optional
|
||||
freq of Volatility, by default "day"
|
||||
Note: `freq` may be different from `time_per_step`
|
||||
"""
|
||||
self.lamb = lamb
|
||||
self.eta = eta
|
||||
self.window_size = window_size
|
||||
if instruments is None:
|
||||
warnings.warn("`instruments` is not set, will load all stocks")
|
||||
self.instruments = "all"
|
||||
if isinstance(instruments, str):
|
||||
self.instruments = D.instruments(instruments)
|
||||
self.freq = freq
|
||||
super(ACStrategy, self).__init__(outer_trade_decision, level_infra, common_infra, **kwargs)
|
||||
|
||||
if trade_exchange is not None:
|
||||
self.trade_exchange = trade_exchange
|
||||
|
||||
def _reset_signal(self):
|
||||
trade_len = self.trade_calendar.get_trade_len()
|
||||
fields = [
|
||||
f"Power(Sum(Power(Log($close/Ref($close, 1)), 2), {self.window_size})/{self.window_size - 1}-Power(Sum(Log($close/Ref($close, 1)), {self.window_size}), 2)/({self.window_size}*{self.window_size - 1}), 0.5)"
|
||||
]
|
||||
signal_start_time, _ = self.trade_calendar.get_step_time(trade_step=0, shift=1)
|
||||
_, signal_end_time = self.trade_calendar.get_step_time(trade_step=trade_len - 1, shift=1)
|
||||
signal_df = D.features(
|
||||
self.instruments, fields, start_time=signal_start_time, end_time=signal_end_time, freq=self.freq
|
||||
)
|
||||
signal_df = convert_index_format(signal_df)
|
||||
signal_df.columns = ["volatility"]
|
||||
self.signal = {}
|
||||
|
||||
if not signal_df.empty:
|
||||
for stock_id, stock_val in signal_df.groupby(level="instrument"):
|
||||
self.signal[stock_id] = stock_val
|
||||
|
||||
def reset_common_infra(self, common_infra):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
common_infra : CommonInfrastructure, optional
|
||||
common infrastructure for backtesting, by default None
|
||||
- It should include `trade_account`, used to get position
|
||||
- It should include `trade_exchange`, used to provide market info
|
||||
"""
|
||||
super(ACStrategy, self).reset_common_infra(common_infra)
|
||||
|
||||
if common_infra.has("trade_exchange"):
|
||||
self.trade_exchange = common_infra.get("trade_exchange")
|
||||
|
||||
def reset_level_infra(self, level_infra):
|
||||
"""
|
||||
reset level-shared infra
|
||||
- After reset the trade calendar, the signal will be changed
|
||||
"""
|
||||
if not hasattr(self, "level_infra"):
|
||||
self.level_infra = level_infra
|
||||
else:
|
||||
self.level_infra.update(level_infra)
|
||||
|
||||
if level_infra.has("trade_calendar"):
|
||||
self.trade_calendar = level_infra.get("trade_calendar")
|
||||
self._reset_signal()
|
||||
|
||||
def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
outer_trade_decision : BaseTradeDecision, optional
|
||||
"""
|
||||
super(ACStrategy, self).reset(outer_trade_decision=outer_trade_decision, **kwargs)
|
||||
if outer_trade_decision is not None:
|
||||
self.trade_amount = {}
|
||||
# init the trade amount of order and predicted trade trend
|
||||
for order in outer_trade_decision.get_decision():
|
||||
self.trade_amount[order.stock_id] = order.amount
|
||||
|
||||
def generate_trade_decision(self, execute_result=None):
|
||||
# get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1]
|
||||
trade_step = self.trade_calendar.get_trade_step()
|
||||
# get the total count of trading step
|
||||
trade_len = self.trade_calendar.get_trade_len()
|
||||
|
||||
# update the order amount
|
||||
if execute_result is not None:
|
||||
for order, _, _, _ in execute_result:
|
||||
self.trade_amount[order.stock_id] -= order.deal_amount
|
||||
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
|
||||
pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)
|
||||
order_list = []
|
||||
for order in self.outer_trade_decision.get_decision():
|
||||
# if not tradable, continue
|
||||
if not self.trade_exchange.is_stock_tradable(
|
||||
stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time
|
||||
):
|
||||
continue
|
||||
_order_amount = None
|
||||
# considering trade unit
|
||||
|
||||
sig_sam = (
|
||||
resam_ts_data(self.signal[order.stock_id]["volatility"], pred_start_time, pred_end_time, method="last")
|
||||
if order.stock_id in self.signal
|
||||
else None
|
||||
)
|
||||
|
||||
if sig_sam is None or sig_sam.iloc[0] is None:
|
||||
# no signal, TWAP
|
||||
_amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(order.factor)
|
||||
if _amount_trade_unit is None:
|
||||
# divide the order into equal parts, and trade one part
|
||||
_order_amount = self.trade_amount[order.stock_id] / (trade_len - trade_step)
|
||||
else:
|
||||
# divide the order into equal parts, and trade one part
|
||||
# calculate the total count of trade units to trade
|
||||
trade_unit_cnt = int(self.trade_amount[order.stock_id] // _amount_trade_unit)
|
||||
# calculate the amount of one part, ceil the amount
|
||||
# floor((trade_unit_cnt + trade_len - trade_step - 1) / (trade_len - trade_step)) == ceil(trade_unit_cnt / (trade_len - trade_step))
|
||||
_order_amount = (
|
||||
(trade_unit_cnt + trade_len - trade_step - 1) // (trade_len - trade_step) * _amount_trade_unit
|
||||
)
|
||||
else:
|
||||
# VA strategy
|
||||
kappa_tild = self.lamb / self.eta * sig_sam.iloc[0] * sig_sam.iloc[0]
|
||||
kappa = np.arccosh(kappa_tild / 2 + 1)
|
||||
amount_ratio = (
|
||||
np.sinh(kappa * (trade_len - trade_step)) - np.sinh(kappa * (trade_len - trade_step - 1))
|
||||
) / np.sinh(kappa * trade_len)
|
||||
_order_amount = order.amount * amount_ratio
|
||||
_order_amount = self.trade_exchange.round_amount_by_trade_unit(_order_amount, order.factor)
|
||||
|
||||
if order.direction == order.SELL:
|
||||
# sell all amount at last
|
||||
if self.trade_amount[order.stock_id] > 1e-5 and (_order_amount < 1e-5 or trade_step == trade_len - 1):
|
||||
_order_amount = self.trade_amount[order.stock_id]
|
||||
|
||||
_order_amount = min(_order_amount, self.trade_amount[order.stock_id])
|
||||
|
||||
if _order_amount > 1e-5:
|
||||
|
||||
_order = Order(
|
||||
stock_id=order.stock_id,
|
||||
amount=_order_amount,
|
||||
start_time=trade_start_time,
|
||||
end_time=trade_end_time,
|
||||
direction=order.direction, # 1 for buy
|
||||
factor=order.factor,
|
||||
)
|
||||
order_list.append(_order)
|
||||
return TradeDecisionWO(order_list, self)
|
||||
|
||||
|
||||
class RandomOrderStrategy(BaseStrategy):
|
||||
def __init__(
|
||||
self,
|
||||
index_range: Tuple[int, int], # The range is closed on both left and right.
|
||||
sample_ratio: float = 1.0,
|
||||
volume_ratio: float = 0.01,
|
||||
market: str = "all",
|
||||
direction: int = Order.BUY,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
index_range : Tuple
|
||||
the intra day time index range of the orders
|
||||
the left and right is closed.
|
||||
# TODO: this is a index_range level limitation. We'll implement a more detailed limitation later.
|
||||
sample_ratio : float
|
||||
the ratio of all orders are sampled
|
||||
volume_ratio : float
|
||||
the volume of the total day
|
||||
raito of the total volume of a specific day
|
||||
market : str
|
||||
stock pool for sampling
|
||||
"""
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.index_range = index_range
|
||||
self.sample_ratio = sample_ratio
|
||||
self.volume_ratio = volume_ratio
|
||||
self.market = market
|
||||
self.direction = direction
|
||||
exch: Exchange = self.common_infra.get("trade_exchange")
|
||||
# TODO: this can't be online
|
||||
self.volume = D.features(
|
||||
D.instruments(market), ["Mean(Ref($volume, 1), 10)"], start_time=exch.start_time, end_time=exch.end_time
|
||||
)
|
||||
self.volume_df = self.volume.iloc[:, 0].unstack()
|
||||
|
||||
def generate_trade_decision(self, execute_result=None):
|
||||
trade_step = self.trade_calendar.get_trade_step()
|
||||
step_time_start, step_time_end = self.trade_calendar.get_step_time(trade_step)
|
||||
|
||||
order_list = []
|
||||
if step_time_start in self.volume_df:
|
||||
for stock_id, volume in self.volume_df[step_time_start].dropna().sample(frac=self.sample_ratio).items():
|
||||
order_list.append(
|
||||
self.common_infra.get("trade_exchange").create_order(
|
||||
code=stock_id,
|
||||
amount=volume * self.volume_ratio,
|
||||
start_time=step_time_start,
|
||||
end_time=step_time_end,
|
||||
direction=self.direction,
|
||||
)
|
||||
)
|
||||
return TradeDecisionWO(order_list, self, self.index_range)
|
||||
|
||||
@@ -15,6 +15,7 @@ import bisect
|
||||
import logging
|
||||
import importlib
|
||||
import traceback
|
||||
from typing import List, Union
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from multiprocessing import Pool
|
||||
@@ -65,7 +66,6 @@ class CalendarProvider(abc.ABC, ProviderBackendMixin):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.backend = kwargs.get("backend", {})
|
||||
|
||||
@abc.abstractmethod
|
||||
def calendar(self, start_time=None, end_time=None, freq="day", freq_sam=None, future=False):
|
||||
"""Get calendar of certain market in given time range.
|
||||
|
||||
@@ -87,7 +87,22 @@ class CalendarProvider(abc.ABC, ProviderBackendMixin):
|
||||
list
|
||||
calendar list
|
||||
"""
|
||||
raise NotImplementedError("Subclass of CalendarProvider must implement `calendar` method")
|
||||
_calendar, _ = self._get_calendar(freq=freq, freq_sam=freq_sam, future=future)
|
||||
# strip
|
||||
if start_time:
|
||||
start_time = pd.Timestamp(start_time)
|
||||
if start_time > _calendar[-1]:
|
||||
return np.array([])
|
||||
else:
|
||||
start_time = _calendar[0]
|
||||
if end_time:
|
||||
end_time = pd.Timestamp(end_time)
|
||||
if end_time < _calendar[0]:
|
||||
return np.array([])
|
||||
else:
|
||||
end_time = _calendar[-1]
|
||||
st, et, si, ei = self.locate_index(start_time, end_time, freq=freq, freq_sam=freq_sam, future=future)
|
||||
return _calendar[si : ei + 1]
|
||||
|
||||
def locate_index(self, start_time, end_time, freq, freq_sam=None, future=False):
|
||||
"""Locate the start time index and end time index in a calendar under certain frequency.
|
||||
@@ -172,6 +187,21 @@ class CalendarProvider(abc.ABC, ProviderBackendMixin):
|
||||
"""Get the uri of calendar generation task."""
|
||||
return hash_args(start_time, end_time, freq, future)
|
||||
|
||||
def load_calendar(self, freq, future):
|
||||
"""Load original calendar timestamp from file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
freq : str
|
||||
frequency of read calendar file.
|
||||
|
||||
Returns
|
||||
----------
|
||||
list
|
||||
list of timestamps
|
||||
"""
|
||||
raise NotImplementedError("Subclass of CalendarProvider must implement `load_calendar` method")
|
||||
|
||||
|
||||
class InstrumentProvider(abc.ABC, ProviderBackendMixin):
|
||||
"""Instrument provider base class
|
||||
@@ -183,19 +213,22 @@ class InstrumentProvider(abc.ABC, ProviderBackendMixin):
|
||||
self.backend = kwargs.get("backend", {})
|
||||
|
||||
@staticmethod
|
||||
def instruments(market="all", filter_pipe=None):
|
||||
def instruments(market: Union[List, str] = "all", filter_pipe: Union[List, None] = None):
|
||||
"""Get the general config dictionary for a base market adding several dynamic filters.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
market : str
|
||||
market/industry/index shortname, e.g. all/sse/szse/sse50/csi300/csi500.
|
||||
market : Union[List, str]
|
||||
str:
|
||||
market/industry/index shortname, e.g. all/sse/szse/sse50/csi300/csi500.
|
||||
list:
|
||||
["ID1", "ID2"]. A list of stocks
|
||||
filter_pipe : list
|
||||
the list of dynamic filters.
|
||||
|
||||
Returns
|
||||
----------
|
||||
dict
|
||||
dict: if insinstance(market, str)
|
||||
dict of stockpool config.
|
||||
{`market`=>base market name, `filter_pipe`=>list of filters}
|
||||
|
||||
@@ -213,7 +246,13 @@ class InstrumentProvider(abc.ABC, ProviderBackendMixin):
|
||||
'name_rule_re': 'SH[0-9]{4}55',
|
||||
'filter_start_time': None,
|
||||
'filter_end_time': None}]}
|
||||
|
||||
list: if insinstance(market, list)
|
||||
just return the original list directly.
|
||||
NOTE: this will make the instruments compatible with more cases. The user code will be simpler.
|
||||
"""
|
||||
if isinstance(market, list):
|
||||
return market
|
||||
if filter_pipe is None:
|
||||
filter_pipe = []
|
||||
config = {"market": market, "filter_pipe": []}
|
||||
@@ -457,7 +496,8 @@ class DatasetProvider(abc.ABC):
|
||||
normalize_column_names = normalize_cache_fields(column_names)
|
||||
data = dict()
|
||||
# One process for one task, so that the memory will be freed quicker.
|
||||
workers = min(C.kernels, len(instruments_d))
|
||||
workers = max(min(C.kernels, len(instruments_d)), 1)
|
||||
|
||||
if C.maxtasksperchild is None:
|
||||
p = Pool(processes=workers)
|
||||
else:
|
||||
@@ -504,7 +544,9 @@ class DatasetProvider(abc.ABC):
|
||||
data = pd.concat(new_data, names=["instrument"], sort=False)
|
||||
data = DiskDatasetCache.cache_to_origin_data(data, column_names)
|
||||
else:
|
||||
data = pd.DataFrame(columns=column_names)
|
||||
data = pd.DataFrame(
|
||||
index=pd.MultiIndex.from_arrays([[], []], names=("instrument", "datetime")), columns=column_names
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@@ -558,19 +600,6 @@ class LocalCalendarProvider(CalendarProvider):
|
||||
return os.path.join(C.get_data_path(), "calendars", "{}.txt")
|
||||
|
||||
def load_calendar(self, freq, future):
|
||||
"""Load original calendar timestamp from file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
freq : str
|
||||
frequency of read calendar file.
|
||||
|
||||
Returns
|
||||
----------
|
||||
list
|
||||
list of timestamps
|
||||
"""
|
||||
|
||||
try:
|
||||
backend_obj = self.backend_obj(freq=freq, future=future).data
|
||||
except ValueError:
|
||||
@@ -587,24 +616,6 @@ class LocalCalendarProvider(CalendarProvider):
|
||||
|
||||
return [pd.Timestamp(x) for x in backend_obj]
|
||||
|
||||
def calendar(self, start_time=None, end_time=None, freq="day", freq_sam=None, future=False):
|
||||
_calendar, _ = self._get_calendar(freq=freq, freq_sam=freq_sam, future=future)
|
||||
# strip
|
||||
if start_time:
|
||||
start_time = pd.Timestamp(start_time)
|
||||
if start_time > _calendar[-1]:
|
||||
return np.array([])
|
||||
else:
|
||||
start_time = _calendar[0]
|
||||
if end_time:
|
||||
end_time = pd.Timestamp(end_time)
|
||||
if end_time < _calendar[0]:
|
||||
return np.array([])
|
||||
else:
|
||||
end_time = _calendar[-1]
|
||||
st, et, si, ei = self.locate_index(start_time, end_time, freq=freq, freq_sam=freq_sam, future=future)
|
||||
return _calendar[si : ei + 1]
|
||||
|
||||
|
||||
class LocalInstrumentProvider(InstrumentProvider):
|
||||
"""Local instrument data provider class
|
||||
@@ -719,7 +730,9 @@ class LocalDatasetProvider(DatasetProvider):
|
||||
column_names = self.get_column_names(fields)
|
||||
cal = Cal.calendar(start_time, end_time, freq)
|
||||
if len(cal) == 0:
|
||||
return pd.DataFrame(columns=column_names)
|
||||
return pd.DataFrame(
|
||||
index=pd.MultiIndex.from_arrays([[], []], names=("instrument", "datetime")), columns=column_names
|
||||
)
|
||||
start_time = cal[0]
|
||||
end_time = cal[-1]
|
||||
|
||||
@@ -741,7 +754,7 @@ class LocalDatasetProvider(DatasetProvider):
|
||||
return
|
||||
start_time = cal[0]
|
||||
end_time = cal[-1]
|
||||
workers = min(C.kernels, len(instruments_d))
|
||||
workers = max(min(C.kernels, len(instruments_d)), 1)
|
||||
if C.maxtasksperchild is None:
|
||||
p = Pool(processes=workers)
|
||||
else:
|
||||
@@ -789,7 +802,7 @@ class ClientCalendarProvider(CalendarProvider):
|
||||
def calendar(self, start_time=None, end_time=None, freq="day", freq_sam=None, future=False):
|
||||
|
||||
self.conn.send_request(
|
||||
request_type="trade_calendar",
|
||||
request_type="calendar",
|
||||
request_content={
|
||||
"start_time": str(start_time),
|
||||
"end_time": str(end_time),
|
||||
@@ -902,7 +915,10 @@ class ClientDatasetProvider(DatasetProvider):
|
||||
column_names = self.get_column_names(fields)
|
||||
cal = Cal.calendar(start_time, end_time, freq)
|
||||
if len(cal) == 0:
|
||||
return pd.DataFrame(columns=column_names)
|
||||
return pd.DataFrame(
|
||||
index=pd.MultiIndex.from_arrays([[], []], names=("instrument", "datetime")),
|
||||
columns=column_names,
|
||||
)
|
||||
start_time = cal[0]
|
||||
end_time = cal[-1]
|
||||
|
||||
@@ -1004,7 +1020,7 @@ class LocalProvider(BaseProvider):
|
||||
:param type: The type of resource for the uri
|
||||
:param **kwargs:
|
||||
"""
|
||||
if type == "trade_calendar":
|
||||
if type == "calendar":
|
||||
return Cal._uri(**kwargs)
|
||||
elif type == "instrument":
|
||||
return Inst._uri(**kwargs)
|
||||
|
||||
@@ -68,7 +68,7 @@ def get_module_logger(module_name, level: Optional[int] = None) -> logging.Logge
|
||||
|
||||
class TimeInspector:
|
||||
|
||||
timer_logger = get_module_logger("timer", level=logging.WARNING)
|
||||
timer_logger = get_module_logger("timer", level=logging.INFO)
|
||||
|
||||
time_marks = []
|
||||
|
||||
|
||||
@@ -12,9 +12,11 @@ In ``DelayTrainer``, the first step is only to save some necessary info to model
|
||||
"""
|
||||
|
||||
import socket
|
||||
import time
|
||||
from typing import Callable, List
|
||||
|
||||
from qlib.data.dataset import Dataset
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.model.base import Model
|
||||
from qlib.utils import flatten_dict, get_cls_kwargs, init_instance_by_config
|
||||
from qlib.workflow import R
|
||||
@@ -190,6 +192,8 @@ class TrainerR(Trainer):
|
||||
Returns:
|
||||
List[Recorder]: a list of Recorders
|
||||
"""
|
||||
if isinstance(tasks, dict):
|
||||
tasks = [tasks]
|
||||
if len(tasks) == 0:
|
||||
return []
|
||||
if train_func is None:
|
||||
@@ -213,6 +217,8 @@ class TrainerR(Trainer):
|
||||
Returns:
|
||||
List[Recorder]: the same list as the param.
|
||||
"""
|
||||
if isinstance(recs, Recorder):
|
||||
recs = [recs]
|
||||
for rec in recs:
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
|
||||
return recs
|
||||
@@ -250,6 +256,8 @@ class DelayTrainerR(TrainerR):
|
||||
Returns:
|
||||
List[Recorder]: a list of Recorders
|
||||
"""
|
||||
if isinstance(recs, Recorder):
|
||||
recs = [recs]
|
||||
if end_train_func is None:
|
||||
end_train_func = self.end_train_func
|
||||
if experiment_name is None:
|
||||
@@ -275,6 +283,9 @@ class TrainerRM(Trainer):
|
||||
STATUS_BEGIN = "begin_task_train"
|
||||
STATUS_END = "end_task_train"
|
||||
|
||||
# This tag is the _id in TaskManager to distinguish tasks.
|
||||
TM_ID = "_id in TaskManager"
|
||||
|
||||
def __init__(self, experiment_name: str = None, task_pool: str = None, train_func=task_train):
|
||||
"""
|
||||
Init TrainerR.
|
||||
@@ -315,6 +326,8 @@ class TrainerRM(Trainer):
|
||||
Returns:
|
||||
List[Recorder]: a list of Recorders
|
||||
"""
|
||||
if isinstance(tasks, dict):
|
||||
tasks = [tasks]
|
||||
if len(tasks) == 0:
|
||||
return []
|
||||
if train_func is None:
|
||||
@@ -326,19 +339,25 @@ class TrainerRM(Trainer):
|
||||
task_pool = experiment_name
|
||||
tm = TaskManager(task_pool=task_pool)
|
||||
_id_list = tm.create_task(tasks) # all tasks will be saved to MongoDB
|
||||
query = {"_id": {"$in": _id_list}}
|
||||
run_task(
|
||||
train_func,
|
||||
task_pool,
|
||||
query=query, # only train these tasks
|
||||
experiment_name=experiment_name,
|
||||
before_status=before_status,
|
||||
after_status=after_status,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if not self.is_delay():
|
||||
tm.wait(query=query)
|
||||
|
||||
recs = []
|
||||
for _id in _id_list:
|
||||
rec = tm.re_query(_id)["res"]
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN})
|
||||
rec.set_tags(**{self.TM_ID: _id})
|
||||
recs.append(rec)
|
||||
return recs
|
||||
|
||||
@@ -352,10 +371,33 @@ class TrainerRM(Trainer):
|
||||
Returns:
|
||||
List[Recorder]: the same list as the param.
|
||||
"""
|
||||
if isinstance(recs, Recorder):
|
||||
recs = [recs]
|
||||
for rec in recs:
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
|
||||
return recs
|
||||
|
||||
def worker(
|
||||
self,
|
||||
train_func: Callable = None,
|
||||
experiment_name: str = None,
|
||||
):
|
||||
"""
|
||||
The multiprocessing method for `train`. It can share a same task_pool with `train` and can run in other progress or other machines.
|
||||
|
||||
Args:
|
||||
train_func (Callable): the training method which needs at least `task`s and `experiment_name`. None for the default training method.
|
||||
experiment_name (str): the experiment name, None for use default name.
|
||||
"""
|
||||
if train_func is None:
|
||||
train_func = self.train_func
|
||||
if experiment_name is None:
|
||||
experiment_name = self.experiment_name
|
||||
task_pool = self.task_pool
|
||||
if task_pool is None:
|
||||
task_pool = experiment_name
|
||||
run_task(train_func, task_pool=task_pool, experiment_name=experiment_name)
|
||||
|
||||
|
||||
class DelayTrainerRM(TrainerRM):
|
||||
"""
|
||||
@@ -395,6 +437,8 @@ class DelayTrainerRM(TrainerRM):
|
||||
Returns:
|
||||
List[Recorder]: a list of Recorders
|
||||
"""
|
||||
if isinstance(tasks, dict):
|
||||
tasks = [tasks]
|
||||
if len(tasks) == 0:
|
||||
return []
|
||||
return super().train(
|
||||
@@ -410,8 +454,6 @@ class DelayTrainerRM(TrainerRM):
|
||||
Given a list of Recorder and return a list of trained Recorder.
|
||||
This class will finish real data loading and model fitting.
|
||||
|
||||
NOTE: This method will train all STATUS_PART_DONE tasks in the task pool, not only the ``recs``.
|
||||
|
||||
Args:
|
||||
recs (list): a list of Recorder, the tasks have been saved to them.
|
||||
end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
|
||||
@@ -421,7 +463,8 @@ class DelayTrainerRM(TrainerRM):
|
||||
Returns:
|
||||
List[Recorder]: a list of Recorders
|
||||
"""
|
||||
|
||||
if isinstance(recs, Recorder):
|
||||
recs = [recs]
|
||||
if end_train_func is None:
|
||||
end_train_func = self.end_train_func
|
||||
if experiment_name is None:
|
||||
@@ -429,18 +472,44 @@ class DelayTrainerRM(TrainerRM):
|
||||
task_pool = self.task_pool
|
||||
if task_pool is None:
|
||||
task_pool = experiment_name
|
||||
tasks = []
|
||||
_id_list = []
|
||||
for rec in recs:
|
||||
tasks.append(rec.load_object("task"))
|
||||
_id_list.append(rec.list_tags()[self.TM_ID])
|
||||
|
||||
query = {"_id": {"$in": _id_list}}
|
||||
run_task(
|
||||
end_train_func,
|
||||
task_pool,
|
||||
query={"filter": {"$in": tasks}}, # only train these tasks
|
||||
query=query, # only train these tasks
|
||||
experiment_name=experiment_name,
|
||||
before_status=TaskManager.STATUS_PART_DONE,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
TaskManager(task_pool=task_pool).wait(query=query)
|
||||
|
||||
for rec in recs:
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_END})
|
||||
return recs
|
||||
|
||||
def worker(self, end_train_func=None, experiment_name: str = None):
|
||||
"""
|
||||
The multiprocessing method for `end_train`. It can share a same task_pool with `end_train` and can run in other progress or other machines.
|
||||
|
||||
Args:
|
||||
end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func.
|
||||
experiment_name (str): the experiment name, None for use default name.
|
||||
"""
|
||||
if end_train_func is None:
|
||||
end_train_func = self.end_train_func
|
||||
if experiment_name is None:
|
||||
experiment_name = self.experiment_name
|
||||
task_pool = self.task_pool
|
||||
if task_pool is None:
|
||||
task_pool = experiment_name
|
||||
run_task(
|
||||
end_train_func,
|
||||
task_pool=task_pool,
|
||||
experiment_name=experiment_name,
|
||||
before_status=TaskManager.STATUS_PART_DONE,
|
||||
)
|
||||
|
||||
@@ -5,14 +5,14 @@
|
||||
class BaseInterpreter:
|
||||
"""Base Interpreter"""
|
||||
|
||||
def interpret(**kwargs):
|
||||
def interpret(self, **kwargs):
|
||||
raise NotImplementedError("interpret is not implemented!")
|
||||
|
||||
|
||||
class ActionInterpreter(BaseInterpreter):
|
||||
"""Action Interpreter that interpret rl agent action into qlib orders"""
|
||||
|
||||
def interpret(action, **kwargs):
|
||||
def interpret(self, action, **kwargs):
|
||||
"""interpret method
|
||||
|
||||
Parameters
|
||||
@@ -32,7 +32,7 @@ class ActionInterpreter(BaseInterpreter):
|
||||
class StateInterpreter(BaseInterpreter):
|
||||
"""State Interpreter that interpret execution result of qlib executor into rl env state"""
|
||||
|
||||
def interpret(execute_result, **kwargs):
|
||||
def interpret(self, execute_result, **kwargs):
|
||||
"""interpret method
|
||||
|
||||
Parameters
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from typing import Union
|
||||
from typing import List, Union
|
||||
|
||||
from ..model.base import BaseModel
|
||||
from ..data.dataset import DatasetH
|
||||
from ..data.dataset.utils import convert_index_format
|
||||
from ..rl.interpreter import ActionInterpreter, StateInterpreter
|
||||
from ..utils import init_instance_by_config
|
||||
from ..backtest.utils import CommonInfrastructure, LevelInfrastructure
|
||||
from ..backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager
|
||||
from ..backtest.order import BaseTradeDecision
|
||||
|
||||
__all__ = ['BaseStrategy', 'ModelStrategy', 'RLStrategy', 'RLIntStrategy']
|
||||
|
||||
@@ -17,16 +18,16 @@ class BaseStrategy:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
outer_trade_decision: object = None,
|
||||
outer_trade_decision: BaseTradeDecision = None,
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
outer_trade_decision : object, optional
|
||||
the trade decison of outer strategy which this startegy relies, and it will be traded in [start_time, end_time], by default None
|
||||
- If the strategy is used to split trade decison, it will be used
|
||||
outer_trade_decision : BaseTradeDecision, optional
|
||||
the trade decision of outer strategy which this startegy relies, and it will be traded in [start_time, end_time], by default None
|
||||
- If the strategy is used to split trade decision, it will be used
|
||||
- If the strategy is used for portfolio management, it can be ignored
|
||||
level_infra : LevelInfrastructure, optional
|
||||
level shared infrastructure for backtesting, including trade calendar
|
||||
@@ -36,18 +37,18 @@ class BaseStrategy:
|
||||
|
||||
self.reset(level_infra=level_infra, common_infra=common_infra, outer_trade_decision=outer_trade_decision)
|
||||
|
||||
def reset_level_infra(self, level_infra):
|
||||
def reset_level_infra(self, level_infra: LevelInfrastructure):
|
||||
if not hasattr(self, "level_infra"):
|
||||
self.level_infra = level_infra
|
||||
else:
|
||||
self.level_infra.update(level_infra)
|
||||
|
||||
if level_infra.has("trade_calendar"):
|
||||
self.trade_calendar = level_infra.get("trade_calendar")
|
||||
self.trade_calendar: TradeCalendarManager = level_infra.get("trade_calendar")
|
||||
|
||||
def reset_common_infra(self, common_infra):
|
||||
def reset_common_infra(self, common_infra: CommonInfrastructure):
|
||||
if not hasattr(self, "common_infra"):
|
||||
self.common_infra = common_infra
|
||||
self.common_infra: CommonInfrastructure = common_infra
|
||||
else:
|
||||
self.common_infra.update(common_infra)
|
||||
|
||||
@@ -64,7 +65,7 @@ class BaseStrategy:
|
||||
"""
|
||||
- reset `level_infra`, used to reset trade calendar, .etc
|
||||
- reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc
|
||||
- reset `outer_trade_decision`, used to make split decison
|
||||
- reset `outer_trade_decision`, used to make split decision
|
||||
"""
|
||||
if level_infra is not None:
|
||||
self.reset_level_infra(level_infra)
|
||||
@@ -81,11 +82,45 @@ class BaseStrategy:
|
||||
Parameters
|
||||
----------
|
||||
execute_result : List[object], optional
|
||||
the executed result for trade decison, by default None
|
||||
the executed result for trade decision, by default None
|
||||
- When call the generate_trade_decision firstly, `execute_result` could be None
|
||||
"""
|
||||
raise NotImplementedError("generate_trade_decision is not implemented!")
|
||||
|
||||
def update_trade_decision(
|
||||
self, trade_decision: BaseTradeDecision, trade_calendar: TradeCalendarManager
|
||||
) -> Union[BaseTradeDecision, None]:
|
||||
"""
|
||||
update trade decision in each step of inner execution, this method enable all order
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trade_decision : BaseTradeDecision
|
||||
the trade decision that will be updated
|
||||
trade_calendar : TradeCalendarManager
|
||||
The calendar of the **inner strategy**!!!!!
|
||||
|
||||
Returns
|
||||
-------
|
||||
BaseTradeDecision:
|
||||
"""
|
||||
# default to return None, which indicates that the trade decision is not changed
|
||||
return None
|
||||
|
||||
def alter_outer_trade_decision(self, outer_trade_decision: BaseTradeDecision):
|
||||
"""
|
||||
A method for updating the outer_trade_decision.
|
||||
The outer strategy may change its decision during updating.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
outer_trade_decision : BaseTradeDecision
|
||||
the decision updated by the outer strategy
|
||||
"""
|
||||
# default to reset the decision directly
|
||||
# NOTE: normally, user should do something to the strategy due to the change of outer decision
|
||||
raise NotImplementedError(f"Please implement the `alter_outer_trade_decision` method")
|
||||
|
||||
|
||||
class ModelStrategy(BaseStrategy):
|
||||
"""Model-based trading strategy, use model to make predictions for trading"""
|
||||
@@ -94,7 +129,7 @@ class ModelStrategy(BaseStrategy):
|
||||
self,
|
||||
model: BaseModel,
|
||||
dataset: DatasetH,
|
||||
outer_trade_decision: object = None,
|
||||
outer_trade_decision: BaseTradeDecision = None,
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
**kwargs,
|
||||
@@ -130,7 +165,7 @@ class RLStrategy(BaseStrategy):
|
||||
def __init__(
|
||||
self,
|
||||
policy,
|
||||
outer_trade_decision: object = None,
|
||||
outer_trade_decision: BaseTradeDecision = None,
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
**kwargs,
|
||||
@@ -153,7 +188,7 @@ class RLIntStrategy(RLStrategy):
|
||||
policy,
|
||||
state_interpreter: Union[dict, StateInterpreter],
|
||||
action_interpreter: Union[dict, ActionInterpreter],
|
||||
outer_trade_decision: object = None,
|
||||
outer_trade_decision: BaseTradeDecision = None,
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
**kwargs,
|
||||
@@ -177,7 +212,7 @@ class RLIntStrategy(RLStrategy):
|
||||
self.action_interpreter = init_instance_by_config(action_interpreter, accept_types=ActionInterpreter)
|
||||
|
||||
def generate_trade_decision(self, execute_result=None):
|
||||
_interpret_state = self.state_interpretor.interpret(execute_result=execute_result)
|
||||
_interpret_state = self.state_interpreter.interpret(execute_result=execute_result)
|
||||
_action = self.policy.step(_interpret_state)
|
||||
_trade_decision = self.action_interpreter.interpret(action=_action)
|
||||
return _trade_decision
|
||||
|
||||
@@ -43,17 +43,29 @@ RECORD_CONFIG = [
|
||||
]
|
||||
|
||||
|
||||
def get_data_handler_config(market=CSI300_MARKET):
|
||||
def get_data_handler_config(
|
||||
start_time="2008-01-01",
|
||||
end_time="2020-08-01",
|
||||
fit_start_time="2008-01-01",
|
||||
fit_end_time="2014-12-31",
|
||||
instruments=CSI300_MARKET,
|
||||
):
|
||||
return {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": market,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"fit_start_time": fit_start_time,
|
||||
"fit_end_time": fit_end_time,
|
||||
"instruments": instruments,
|
||||
}
|
||||
|
||||
|
||||
def get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA158_CLASS):
|
||||
def get_dataset_config(
|
||||
dataset_class=DATASET_ALPHA158_CLASS,
|
||||
train=("2008-01-01", "2014-12-31"),
|
||||
valid=("2015-01-01", "2016-12-31"),
|
||||
test=("2017-01-01", "2020-08-01"),
|
||||
handler_kwargs={"instruments": CSI300_MARKET},
|
||||
):
|
||||
return {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
@@ -61,48 +73,88 @@ def get_dataset_config(market=CSI300_MARKET, dataset_class=DATASET_ALPHA158_CLAS
|
||||
"handler": {
|
||||
"class": dataset_class,
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": get_data_handler_config(market),
|
||||
"kwargs": get_data_handler_config(**handler_kwargs),
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
"train": train,
|
||||
"valid": valid,
|
||||
"test": test,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_gbdt_task(market=CSI300_MARKET):
|
||||
def get_gbdt_task(dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET}):
|
||||
return {
|
||||
"model": GBDT_MODEL,
|
||||
"dataset": get_dataset_config(market),
|
||||
"dataset": get_dataset_config(**dataset_kwargs, handler_kwargs=handler_kwargs),
|
||||
}
|
||||
|
||||
|
||||
def get_record_lgb_config(market=CSI300_MARKET):
|
||||
def get_record_lgb_config(dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET}):
|
||||
return {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
},
|
||||
"dataset": get_dataset_config(market),
|
||||
"dataset": get_dataset_config(**dataset_kwargs, handler_kwargs=handler_kwargs),
|
||||
"record": RECORD_CONFIG,
|
||||
}
|
||||
|
||||
|
||||
def get_record_xgboost_config(market=CSI300_MARKET):
|
||||
def get_record_xgboost_config(dataset_kwargs={}, handler_kwargs={"instruments": CSI300_MARKET}):
|
||||
return {
|
||||
"model": {
|
||||
"class": "XGBModel",
|
||||
"module_path": "qlib.contrib.model.xgboost",
|
||||
},
|
||||
"dataset": get_dataset_config(market),
|
||||
"dataset": get_dataset_config(**dataset_kwargs, handler_kwargs=handler_kwargs),
|
||||
"record": RECORD_CONFIG,
|
||||
}
|
||||
|
||||
|
||||
CSI300_DATASET_CONFIG = get_dataset_config(market=CSI300_MARKET)
|
||||
CSI300_GBDT_TASK = get_gbdt_task(market=CSI300_MARKET)
|
||||
CSI300_DATASET_CONFIG = get_dataset_config(handler_kwargs={"instruments": CSI300_MARKET})
|
||||
CSI300_GBDT_TASK = get_gbdt_task(handler_kwargs={"instruments": CSI300_MARKET})
|
||||
|
||||
CSI100_RECORD_XGBOOST_TASK_CONFIG = get_record_xgboost_config(market=CSI100_MARKET)
|
||||
CSI100_RECORD_LGB_TASK_CONFIG = get_record_lgb_config(market=CSI100_MARKET)
|
||||
CSI100_RECORD_XGBOOST_TASK_CONFIG = get_record_xgboost_config(handler_kwargs={"instruments": CSI100_MARKET})
|
||||
CSI100_RECORD_LGB_TASK_CONFIG = get_record_lgb_config(handler_kwargs={"instruments": CSI100_MARKET})
|
||||
|
||||
# use for rolling_online_managment.py
|
||||
ROLLING_HANDLER_CONFIG = {
|
||||
"start_time": "2013-01-01",
|
||||
"end_time": "2020-09-25",
|
||||
"fit_start_time": "2013-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": CSI100_MARKET,
|
||||
}
|
||||
ROLLING_DATASET_CONFIG = {
|
||||
"train": ("2013-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2015-12-31"),
|
||||
"test": ("2016-01-01", "2020-07-10"),
|
||||
}
|
||||
CSI100_RECORD_XGBOOST_TASK_CONFIG_ROLLING = get_record_xgboost_config(
|
||||
dataset_kwargs=ROLLING_DATASET_CONFIG, handler_kwargs=ROLLING_HANDLER_CONFIG
|
||||
)
|
||||
CSI100_RECORD_LGB_TASK_CONFIG_ROLLING = get_record_lgb_config(
|
||||
dataset_kwargs=ROLLING_DATASET_CONFIG, handler_kwargs=ROLLING_HANDLER_CONFIG
|
||||
)
|
||||
|
||||
# use for online_management_simulate.py
|
||||
ONLINE_HANDLER_CONFIG = {
|
||||
"start_time": "2018-01-01",
|
||||
"end_time": "2018-10-31",
|
||||
"fit_start_time": "2018-01-01",
|
||||
"fit_end_time": "2018-03-31",
|
||||
"instruments": CSI100_MARKET,
|
||||
}
|
||||
ONLINE_DATASET_CONFIG = {
|
||||
"train": ("2018-01-01", "2018-03-31"),
|
||||
"valid": ("2018-04-01", "2018-05-31"),
|
||||
"test": ("2018-06-01", "2018-09-10"),
|
||||
}
|
||||
CSI100_RECORD_XGBOOST_TASK_CONFIG_ONLINE = get_record_xgboost_config(
|
||||
dataset_kwargs=ONLINE_DATASET_CONFIG, handler_kwargs=ONLINE_HANDLER_CONFIG
|
||||
)
|
||||
CSI100_RECORD_LGB_TASK_CONFIG_ONLINE = get_record_lgb_config(
|
||||
dataset_kwargs=ONLINE_DATASET_CONFIG, handler_kwargs=ONLINE_HANDLER_CONFIG
|
||||
)
|
||||
|
||||
@@ -7,52 +7,7 @@ from typing import Tuple, List, Union, Optional, Callable
|
||||
|
||||
from . import lazy_sort_index
|
||||
from ..config import C
|
||||
|
||||
|
||||
def parse_freq(freq: str) -> Tuple[int, str]:
|
||||
"""
|
||||
Parse freq into a unified format
|
||||
|
||||
Parameters
|
||||
----------
|
||||
freq : str
|
||||
Raw freq, supported freq should match the re '^([0-9]*)(month|mon|week|w|day|d|minute|min)$'
|
||||
|
||||
Returns
|
||||
-------
|
||||
freq: Tuple[int, str]
|
||||
Unified freq, including freq count and unified freq unit. The freq unit should be '[month|week|day|minute]'.
|
||||
Example:
|
||||
|
||||
.. code-block::
|
||||
|
||||
print(parse_freq("day"))
|
||||
(1, "day" )
|
||||
print(parse_freq("2mon"))
|
||||
(2, "month")
|
||||
print(parse_freq("10w"))
|
||||
(10, "week")
|
||||
|
||||
"""
|
||||
freq = freq.lower()
|
||||
match_obj = re.match("^([0-9]*)(month|mon|week|w|day|d|minute|min)$", freq)
|
||||
if match_obj is None:
|
||||
raise ValueError(
|
||||
"freq format is not supported, the freq should be like (n)month/mon, (n)week/w, (n)day/d, (n)minute/min"
|
||||
)
|
||||
_count = int(match_obj.group(1)) if match_obj.group(1) else 1
|
||||
_freq = match_obj.group(2)
|
||||
_freq_format_dict = {
|
||||
"month": "month",
|
||||
"mon": "month",
|
||||
"week": "week",
|
||||
"w": "week",
|
||||
"day": "day",
|
||||
"d": "day",
|
||||
"minute": "minute",
|
||||
"min": "minute",
|
||||
}
|
||||
return _count, _freq_format_dict[_freq]
|
||||
from .time import Freq, cal_sam_minute
|
||||
|
||||
|
||||
def resam_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> np.ndarray:
|
||||
@@ -75,46 +30,14 @@ def resam_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> np
|
||||
np.ndarray
|
||||
The calendar with frequency freq_sam
|
||||
"""
|
||||
raw_count, freq_raw = parse_freq(freq_raw)
|
||||
sam_count, freq_sam = parse_freq(freq_sam)
|
||||
raw_count, freq_raw = Freq.parse(freq_raw)
|
||||
sam_count, freq_sam = Freq.parse(freq_sam)
|
||||
if not len(calendar_raw):
|
||||
return calendar_raw
|
||||
|
||||
# if freq_sam is xminute, divide each trading day into several bars evenly
|
||||
if freq_sam == "minute":
|
||||
|
||||
def cal_sam_minute(x, sam_minutes):
|
||||
"""
|
||||
Sample raw calendar into calendar with sam_minutes freq, shift represents the shift minute the market time
|
||||
- open time of stock market is [9:30 - shift*pd.Timedelta(minutes=1)]
|
||||
- mid close time of stock market is [11:29 - shift*pd.Timedelta(minutes=1)]
|
||||
- mid open time of stock market is [13:00 - shift*pd.Timedelta(minutes=1)]
|
||||
- close time of stock market is [14:59 - shift*pd.Timedelta(minutes=1)]
|
||||
"""
|
||||
day_time = pd.Timestamp(x.date())
|
||||
shift = C.min_data_shift
|
||||
|
||||
open_time = day_time + pd.Timedelta(hours=9, minutes=30) - shift * pd.Timedelta(minutes=1)
|
||||
mid_close_time = day_time + pd.Timedelta(hours=11, minutes=29) - shift * pd.Timedelta(minutes=1)
|
||||
mid_open_time = day_time + pd.Timedelta(hours=13, minutes=00) - shift * pd.Timedelta(minutes=1)
|
||||
close_time = day_time + pd.Timedelta(hours=14, minutes=59) - shift * pd.Timedelta(minutes=1)
|
||||
|
||||
if open_time <= x <= mid_close_time:
|
||||
minute_index = (x - open_time).seconds // 60
|
||||
elif mid_open_time <= x <= close_time:
|
||||
minute_index = (x - mid_open_time).seconds // 60 + 120
|
||||
else:
|
||||
raise ValueError("datetime of calendar is out of range")
|
||||
minute_index = minute_index // sam_minutes * sam_minutes
|
||||
|
||||
if 0 <= minute_index < 120:
|
||||
return open_time + minute_index * pd.Timedelta(minutes=1)
|
||||
elif 120 <= minute_index < 240:
|
||||
return mid_open_time + (minute_index - 120) * pd.Timedelta(minutes=1)
|
||||
else:
|
||||
raise ValueError("calendar minute_index error, check `min_data_shift` in qlib.config.C")
|
||||
|
||||
if freq_raw != "minute":
|
||||
if freq_sam == Freq.NORM_FREQ_MINUTE:
|
||||
if freq_raw != Freq.NORM_FREQ_MINUTE:
|
||||
raise ValueError("when sampling minute calendar, freq of raw calendar must be minute or min")
|
||||
else:
|
||||
if raw_count > sam_count:
|
||||
@@ -125,15 +48,15 @@ def resam_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> np
|
||||
# else, convert the raw calendar into day calendar, and divide the whole calendar into several bars evenly
|
||||
else:
|
||||
_calendar_day = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, 0, 0, 0), calendar_raw)))
|
||||
if freq_sam == "day":
|
||||
if freq_sam == Freq.NORM_FREQ_DAY:
|
||||
return _calendar_day[::sam_count]
|
||||
|
||||
elif freq_sam == "week":
|
||||
elif freq_sam == Freq.NORM_FREQ_WEEK:
|
||||
_day_in_week = np.array(list(map(lambda x: x.dayofweek, _calendar_day)))
|
||||
_calendar_week = _calendar_day[np.ediff1d(_day_in_week, to_begin=-1) < 0]
|
||||
return _calendar_week[::sam_count]
|
||||
|
||||
elif freq_sam == "month":
|
||||
elif freq_sam == Freq.NORM_FREQ_MONTH:
|
||||
_day_in_month = np.array(list(map(lambda x: x.day, _calendar_day)))
|
||||
_calendar_month = _calendar_day[np.ediff1d(_day_in_month, to_begin=-1) < 0]
|
||||
return _calendar_month[::sam_count]
|
||||
@@ -175,7 +98,7 @@ def get_resam_calendar(
|
||||
|
||||
"""
|
||||
|
||||
_, norm_freq = parse_freq(freq)
|
||||
_, norm_freq = Freq.parse(freq)
|
||||
|
||||
from ..data.data import Cal
|
||||
|
||||
@@ -184,7 +107,7 @@ def get_resam_calendar(
|
||||
freq, freq_sam = freq, None
|
||||
except (ValueError, KeyError):
|
||||
freq_sam = freq
|
||||
if norm_freq in ["month", "week", "day"]:
|
||||
if norm_freq in [Freq.NORM_FREQ_MONTH, Freq.NORM_FREQ_WEEK, Freq.NORM_FREQ_DAY]:
|
||||
try:
|
||||
_calendar = Cal.calendar(
|
||||
start_time=start_time, end_time=end_time, freq="day", freq_sam=freq, future=future
|
||||
@@ -195,7 +118,7 @@ def get_resam_calendar(
|
||||
start_time=start_time, end_time=end_time, freq="1min", freq_sam=freq, future=future
|
||||
)
|
||||
freq = "1min"
|
||||
elif norm_freq == "minute":
|
||||
elif norm_freq == Freq.NORM_FREQ_MINUTE:
|
||||
_calendar = Cal.calendar(
|
||||
start_time=start_time, end_time=end_time, freq="1min", freq_sam=freq, future=future
|
||||
)
|
||||
@@ -205,6 +128,36 @@ def get_resam_calendar(
|
||||
return _calendar, freq, freq_sam
|
||||
|
||||
|
||||
def get_higher_eq_freq_feature(instruments, fields, start_time=None, end_time=None, freq="day", disk_cache=1):
|
||||
"""get the feature with higher or equal frequency than `freq`.
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame
|
||||
the feature with higher or equal frequency
|
||||
"""
|
||||
|
||||
from ..data.data import D
|
||||
|
||||
try:
|
||||
_result = D.features(instruments, fields, start_time, end_time, freq=freq, disk_cache=disk_cache)
|
||||
_freq = freq
|
||||
except (ValueError, KeyError):
|
||||
_, norm_freq = Freq.parse(freq)
|
||||
if norm_freq in [Freq.NORM_FREQ_MONTH, Freq.NORM_FREQ_WEEK, Freq.NORM_FREQ_DAY]:
|
||||
try:
|
||||
_result = D.features(instruments, fields, start_time, end_time, freq="day", disk_cache=disk_cache)
|
||||
_freq = "day"
|
||||
except (ValueError, KeyError):
|
||||
_result = D.features(instruments, fields, start_time, end_time, freq="1min", disk_cache=disk_cache)
|
||||
_freq = "1min"
|
||||
elif norm_freq == Freq.NORM_FREQ_MINUTE:
|
||||
_result = D.features(instruments, fields, start_time, end_time, freq="1min", disk_cache=disk_cache)
|
||||
_freq = "1min"
|
||||
else:
|
||||
raise ValueError(f"freq {freq} is not supported")
|
||||
return _result, _freq
|
||||
|
||||
|
||||
def resam_ts_data(
|
||||
ts_feature: Union[pd.DataFrame, pd.Series],
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
@@ -273,14 +226,14 @@ def resam_ts_data(
|
||||
end sampling time, by default None
|
||||
method : Union[str, Callable], optional
|
||||
sample method, apply method function to each stock series data, by default "last"
|
||||
- If type(method) is str, it should be an attribute of SeriesGroupBy or DataFrameGroupby, and run feature.groupby
|
||||
- If `feature` has MultiIndex[instrument, datetime], method must be a member of pandas.groupby when it's type is str.or callable function.
|
||||
- If type(method) is str or callable function, it should be an attribute of SeriesGroupBy or DataFrameGroupby, and applies groupy.method for the sliced time-series data
|
||||
- If method is None, do nothing for the sliced time-series data.
|
||||
method_kwargs : dict, optional
|
||||
arguments of method, by default {}
|
||||
|
||||
Returns
|
||||
-------
|
||||
The Resampled DataFrame/Series/Value
|
||||
The resampled DataFrame/Series/value, return None when the resampled data is empty.
|
||||
"""
|
||||
|
||||
selector_datetime = slice(start_time, end_time)
|
||||
@@ -293,7 +246,7 @@ def resam_ts_data(
|
||||
if datetime_level:
|
||||
feature = feature.loc[selector_datetime]
|
||||
else:
|
||||
feature = feature.loc[(slice(None), selector_datetime)]
|
||||
feature = feature.loc(axis=0)[(slice(None), selector_datetime)]
|
||||
|
||||
if feature.empty:
|
||||
return None
|
||||
|
||||
160
qlib/utils/time.py
Normal file
160
qlib/utils/time.py
Normal file
@@ -0,0 +1,160 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
"""
|
||||
Time related utils are compiled in this script
|
||||
"""
|
||||
import bisect
|
||||
from datetime import datetime, time
|
||||
from typing import List, Tuple
|
||||
import re
|
||||
from numpy import append
|
||||
import pandas as pd
|
||||
from qlib.config import C
|
||||
import functools
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=240)
|
||||
def get_min_cal(shift: int = 0) -> List[time]:
|
||||
"""
|
||||
get the minute level calendar in day period
|
||||
|
||||
Parameters
|
||||
----------
|
||||
shift : int
|
||||
the shift direction would be like pandas shift.
|
||||
series.shift(1) will replace the value at `i`-th with the one at `i-1`-th
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[time]:
|
||||
|
||||
"""
|
||||
cal = []
|
||||
for ts in list(pd.date_range("9:30", "11:29", freq="1min") - pd.Timedelta(minutes=shift)) + list(
|
||||
pd.date_range("13:00", "14:59", freq="1min") - pd.Timedelta(minutes=shift)
|
||||
):
|
||||
cal.append(ts.time())
|
||||
return cal
|
||||
|
||||
|
||||
class Freq:
|
||||
NORM_FREQ_MONTH = "month"
|
||||
NORM_FREQ_WEEK = "week"
|
||||
NORM_FREQ_DAY = "day"
|
||||
NORM_FREQ_MINUTE = "minute"
|
||||
SUPPORT_CAL_LIST = [NORM_FREQ_MINUTE]
|
||||
|
||||
MIN_CAL = get_min_cal()
|
||||
|
||||
def __init__(self, freq: str) -> None:
|
||||
self.count, self.base = self.parse(freq)
|
||||
|
||||
@staticmethod
|
||||
def parse(freq: str) -> Tuple[int, str]:
|
||||
"""
|
||||
Parse freq into a unified format
|
||||
|
||||
Parameters
|
||||
----------
|
||||
freq : str
|
||||
Raw freq, supported freq should match the re '^([0-9]*)(month|mon|week|w|day|d|minute|min)$'
|
||||
|
||||
Returns
|
||||
-------
|
||||
freq: Tuple[int, str]
|
||||
Unified freq, including freq count and unified freq unit. The freq unit should be '[month|week|day|minute]'.
|
||||
Example:
|
||||
|
||||
.. code-block::
|
||||
|
||||
print(Freq.parse("day"))
|
||||
(1, "day" )
|
||||
print(Freq.parse("2mon"))
|
||||
(2, "month")
|
||||
print(Freq.parse("10w"))
|
||||
(10, "week")
|
||||
|
||||
"""
|
||||
freq = freq.lower()
|
||||
match_obj = re.match("^([0-9]*)(month|mon|week|w|day|d|minute|min)$", freq)
|
||||
if match_obj is None:
|
||||
raise ValueError(
|
||||
"freq format is not supported, the freq should be like (n)month/mon, (n)week/w, (n)day/d, (n)minute/min"
|
||||
)
|
||||
_count = int(match_obj.group(1)) if match_obj.group(1) else 1
|
||||
_freq = match_obj.group(2)
|
||||
_freq_format_dict = {
|
||||
"month": Freq.NORM_FREQ_MONTH,
|
||||
"mon": Freq.NORM_FREQ_MONTH,
|
||||
"week": Freq.NORM_FREQ_WEEK,
|
||||
"w": Freq.NORM_FREQ_WEEK,
|
||||
"day": Freq.NORM_FREQ_DAY,
|
||||
"d": Freq.NORM_FREQ_DAY,
|
||||
"minute": Freq.NORM_FREQ_MINUTE,
|
||||
"min": Freq.NORM_FREQ_MINUTE,
|
||||
}
|
||||
return _count, _freq_format_dict[_freq]
|
||||
|
||||
|
||||
def get_day_min_idx_range(start: str, end: str, freq: str) -> Tuple[int, int]:
|
||||
"""
|
||||
get the min-bar index in a day for a time range (both left and right is closed) given a fixed frequency
|
||||
Parameters
|
||||
----------
|
||||
start : str
|
||||
e.g. "9:30"
|
||||
end : str
|
||||
e.g. "14:30"
|
||||
freq : str
|
||||
"1min"
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[int, int]:
|
||||
The index of start and end in the calendar. Both left and right are **closed**
|
||||
"""
|
||||
start = pd.Timestamp(start).time()
|
||||
end = pd.Timestamp(end).time()
|
||||
freq = Freq(freq)
|
||||
in_day_cal = Freq.MIN_CAL[:: freq.count]
|
||||
left_idx = bisect.bisect_left(in_day_cal, start)
|
||||
right_idx = bisect.bisect_right(in_day_cal, end) - 1
|
||||
return left_idx, right_idx
|
||||
|
||||
|
||||
def cal_sam_minute(x: pd.Timestamp, sam_minutes: int) -> pd.Timestamp:
|
||||
"""
|
||||
align the minute-level data to a down sampled calendar
|
||||
|
||||
e.g. align 10:38 to 10:35 in 5 minute-level(10:30 in 10 minute-level)
|
||||
|
||||
Parameters
|
||||
----------
|
||||
x : pd.Timestamp
|
||||
datetime to be aligned
|
||||
sam_minutes : int
|
||||
align to `sam_minutes` minute-level calendar
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.Timestamp:
|
||||
the datetime after aligned
|
||||
"""
|
||||
cal = get_min_cal(C.min_data_shift)[::sam_minutes]
|
||||
idx = bisect.bisect_right(cal, x.time()) - 1
|
||||
date, new_time = x.date(), cal[idx]
|
||||
return pd.Timestamp(
|
||||
datetime(
|
||||
date.year,
|
||||
month=date.month,
|
||||
day=date.day,
|
||||
hour=new_time.hour,
|
||||
minute=new_time.minute,
|
||||
second=new_time.second,
|
||||
microsecond=new_time.microsecond,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(get_day_min_idx_range("8:30", "14:59", "10min"))
|
||||
@@ -18,10 +18,12 @@ There are 4 total situations for using different trainers in different situation
|
||||
========================= ===================================================================================
|
||||
Situations Description
|
||||
========================= ===================================================================================
|
||||
Online + Trainer When you REAL want to do a routine, the Trainer will help you train the models.
|
||||
Online + Trainer When you want to do a REAL routine, the Trainer will help you train the models. It
|
||||
will train models task by task and strategy by strategy.
|
||||
|
||||
Online + DelayTrainer In normal online routine, whether Trainer or DelayTrainer will REAL train models
|
||||
in this routine. So it is not necessary to use DelayTrainer when do a REAL routine.
|
||||
Online + DelayTrainer When your models don't have any temporal dependence, the DelayTrainer will train
|
||||
nothing until all tasks have been prepared. It makes user can train all tasks in
|
||||
the end of `routine` or `first_train`.
|
||||
|
||||
Simulation + Trainer When your models have some temporal dependence on the previous models, then you
|
||||
need to consider using Trainer. This means it will REAL train your models in
|
||||
@@ -103,17 +105,21 @@ class OnlineManager(Serializable):
|
||||
"""
|
||||
if strategies is None:
|
||||
strategies = self.strategies
|
||||
for strategy in strategies:
|
||||
|
||||
models_list = []
|
||||
for strategy in strategies:
|
||||
self.logger.info(f"Strategy `{strategy.name_id}` begins first training...")
|
||||
tasks = strategy.first_tasks()
|
||||
models = self.trainer.train(tasks, experiment_name=strategy.name_id)
|
||||
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
|
||||
models_list.append(models)
|
||||
self.logger.info(f"Finished training {len(models)} models.")
|
||||
|
||||
online_models = strategy.prepare_online_models(models, **model_kwargs)
|
||||
self.history.setdefault(self.cur_time, {})[strategy] = online_models
|
||||
|
||||
if not self.status == self.STATUS_SIMULATING or not self.trainer.is_delay():
|
||||
for strategy, models in zip(strategies, models_list):
|
||||
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
|
||||
|
||||
def routine(
|
||||
self,
|
||||
cur_time: Union[str, pd.Timestamp] = None,
|
||||
@@ -139,33 +145,38 @@ class OnlineManager(Serializable):
|
||||
cur_time = D.calendar(freq=self.freq).max()
|
||||
self.cur_time = pd.Timestamp(cur_time) # None for latest date
|
||||
|
||||
models_list = []
|
||||
for strategy in self.strategies:
|
||||
self.logger.info(f"Strategy `{strategy.name_id}` begins routine...")
|
||||
if self.status == self.STATUS_NORMAL:
|
||||
strategy.tool.update_online_pred()
|
||||
|
||||
tasks = strategy.prepare_tasks(self.cur_time, **task_kwargs)
|
||||
models = self.trainer.train(tasks)
|
||||
if self.status == self.STATUS_NORMAL or not self.trainer.is_delay():
|
||||
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
|
||||
models = self.trainer.train(tasks, experiment_name=strategy.name_id)
|
||||
models_list.append(models)
|
||||
self.logger.info(f"Finished training {len(models)} models.")
|
||||
online_models = strategy.prepare_online_models(models, **model_kwargs)
|
||||
self.history.setdefault(self.cur_time, {})[strategy] = online_models
|
||||
|
||||
if not self.trainer.is_delay():
|
||||
if not self.status == self.STATUS_SIMULATING or not self.trainer.is_delay():
|
||||
for strategy, models in zip(self.strategies, models_list):
|
||||
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
|
||||
self.prepare_signals(**signal_kwargs)
|
||||
|
||||
def get_collector(self) -> MergeCollector:
|
||||
def get_collector(self, **kwargs) -> MergeCollector:
|
||||
"""
|
||||
Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results from every strategy.
|
||||
This collector can be a basis as the signals preparation.
|
||||
|
||||
Args:
|
||||
**kwargs: the params for get_collector.
|
||||
|
||||
Returns:
|
||||
MergeCollector: the collector to merge other collectors.
|
||||
"""
|
||||
collector_dict = {}
|
||||
for strategy in self.strategies:
|
||||
collector_dict[strategy.name_id] = strategy.get_collector()
|
||||
collector_dict[strategy.name_id] = strategy.get_collector(**kwargs)
|
||||
return MergeCollector(collector_dict, process_list=[])
|
||||
|
||||
def add_strategy(self, strategies: Union[OnlineStrategy, List[OnlineStrategy]]):
|
||||
@@ -297,6 +308,7 @@ class OnlineManager(Serializable):
|
||||
# NOTE: Assumption: the predictions of online models need less than next cur_time, or this method will work in a wrong way.
|
||||
self.prepare_signals(**signal_kwargs)
|
||||
if signals_time > cur_time:
|
||||
# FIXME: if use DelayTrainer and worker (and worker is faster than main progress), there are some possibilities of showing this warning.
|
||||
self.logger.warn(
|
||||
f"The signals have already parpred to {signals_time} by last preparation, but current time is only {cur_time}. This may be because the online models predict more than they should, which can cause signals to be contaminated by the offline models."
|
||||
)
|
||||
|
||||
@@ -7,7 +7,8 @@ import warnings
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
from ..contrib.evaluate import risk_analysis
|
||||
from typing import Union, List
|
||||
from ..contrib.evaluate import indicator_analysis, risk_analysis, indicator_analysis
|
||||
|
||||
from ..data.dataset import DatasetH
|
||||
from ..data.dataset.handler import DataHandlerLP
|
||||
@@ -15,9 +16,9 @@ from ..backtest import backtest as normal_backtest
|
||||
from ..utils import init_instance_by_config, get_module_by_module_path
|
||||
from ..log import get_module_logger
|
||||
from ..utils import flatten_dict
|
||||
from ..utils.resam import parse_freq
|
||||
from ..utils.time import Freq
|
||||
from ..strategy.base import BaseStrategy
|
||||
from ..contrib.eva.alpha import calc_ic, calc_long_short_return
|
||||
from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec
|
||||
|
||||
|
||||
logger = get_module_logger("workflow", logging.INFO)
|
||||
@@ -294,7 +295,15 @@ class PortAnaRecord(RecordTemp):
|
||||
|
||||
artifact_path = "portfolio_analysis"
|
||||
|
||||
def __init__(self, recorder, config, risk_analysis_freq, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
recorder,
|
||||
config,
|
||||
risk_analysis_freq: Union[List, str] = None,
|
||||
indicator_analysis_freq: Union[List, str] = None,
|
||||
indicator_analysis_method=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
config["strategy"] : dict
|
||||
define the strategy class as well as the kwargs.
|
||||
@@ -302,22 +311,50 @@ class PortAnaRecord(RecordTemp):
|
||||
define the executor class as well as the kwargs.
|
||||
config["backtest"] : dict
|
||||
define the backtest kwargs.
|
||||
risk_analysis_freq : int
|
||||
risk_analysis_freq : str|List[str]
|
||||
risk analysis freq of report
|
||||
indicator_analysis_freq : str|List[str]
|
||||
indicator analysis freq of report
|
||||
indicator_analysis_method : str, optional, default by None
|
||||
the candidated values include 'mean', 'amount_weighted', 'value_weighted'
|
||||
"""
|
||||
super().__init__(recorder=recorder, **kwargs)
|
||||
|
||||
self.strategy_config = config["strategy"]
|
||||
self.executor_config = config["executor"]
|
||||
_default_executor_config = {
|
||||
"class": "SimulatorExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"time_per_step": "day",
|
||||
"generate_report": True,
|
||||
},
|
||||
}
|
||||
self.executor_config = config.get("executor", _default_executor_config)
|
||||
self.backtest_config = config["backtest"]
|
||||
_count, _freq = parse_freq(risk_analysis_freq)
|
||||
self.risk_analysis_freq = f"{_count}{_freq}"
|
||||
self.report_freq = self._get_report_freq(self.executor_config)
|
||||
|
||||
self.all_freq = self._get_report_freq(self.executor_config)
|
||||
if risk_analysis_freq is None:
|
||||
risk_analysis_freq = [self.all_freq[0]]
|
||||
if indicator_analysis_freq is None:
|
||||
indicator_analysis_freq = [self.all_freq[0]]
|
||||
|
||||
if isinstance(risk_analysis_freq, str):
|
||||
risk_analysis_freq = [risk_analysis_freq]
|
||||
if isinstance(indicator_analysis_freq, str):
|
||||
indicator_analysis_freq = [indicator_analysis_freq]
|
||||
|
||||
self.risk_analysis_freq = [
|
||||
"{0}{1}".format(*Freq.parse(_analysis_freq)) for _analysis_freq in risk_analysis_freq
|
||||
]
|
||||
self.indicator_analysis_freq = [
|
||||
"{0}{1}".format(*Freq.parse(_analysis_freq)) for _analysis_freq in indicator_analysis_freq
|
||||
]
|
||||
self.indicator_analysis_method = indicator_analysis_method
|
||||
|
||||
def _get_report_freq(self, executor_config):
|
||||
ret_freq = []
|
||||
if executor_config["kwargs"].get("generate_report", False):
|
||||
_count, _freq = parse_freq(executor_config["kwargs"]["time_per_step"])
|
||||
_count, _freq = Freq.parse(executor_config["kwargs"]["time_per_step"])
|
||||
ret_freq.append(f"{_count}{_freq}")
|
||||
if "sub_env" in executor_config["kwargs"]:
|
||||
ret_freq.extend(self._get_report_freq(executor_config["kwargs"]["sub_env"]))
|
||||
@@ -325,55 +362,97 @@ class PortAnaRecord(RecordTemp):
|
||||
|
||||
def generate(self, **kwargs):
|
||||
# custom strategy and get backtest
|
||||
report_dict = normal_backtest(
|
||||
report_dict, indicator_dict = normal_backtest(
|
||||
executor=self.executor_config, strategy=self.strategy_config, **self.backtest_config
|
||||
)
|
||||
for report_freq, (report_normal, positions_normal) in report_dict.items():
|
||||
for _freq, (report_normal, positions_normal) in report_dict.items():
|
||||
self.recorder.save_objects(
|
||||
**{f"report_normal_{report_freq}.pkl": report_normal}, artifact_path=PortAnaRecord.get_path()
|
||||
**{f"report_normal_{_freq}.pkl": report_normal}, artifact_path=PortAnaRecord.get_path()
|
||||
)
|
||||
self.recorder.save_objects(
|
||||
**{f"positions_normal_{report_freq}.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path()
|
||||
**{f"positions_normal_{_freq}.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path()
|
||||
)
|
||||
|
||||
if self.risk_analysis_freq not in report_dict:
|
||||
warnings.warn(
|
||||
f"the freq {self.risk_analysis_freq} report is not found, please set the corresponding env with `generate_report==True`"
|
||||
)
|
||||
else:
|
||||
report_normal, _ = report_dict.get(self.risk_analysis_freq)
|
||||
analysis = dict()
|
||||
analysis["excess_return_without_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"], freq=self.risk_analysis_freq
|
||||
)
|
||||
analysis["excess_return_with_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"] - report_normal["cost"], freq=self.risk_analysis_freq
|
||||
)
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
# log metrics
|
||||
self.recorder.log_metrics(**flatten_dict(analysis_df["risk"].unstack().T.to_dict()))
|
||||
# save results
|
||||
for _freq, indicators_normal in indicator_dict.items():
|
||||
self.recorder.save_objects(
|
||||
**{f"port_analysis_{report_freq}.pkl": analysis_df}, artifact_path=PortAnaRecord.get_path()
|
||||
**{f"indicators_normal_{_freq}.pkl": indicators_normal}, artifact_path=PortAnaRecord.get_path()
|
||||
)
|
||||
logger.info(
|
||||
f"Portfolio analysis record 'port_analysis_{report_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
|
||||
)
|
||||
# print out results
|
||||
pprint("The following are analysis results of the excess return without cost.")
|
||||
pprint(analysis["excess_return_without_cost"])
|
||||
pprint("The following are analysis results of the excess return with cost.")
|
||||
pprint(analysis["excess_return_with_cost"])
|
||||
|
||||
for _analysis_freq in self.risk_analysis_freq:
|
||||
if _analysis_freq not in report_dict:
|
||||
warnings.warn(
|
||||
f"the freq {_analysis_freq} report is not found, please set the corresponding env with `generate_report=True`"
|
||||
)
|
||||
else:
|
||||
report_normal, _ = report_dict.get(_analysis_freq)
|
||||
analysis = dict()
|
||||
analysis["excess_return_without_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"], freq=_analysis_freq
|
||||
)
|
||||
analysis["excess_return_with_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"] - report_normal["cost"], freq=_analysis_freq
|
||||
)
|
||||
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
# log metrics
|
||||
analysis_dict = flatten_dict(analysis_df["risk"].unstack().T.to_dict())
|
||||
self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()})
|
||||
# save results
|
||||
self.recorder.save_objects(
|
||||
**{f"port_analysis_{_analysis_freq}.pkl": analysis_df}, artifact_path=PortAnaRecord.get_path()
|
||||
)
|
||||
logger.info(
|
||||
f"Portfolio analysis record 'port_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
|
||||
)
|
||||
# print out results
|
||||
pprint(f"The following are analysis results of benchmark return({_analysis_freq}).")
|
||||
pprint(risk_analysis(report_normal["bench"], freq=_analysis_freq))
|
||||
pprint(f"The following are analysis results of the excess return without cost({_analysis_freq}).")
|
||||
pprint(analysis["excess_return_without_cost"])
|
||||
pprint(f"The following are analysis results of the excess return with cost({_analysis_freq}).")
|
||||
pprint(analysis["excess_return_with_cost"])
|
||||
|
||||
for _analysis_freq in self.indicator_analysis_freq:
|
||||
if _analysis_freq not in indicator_dict:
|
||||
warnings.warn(f"the freq {_analysis_freq} indicator is not found")
|
||||
else:
|
||||
indicators_normal = indicator_dict.get(_analysis_freq)
|
||||
if self.indicator_analysis_method is None:
|
||||
analysis_df = indicator_analysis(indicators_normal)
|
||||
else:
|
||||
analysis_df = indicator_analysis(indicators_normal, method=self.indicator_analysis_method)
|
||||
# log metrics
|
||||
analysis_dict = analysis_df["value"].to_dict()
|
||||
self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()})
|
||||
# save results
|
||||
self.recorder.save_objects(
|
||||
**{f"indicator_analysis_{_analysis_freq}.pkl": analysis_df}, artifact_path=PortAnaRecord.get_path()
|
||||
)
|
||||
logger.info(
|
||||
f"Indicator analysis record 'indicator_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
|
||||
)
|
||||
pprint(f"The following are analysis results of indicators({_analysis_freq}).")
|
||||
pprint(analysis_df)
|
||||
|
||||
def list(self):
|
||||
list_path = []
|
||||
for _freq in self.report_freq:
|
||||
for _freq in self.all_freq:
|
||||
list_path.extend(
|
||||
[
|
||||
PortAnaRecord.get_path(f"report_normal_{_freq}.pkl"),
|
||||
PortAnaRecord.get_path(f"positions_normal_{_freq}.pkl"),
|
||||
]
|
||||
)
|
||||
if _freq == self.risk_analysis_freq:
|
||||
list_path.append(PortAnaRecord.get_path(f"port_analysis_{_freq}.pkl"))
|
||||
for _analysis_freq in self.risk_analysis_freq:
|
||||
if _analysis_freq in self.all_freq:
|
||||
list_path.append(PortAnaRecord.get_path(f"port_analysis_{_analysis_freq}.pkl"))
|
||||
else:
|
||||
warnings.warn(f"risk_analysis freq {_analysis_freq} is not found")
|
||||
|
||||
for _analysis_freq in self.indicator_analysis_freq:
|
||||
if _analysis_freq in self.all_freq:
|
||||
list_path.append(PortAnaRecord.get_path(f"indicator_analysis_{_analysis_freq}.pkl"))
|
||||
else:
|
||||
warnings.warn(f"indicator_analysis freq {_analysis_freq} is not found")
|
||||
|
||||
return list_path
|
||||
|
||||
@@ -69,28 +69,29 @@ class TaskManager:
|
||||
|
||||
ENCODE_FIELDS_PREFIX = ["def", "res"]
|
||||
|
||||
def __init__(self, task_pool: str = None):
|
||||
def __init__(self, task_pool: str):
|
||||
"""
|
||||
Init Task Manager, remember to make the statement of MongoDB url and database name firstly.
|
||||
A TaskManager instance serves a specific task pool.
|
||||
The static method of this module serves the whole MongoDB.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
task_pool: str
|
||||
the name of Collection in MongoDB
|
||||
"""
|
||||
self.mdb = get_mongodb()
|
||||
if task_pool is not None:
|
||||
self.task_pool = getattr(self.mdb, task_pool)
|
||||
self.task_pool = getattr(get_mongodb(), task_pool)
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
|
||||
def list(self) -> list:
|
||||
@staticmethod
|
||||
def list() -> list:
|
||||
"""
|
||||
List the all collection(task_pool) of the db
|
||||
List the all collection(task_pool) of the db.
|
||||
|
||||
Returns:
|
||||
list
|
||||
"""
|
||||
return self.mdb.list_collection_names()
|
||||
return get_mongodb().list_collection_names()
|
||||
|
||||
def _encode_task(self, task):
|
||||
for prefix in self.ENCODE_FIELDS_PREFIX:
|
||||
@@ -109,6 +110,25 @@ class TaskManager:
|
||||
def _dict_to_str(self, flt):
|
||||
return {k: str(v) for k, v in flt.items()}
|
||||
|
||||
def _decode_query(self, query):
|
||||
"""
|
||||
If the query includes any `_id`, then it needs `ObjectId` to decode.
|
||||
For example, when using TrainerRM, it needs query `{"_id": {"$in": _id_list}}`. Then we need to `ObjectId` every `_id` in `_id_list`.
|
||||
|
||||
Args:
|
||||
query (dict): query dict. Defaults to {}.
|
||||
|
||||
Returns:
|
||||
dict: the query after decoding.
|
||||
"""
|
||||
if "_id" in query:
|
||||
if isinstance(query["_id"], dict):
|
||||
for key in query["_id"]:
|
||||
query["_id"][key] = [ObjectId(i) for i in query["_id"][key]]
|
||||
else:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
return query
|
||||
|
||||
def replace_task(self, task, new_task):
|
||||
"""
|
||||
Use a new task to replace a old one
|
||||
@@ -224,8 +244,7 @@ class TaskManager:
|
||||
dict: a task(document in collection) after decoding
|
||||
"""
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
query = self._decode_query(query)
|
||||
query.update({"status": status})
|
||||
task = self.task_pool.find_one_and_update(
|
||||
query, {"$set": {"status": self.STATUS_RUNNING}}, sort=[("priority", pymongo.DESCENDING)]
|
||||
@@ -283,12 +302,11 @@ class TaskManager:
|
||||
dict: a task(document in collection) after decoding
|
||||
"""
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
query = self._decode_query(query)
|
||||
for t in self.task_pool.find(query):
|
||||
yield self._decode_task(t)
|
||||
|
||||
def re_query(self, _id):
|
||||
def re_query(self, _id) -> dict:
|
||||
"""
|
||||
Use _id to query task.
|
||||
|
||||
@@ -339,8 +357,7 @@ class TaskManager:
|
||||
|
||||
"""
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
query = self._decode_query(query)
|
||||
self.task_pool.delete_many(query)
|
||||
|
||||
def task_stat(self, query={}) -> dict:
|
||||
@@ -354,8 +371,7 @@ class TaskManager:
|
||||
dict
|
||||
"""
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
query = self._decode_query(query)
|
||||
tasks = self.query(query=query, decode=False)
|
||||
status_stat = {}
|
||||
for t in tasks:
|
||||
@@ -377,8 +393,7 @@ class TaskManager:
|
||||
|
||||
def reset_status(self, query, status):
|
||||
query = query.copy()
|
||||
if "_id" in query:
|
||||
query["_id"] = ObjectId(query["_id"])
|
||||
query = self._decode_query(query)
|
||||
print(self.task_pool.update_many(query, {"$set": {"status": status}}))
|
||||
|
||||
def prioritize(self, task, priority: int):
|
||||
@@ -402,9 +417,19 @@ class TaskManager:
|
||||
return sum(task_stat.values())
|
||||
|
||||
def wait(self, query={}):
|
||||
"""
|
||||
When multiprocessing, the main progress may fetch nothing from TaskManager because there are still some running tasks.
|
||||
So main progress should wait until all tasks are trained well by other progress or machines.
|
||||
|
||||
Args:
|
||||
query (dict, optional): the query dict. Defaults to {}.
|
||||
"""
|
||||
task_stat = self.task_stat(query)
|
||||
total = self._get_total(task_stat)
|
||||
last_undone_n = self._get_undone_n(task_stat)
|
||||
if last_undone_n == 0:
|
||||
return
|
||||
self.logger.warn(f"Waiting for {last_undone_n} undone tasks. Please make sure they are running.")
|
||||
with tqdm(total=total, initial=total - last_undone_n) as pbar:
|
||||
while True:
|
||||
time.sleep(10)
|
||||
|
||||
@@ -17,7 +17,6 @@ def experiment_exit_handler():
|
||||
Thus, if any exception or user interuption occurs beforehead, we should handle them first. Once `R` is
|
||||
ended, another call of `R.end_exp` will not take effect.
|
||||
"""
|
||||
signal.signal(signal.SIGINT, experiment_kill_signal_handler) # handle user keyboard interupt
|
||||
sys.excepthook = experiment_exception_hook # handle uncaught exception
|
||||
atexit.register(R.end_exp, recorder_status=Recorder.STATUS_FI) # will not take effect if experiment ends
|
||||
|
||||
@@ -39,11 +38,3 @@ def experiment_exception_hook(type, value, tb):
|
||||
print(f"{type.__name__}: {value}")
|
||||
|
||||
R.end_exp(recorder_status=Recorder.STATUS_FA)
|
||||
|
||||
|
||||
def experiment_kill_signal_handler(signum, frame):
|
||||
"""
|
||||
End an experiment when user kill the program through keyboard (CTRL+C, etc.).
|
||||
"""
|
||||
R.end_exp(recorder_status=Recorder.STATUS_FA)
|
||||
raise KeyboardInterrupt
|
||||
|
||||
2
setup.py
2
setup.py
@@ -45,7 +45,7 @@ REQUIRED = [
|
||||
"statsmodels",
|
||||
"xlrd>=1.0.0",
|
||||
"plotly==4.12.0",
|
||||
"matplotlib==3.1.3",
|
||||
"matplotlib==3.3",
|
||||
"tables>=3.6.1",
|
||||
"pyyaml>=5.3.1",
|
||||
"mlflow>=1.12.1",
|
||||
|
||||
89
tests/misc/test_utils.py
Normal file
89
tests/misc/test_utils.py
Normal file
@@ -0,0 +1,89 @@
|
||||
from unittest.case import TestCase
|
||||
import unittest
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
from qlib import init
|
||||
from qlib.config import C
|
||||
from qlib.log import TimeInspector
|
||||
from qlib.utils.time import cal_sam_minute as cal_sam_minute_new, get_min_cal
|
||||
|
||||
|
||||
def cal_sam_minute(x, sam_minutes):
|
||||
"""
|
||||
Sample raw calendar into calendar with sam_minutes freq, shift represents the shift minute the market time
|
||||
- open time of stock market is [9:30 - shift*pd.Timedelta(minutes=1)]
|
||||
- mid close time of stock market is [11:29 - shift*pd.Timedelta(minutes=1)]
|
||||
- mid open time of stock market is [13:00 - shift*pd.Timedelta(minutes=1)]
|
||||
- close time of stock market is [14:59 - shift*pd.Timedelta(minutes=1)]
|
||||
"""
|
||||
# TODO: actually, this version is much faster when no cache or optimization
|
||||
day_time = pd.Timestamp(x.date())
|
||||
shift = C.min_data_shift
|
||||
|
||||
open_time = day_time + pd.Timedelta(hours=9, minutes=30) - shift * pd.Timedelta(minutes=1)
|
||||
mid_close_time = day_time + pd.Timedelta(hours=11, minutes=29) - shift * pd.Timedelta(minutes=1)
|
||||
mid_open_time = day_time + pd.Timedelta(hours=13, minutes=00) - shift * pd.Timedelta(minutes=1)
|
||||
close_time = day_time + pd.Timedelta(hours=14, minutes=59) - shift * pd.Timedelta(minutes=1)
|
||||
|
||||
if open_time <= x <= mid_close_time:
|
||||
minute_index = (x - open_time).seconds // 60
|
||||
elif mid_open_time <= x <= close_time:
|
||||
minute_index = (x - mid_open_time).seconds // 60 + 120
|
||||
else:
|
||||
raise ValueError("datetime of calendar is out of range")
|
||||
minute_index = minute_index // sam_minutes * sam_minutes
|
||||
|
||||
if 0 <= minute_index < 120:
|
||||
return open_time + minute_index * pd.Timedelta(minutes=1)
|
||||
elif 120 <= minute_index < 240:
|
||||
return mid_open_time + (minute_index - 120) * pd.Timedelta(minutes=1)
|
||||
else:
|
||||
raise ValueError("calendar minute_index error, check `min_data_shift` in qlib.config.C")
|
||||
|
||||
|
||||
class TimeUtils(TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
init()
|
||||
|
||||
def test_cal_sam_minute(self):
|
||||
# test the correctness of the code
|
||||
random_n = 1000
|
||||
cal = get_min_cal()
|
||||
|
||||
def gen_args():
|
||||
for time in np.random.choice(cal, size=random_n, replace=True):
|
||||
sam_minutes = np.random.choice([1, 2, 3, 4, 5, 6])
|
||||
dt = pd.Timestamp(
|
||||
datetime(
|
||||
2021,
|
||||
month=3,
|
||||
day=3,
|
||||
hour=time.hour,
|
||||
minute=time.minute,
|
||||
second=time.second,
|
||||
microsecond=time.microsecond,
|
||||
)
|
||||
)
|
||||
args = dt, sam_minutes
|
||||
yield args
|
||||
|
||||
for args in gen_args():
|
||||
assert cal_sam_minute(*args) == cal_sam_minute_new(*args)
|
||||
|
||||
# test the performance of the code
|
||||
|
||||
args_l = list(gen_args())
|
||||
|
||||
with TimeInspector.logt():
|
||||
for args in args_l:
|
||||
cal_sam_minute(*args)
|
||||
|
||||
with TimeInspector.logt():
|
||||
for args in args_l:
|
||||
cal_sam_minute_new(*args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user