mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-13 09:20:59 +08:00
Compare commits
14 Commits
mini_proje
...
v0.8.6
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
36950b905d | ||
|
|
58540f76ee | ||
|
|
3e6e2865ce | ||
|
|
3fcbaa33fa | ||
|
|
50409ff17b | ||
|
|
afcea404a5 | ||
|
|
e24ef67663 | ||
|
|
2d5eecb9a2 | ||
|
|
89972f6c6f | ||
|
|
1ef8e61abd | ||
|
|
1a4114b683 | ||
|
|
e874ef2bc1 | ||
|
|
14b2b355a7 | ||
|
|
64fadff218 |
4
.github/workflows/test.yml
vendored
4
.github/workflows/test.yml
vendored
@@ -72,7 +72,7 @@ jobs:
|
||||
run: |
|
||||
pip install --upgrade pip
|
||||
pip install pylint
|
||||
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0201,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' qlib --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500"
|
||||
pylint --disable=C0104,C0114,C0115,C0116,C0301,C0302,C0411,C0413,C1802,R0401,R0801,R0902,R0903,R0911,R0912,R0913,R0914,R0915,R1720,W0105,W0123,W0201,W0511,W0613,W1113,W1514,E0401,E1121,C0103,C0209,R0402,R1705,R1710,R1725,R1735,W0102,W0212,W0221,W0223,W0231,W0237,W0612,W0621,W0622,W0703,W1309,E1102,E1136 --const-rgx='[a-z_][a-z0-9_]{2,30}$' qlib --init-hook "import astroid; astroid.context.InferenceContext.max_inferred = 500"
|
||||
|
||||
# The following flake8 error codes were ignored:
|
||||
# E501 line too long
|
||||
@@ -110,7 +110,7 @@ jobs:
|
||||
run: |
|
||||
python scripts/get_data.py qlib_data --name qlib_data_simple --target_dir ~/.qlib/qlib_data/cn_data_simple --interval 1d --region cn
|
||||
python -c "import os; userpath=os.path.expanduser('~'); os.rename(userpath + '/.qlib/qlib_data/cn_data_simple', userpath + '/.qlib/qlib_data/cn_data')"
|
||||
azcopy copy https://qlibpublic.blob.core.windows.net/data /tmp/qlibpublic --recursive
|
||||
azcopy copy https://qlibpublic.blob.core.windows.net/data/rl /tmp/qlibpublic/data --recursive
|
||||
mv /tmp/qlibpublic/data tests/.data
|
||||
|
||||
- name: Test workflow by config (install from pip)
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -44,3 +44,4 @@ tags
|
||||
*.swp
|
||||
|
||||
./pretrain
|
||||
.idea/
|
||||
|
||||
@@ -66,7 +66,7 @@ TopkDropoutStrategy
|
||||
- Adopt the ``Topk-Drop`` algorithm to calculate the target amount of each stock
|
||||
|
||||
.. note::
|
||||
There are two parameters for the ``Topk-Drop`` algorithm:
|
||||
There are two parameters for the ``Topk-Drop`` algorithm:
|
||||
|
||||
- `Topk`: The number of stocks held
|
||||
- `Drop`: The number of stocks sold on each trading day
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
pandas==1.1.2
|
||||
numpy==1.21.0
|
||||
lightgbm==3.1.0
|
||||
lightgbm
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi500
|
||||
benchmark: &benchmark SH000905
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: LGBModel
|
||||
module_path: qlib.contrib.model.gbdt
|
||||
kwargs:
|
||||
loss: mse
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.2
|
||||
subsample: 0.8789
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha158
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -0,0 +1,80 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data"
|
||||
region: cn
|
||||
market: &market csi500
|
||||
benchmark: &benchmark SH000905
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: 2008-01-01
|
||||
end_time: 2020-08-01
|
||||
fit_start_time: 2008-01-01
|
||||
fit_end_time: 2014-12-31
|
||||
instruments: *market
|
||||
infer_processors: []
|
||||
learn_processors:
|
||||
- class: DropnaLabel
|
||||
- class: CSRankNorm
|
||||
kwargs:
|
||||
fields_group: label
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
port_analysis_config: &port_analysis_config
|
||||
strategy:
|
||||
class: TopkDropoutStrategy
|
||||
module_path: qlib.contrib.strategy
|
||||
kwargs:
|
||||
signal:
|
||||
- <MODEL>
|
||||
- <DATASET>
|
||||
topk: 50
|
||||
n_drop: 5
|
||||
backtest:
|
||||
start_time: 2017-01-01
|
||||
end_time: 2020-08-01
|
||||
account: 100000000
|
||||
benchmark: *benchmark
|
||||
exchange_kwargs:
|
||||
limit_threshold: 0.095
|
||||
deal_price: close
|
||||
open_cost: 0.0005
|
||||
close_cost: 0.0015
|
||||
min_cost: 5
|
||||
task:
|
||||
model:
|
||||
class: LGBModel
|
||||
module_path: qlib.contrib.model.gbdt
|
||||
kwargs:
|
||||
loss: mse
|
||||
colsample_bytree: 0.8879
|
||||
learning_rate: 0.0421
|
||||
subsample: 0.8789
|
||||
lambda_l1: 205.6999
|
||||
lambda_l2: 580.9768
|
||||
max_depth: 8
|
||||
num_leaves: 210
|
||||
num_threads: 20
|
||||
dataset:
|
||||
class: DatasetH
|
||||
module_path: qlib.data.dataset
|
||||
kwargs:
|
||||
handler:
|
||||
class: Alpha360
|
||||
module_path: qlib.contrib.data.handler
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [2008-01-01, 2014-12-31]
|
||||
valid: [2015-01-01, 2016-12-31]
|
||||
test: [2017-01-01, 2020-08-01]
|
||||
record:
|
||||
- class: SignalRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
model: <MODEL>
|
||||
dataset: <DATASET>
|
||||
- class: SigAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
ana_long_short: False
|
||||
ann_scaler: 252
|
||||
- class: PortAnaRecord
|
||||
module_path: qlib.workflow.record_temp
|
||||
kwargs:
|
||||
config: *port_analysis_config
|
||||
@@ -20,7 +20,9 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
> NOTE:
|
||||
> We have very limited resources to implement and finetune the models. We tried our best effort to fairly compare these models. But some models may have greater potential than what it looks like in the table below. Your contribution is highly welcomed to explore their potential.
|
||||
|
||||
## Alpha158 dataset
|
||||
## Results on CSI300
|
||||
|
||||
### Alpha158 dataset
|
||||
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
|------------------------------------------|-------------------------------------|-------------|-------------|-------------|-------------|-------------------|-------------------|--------------|
|
||||
@@ -44,7 +46,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
| DoubleEnsemble(Chuheng Zhang, et al.) | Alpha158 | 0.0544±0.00 | 0.4340±0.00 | 0.0523±0.00 | 0.4284±0.01 | 0.1168±0.01 | 1.3384±0.12 | -0.1036±0.01 |
|
||||
|
||||
|
||||
## Alpha360 dataset
|
||||
### Alpha360 dataset
|
||||
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
|-------------------------------------------|----------|-------------|-------------|-------------|-------------|-------------------|-------------------|--------------|
|
||||
@@ -79,6 +81,38 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
- Signal-based evaluation: IC, ICIR, Rank IC, Rank ICIR
|
||||
- Portfolio-based metrics: Annualized Return, Information Ratio, Max Drawdown
|
||||
|
||||
## Results on CSI500
|
||||
The results on CSI500 is not complete. PR's for models on csi500 are welcome!
|
||||
|
||||
Transfer previous models in CSI300 to CSI500 is quite easy. You can try models with just a few commands below.
|
||||
```
|
||||
cd examples/benchmarks/LightGBM
|
||||
pip install -r requirements.txt
|
||||
|
||||
# create new config and set the benchmark to csi500
|
||||
cp workflow_config_lightgbm_Alpha158.yaml workflow_config_lightgbm_Alpha158_csi500.yaml
|
||||
sed -i "s/csi300/csi500/g" workflow_config_lightgbm_Alpha158_csi500.yaml
|
||||
sed -i "s/SH000300/SH000905/g" workflow_config_lightgbm_Alpha158_csi500.yaml
|
||||
|
||||
# you can either run the model once
|
||||
qrun workflow_config_lightgbm_Alpha158_csi500.yaml
|
||||
|
||||
# or run it for multiple times automatically and get the summarized results.
|
||||
cd ../../
|
||||
python run_all_model.py run 3 lightgbm Alpha158 csi500 # for models with randomness. please run it for 20 times.
|
||||
```
|
||||
|
||||
### Alpha158 dataset
|
||||
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
|------------|----------|-------------|-------------|-------------|-------------|-------------------|-------------------|--------------|
|
||||
| LightGBM | Alpha158 | 0.0377±0.00 | 0.3860±0.00 | 0.0448±0.00 | 0.4675±0.00 | 0.1151±0.00 | 1.3884±0.00 | -0.0898±0.00 |
|
||||
|
||||
### Alpha360 dataset
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
|------------|----------|-------------|-------------|-------------|-------------|-------------------|-------------------|--------------|
|
||||
| LightGBM | Alpha360 | 0.0400±0.00 | 0.3605±0.00 | 0.0536±0.00 | 0.5431±0.00 | 0.0505±0.00 | 0.7658±0.02 | -0.1880±0.00 |
|
||||
|
||||
|
||||
# Contributing
|
||||
|
||||
|
||||
@@ -28,6 +28,8 @@ The default forecasting models are `Linear`. Users can choose other forecasting
|
||||
The results of related methods in Qlib's public dataset can be found [here](../)
|
||||
|
||||
# Requirements
|
||||
Here is the minimal hardware requirements to run the ``workflow.py`` of DDG-DA.
|
||||
Here are the minimal hardware requirements to run the ``workflow.py`` of DDG-DA.
|
||||
* Memory: 45G
|
||||
* Disk: 4G
|
||||
|
||||
Pytorch with CPU & RAM will be enough for this example.
|
||||
|
||||
@@ -117,8 +117,10 @@ def get_all_folders(models, exclude) -> dict:
|
||||
|
||||
|
||||
# function to get all the files under the model folder
|
||||
def get_all_files(folder_path, dataset) -> (str, str):
|
||||
yaml_path = str(Path(f"{folder_path}") / f"*{dataset}*.yaml")
|
||||
def get_all_files(folder_path, dataset, universe="") -> (str, str):
|
||||
if universe != "":
|
||||
universe = f"_{universe}"
|
||||
yaml_path = str(Path(f"{folder_path}") / f"*{dataset}{universe}.yaml")
|
||||
req_path = str(Path(f"{folder_path}") / f"*.txt")
|
||||
yaml_file = glob.glob(yaml_path)
|
||||
req_file = glob.glob(req_path)
|
||||
@@ -224,6 +226,7 @@ class ModelRunner:
|
||||
times=1,
|
||||
models=None,
|
||||
dataset="Alpha360",
|
||||
universe="",
|
||||
exclude=False,
|
||||
qlib_uri: str = "git+https://github.com/microsoft/qlib#egg=pyqlib",
|
||||
exp_folder_name: str = "run_all_model_records",
|
||||
@@ -245,6 +248,9 @@ class ModelRunner:
|
||||
determines whether the model being used is excluded or included.
|
||||
dataset : str
|
||||
determines the dataset to be used for each model.
|
||||
universe : str
|
||||
the stock universe of the dataset.
|
||||
default "" indicates that
|
||||
qlib_uri : str
|
||||
the uri to install qlib with pip
|
||||
it could be url on the we or local path (NOTE: the local path must be a absolute path)
|
||||
@@ -259,6 +265,15 @@ class ModelRunner:
|
||||
-------
|
||||
Here are some use cases of the function in the bash:
|
||||
|
||||
The run_all_models will decide which config to run based no `models` `dataset` `universe`
|
||||
Example 1):
|
||||
|
||||
models="lightgbm", dataset="Alpha158", universe="" will result in running the following config
|
||||
examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml
|
||||
|
||||
models="lightgbm", dataset="Alpha158", universe="csi500" will result in running the following config
|
||||
examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158_csi500.yaml
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
# Case 1 - run all models multiple times
|
||||
@@ -279,6 +294,9 @@ class ModelRunner:
|
||||
# Case 6 - run other models except those are given as arguments for one time
|
||||
python run_all_model.py run --models=[mlp,tft,sfm] --exclude=True
|
||||
|
||||
# Case 7 - run lightgbm model on csi500.
|
||||
python run_all_model.py run 3 lightgbm Alpha158 csi500
|
||||
|
||||
"""
|
||||
self._init_qlib(exp_folder_name)
|
||||
|
||||
@@ -290,7 +308,7 @@ class ModelRunner:
|
||||
for fn in folders:
|
||||
# get all files
|
||||
sys.stderr.write("Retrieving files...\n")
|
||||
yaml_path, req_path = get_all_files(folders[fn], dataset)
|
||||
yaml_path, req_path = get_all_files(folders[fn], dataset, universe=universe)
|
||||
if yaml_path is None:
|
||||
sys.stderr.write(f"There is no {dataset}.yaml file in {folders[fn]}")
|
||||
continue
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
from pathlib import Path
|
||||
|
||||
__version__ = "0.8.5.99"
|
||||
__version__ = "0.8.6"
|
||||
__version__bak = __version__ # This version is backup for QlibConfig.reset_qlib_version
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
@@ -2,24 +2,29 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import List, Tuple, Union, TYPE_CHECKING
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Generator, List, Optional, Tuple, Union
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from .account import Account
|
||||
from .report import Indicator, PortfolioMetrics
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..strategy.base import BaseStrategy
|
||||
from .executor import BaseExecutor
|
||||
from .decision import BaseTradeDecision
|
||||
from .position import Position
|
||||
from .exchange import Exchange
|
||||
from .backtest import backtest_loop
|
||||
from .backtest import collect_data_loop
|
||||
from .utils import CommonInfrastructure
|
||||
from .decision import Order
|
||||
from ..utils import init_instance_by_config
|
||||
from ..log import get_module_logger
|
||||
|
||||
from ..config import C
|
||||
from ..log import get_module_logger
|
||||
from ..utils import init_instance_by_config
|
||||
from .backtest import backtest_loop, collect_data_loop
|
||||
from .decision import Order
|
||||
from .exchange import Exchange
|
||||
from .position import Position
|
||||
from .utils import CommonInfrastructure
|
||||
|
||||
# make import more user-friendly by adding `from qlib.backtest import STH`
|
||||
|
||||
@@ -28,26 +33,34 @@ logger = get_module_logger("backtest caller")
|
||||
|
||||
|
||||
def get_exchange(
|
||||
exchange=None,
|
||||
freq="day",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
codes="all",
|
||||
subscribe_fields=[],
|
||||
open_cost=0.0015,
|
||||
close_cost=0.0025,
|
||||
min_cost=5.0,
|
||||
limit_threshold=None,
|
||||
exchange: Union[str, dict, object, Path] = None,
|
||||
freq: str = "day",
|
||||
start_time: Union[pd.Timestamp, str] = None,
|
||||
end_time: Union[pd.Timestamp, str] = None,
|
||||
codes: Union[list, str] = "all",
|
||||
subscribe_fields: list = [],
|
||||
open_cost: float = 0.0015,
|
||||
close_cost: float = 0.0025,
|
||||
min_cost: float = 5.0,
|
||||
limit_threshold: Union[Tuple[str, str], float, None] = None,
|
||||
deal_price: Union[str, Tuple[str], List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
) -> Exchange:
|
||||
"""get_exchange
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
# exchange related arguments
|
||||
exchange: Exchange().
|
||||
exchange: Exchange(). It could be None or any types that are acceptable by `init_instance_by_config`.
|
||||
freq: str
|
||||
frequency of data.
|
||||
start_time: Union[pd.Timestamp, str]
|
||||
closed start time for backtest.
|
||||
end_time: Union[pd.Timestamp, str]
|
||||
closed end time for backtest.
|
||||
codes: list|str
|
||||
list stock_id list or a string of instruments (i.e. all, csi500, sse50)
|
||||
subscribe_fields: list
|
||||
subscribe fields.
|
||||
open_cost : float
|
||||
@@ -57,8 +70,6 @@ def get_exchange(
|
||||
min_cost : float
|
||||
min transaction cost. It is an absolute amount of cost instead of a ratio of your order's deal amount.
|
||||
e.g. You must pay at least 5 yuan of commission regardless of your order's deal amount.
|
||||
trade_unit : int
|
||||
Included in kwargs. Please refer to the docs of `__init__` of `Exchange`
|
||||
deal_price: Union[str, Tuple[str], List[str]]
|
||||
The `deal_price` supports following two types of input
|
||||
- <deal_price> : str
|
||||
@@ -101,10 +112,14 @@ def get_exchange(
|
||||
|
||||
|
||||
def create_account_instance(
|
||||
start_time, end_time, benchmark: str, account: Union[float, int, dict], pos_type: str = "Position"
|
||||
start_time: Union[pd.Timestamp, str],
|
||||
end_time: Union[pd.Timestamp, str],
|
||||
benchmark: str,
|
||||
account: Union[float, int, dict],
|
||||
pos_type: str = "Position",
|
||||
) -> Account:
|
||||
"""
|
||||
# TODO: is very strange pass benchmark_config in the account(maybe for report)
|
||||
# TODO: is very strange pass benchmark_config in the account (maybe for report)
|
||||
# There should be a post-step to process the report.
|
||||
|
||||
Parameters
|
||||
@@ -132,6 +147,8 @@ def create_account_instance(
|
||||
key "cash" means initial cash.
|
||||
key "stock1" means the information of first stock with amount and price(optional).
|
||||
...
|
||||
pos_type: str
|
||||
Postion type.
|
||||
"""
|
||||
if isinstance(account, (int, float)):
|
||||
pos_kwargs = {"init_cash": account}
|
||||
@@ -159,15 +176,15 @@ def create_account_instance(
|
||||
|
||||
|
||||
def get_strategy_executor(
|
||||
start_time,
|
||||
end_time,
|
||||
strategy: BaseStrategy,
|
||||
executor: BaseExecutor,
|
||||
start_time: Union[pd.Timestamp, str],
|
||||
end_time: Union[pd.Timestamp, str],
|
||||
strategy: Union[str, dict, object, Path],
|
||||
executor: Union[str, dict, object, Path],
|
||||
benchmark: str = "SH000300",
|
||||
account: Union[float, int, Position] = 1e9,
|
||||
exchange_kwargs: dict = {},
|
||||
pos_type: str = "Position",
|
||||
):
|
||||
) -> Tuple[BaseStrategy, BaseExecutor]:
|
||||
|
||||
# NOTE:
|
||||
# - for avoiding recursive import
|
||||
@@ -176,7 +193,11 @@ def get_strategy_executor(
|
||||
from .executor import BaseExecutor # pylint: disable=C0415
|
||||
|
||||
trade_account = create_account_instance(
|
||||
start_time=start_time, end_time=end_time, benchmark=benchmark, account=account, pos_type=pos_type
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
benchmark=benchmark,
|
||||
account=account,
|
||||
pos_type=pos_type,
|
||||
)
|
||||
|
||||
exchange_kwargs = copy.copy(exchange_kwargs)
|
||||
@@ -196,29 +217,31 @@ def get_strategy_executor(
|
||||
|
||||
|
||||
def backtest(
|
||||
start_time,
|
||||
end_time,
|
||||
strategy,
|
||||
executor,
|
||||
benchmark="SH000300",
|
||||
account=1e9,
|
||||
exchange_kwargs={},
|
||||
start_time: Union[pd.Timestamp, str],
|
||||
end_time: Union[pd.Timestamp, str],
|
||||
strategy: Union[str, dict, object, Path],
|
||||
executor: Union[str, dict, object, Path],
|
||||
benchmark: str = "SH000300",
|
||||
account: Union[float, int, Position] = 1e9,
|
||||
exchange_kwargs: dict = {},
|
||||
pos_type: str = "Position",
|
||||
):
|
||||
"""initialize the strategy and executor, then backtest function for the interaction of the outermost strategy and executor in the nested decision execution
|
||||
) -> Tuple[PortfolioMetrics, Indicator]:
|
||||
"""initialize the strategy and executor, then backtest function for the interaction of the outermost strategy and
|
||||
executor in the nested decision execution
|
||||
|
||||
Parameters
|
||||
----------
|
||||
start_time : pd.Timestamp|str
|
||||
start_time : Union[pd.Timestamp, str]
|
||||
closed start time for backtest
|
||||
**NOTE**: This will be applied to the outmost executor's calendar.
|
||||
end_time : pd.Timestamp|str
|
||||
end_time : Union[pd.Timestamp, str]
|
||||
closed end time for backtest
|
||||
**NOTE**: This will be applied to the outmost executor's calendar.
|
||||
E.g. Executor[day](Executor[1min]), setting `end_time == 20XX0301` will include all the minutes on 20XX0301
|
||||
strategy : Union[str, dict, BaseStrategy]
|
||||
for initializing outermost portfolio strategy. Please refer to the docs of init_instance_by_config for more information.
|
||||
executor : Union[str, dict, BaseExecutor]
|
||||
strategy : Union[str, dict, object, Path]
|
||||
for initializing outermost portfolio strategy. Please refer to the docs of init_instance_by_config for more
|
||||
information.
|
||||
executor : Union[str, dict, object, Path]
|
||||
for initializing the outermost executor.
|
||||
benchmark: str
|
||||
the benchmark for reporting.
|
||||
@@ -257,16 +280,16 @@ def backtest(
|
||||
|
||||
|
||||
def collect_data(
|
||||
start_time,
|
||||
end_time,
|
||||
strategy,
|
||||
executor,
|
||||
benchmark="SH000300",
|
||||
account=1e9,
|
||||
exchange_kwargs={},
|
||||
start_time: Union[pd.Timestamp, str],
|
||||
end_time: Union[pd.Timestamp, str],
|
||||
strategy: Union[str, dict, object, Path],
|
||||
executor: Union[str, dict, object, Path],
|
||||
benchmark: str = "SH000300",
|
||||
account: Union[float, int, Position] = 1e9,
|
||||
exchange_kwargs: dict = {},
|
||||
pos_type: str = "Position",
|
||||
return_value: dict = None,
|
||||
):
|
||||
) -> Generator[object, None, None]:
|
||||
"""initialize the strategy and executor, then collect the trade decision data for rl training
|
||||
|
||||
please refer to the docs of the backtest for the explanation of the parameters
|
||||
@@ -291,7 +314,7 @@ def collect_data(
|
||||
|
||||
def format_decisions(
|
||||
decisions: List[BaseTradeDecision],
|
||||
) -> Tuple[str, List[Tuple[BaseTradeDecision, Union[Tuple, None]]]]:
|
||||
) -> Optional[Tuple[str, List[Tuple[BaseTradeDecision, Union[Tuple, None]]]]]:
|
||||
"""
|
||||
format the decisions collected by `qlib.backtest.collect_data`
|
||||
The decisions will be organized into a tree-like structure.
|
||||
@@ -326,4 +349,4 @@ def format_decisions(
|
||||
return res
|
||||
|
||||
|
||||
__all__ = ["Order"]
|
||||
__all__ = ["Order", "backtest"]
|
||||
|
||||
@@ -1,15 +1,18 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from typing import Dict, List, Tuple
|
||||
from qlib.utils import init_instance_by_config
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from .position import BasePosition
|
||||
from .report import PortfolioMetrics, Indicator
|
||||
from qlib.utils import init_instance_by_config
|
||||
|
||||
from .decision import BaseTradeDecision, Order
|
||||
from .exchange import Exchange
|
||||
from .position import BasePosition
|
||||
from .report import Indicator, PortfolioMetrics
|
||||
|
||||
"""
|
||||
rtn & earning in the Account
|
||||
@@ -34,40 +37,42 @@ class AccumulatedInfo:
|
||||
AccumulatedInfo should be shared across different levels
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.rtn = 0 # accumulated return, do not consider cost
|
||||
self.cost = 0 # accumulated cost
|
||||
self.to = 0 # accumulated turnover
|
||||
def reset(self) -> None:
|
||||
self.rtn: float = 0.0 # accumulated return, do not consider cost
|
||||
self.cost: float = 0.0 # accumulated cost
|
||||
self.to: float = 0.0 # accumulated turnover
|
||||
|
||||
def add_return_value(self, value):
|
||||
def add_return_value(self, value: float) -> None:
|
||||
self.rtn += value
|
||||
|
||||
def add_cost(self, value):
|
||||
def add_cost(self, value: float) -> None:
|
||||
self.cost += value
|
||||
|
||||
def add_turnover(self, value):
|
||||
def add_turnover(self, value: float) -> None:
|
||||
self.to += value
|
||||
|
||||
@property
|
||||
def get_return(self):
|
||||
def get_return(self) -> float:
|
||||
return self.rtn
|
||||
|
||||
@property
|
||||
def get_cost(self):
|
||||
def get_cost(self) -> float:
|
||||
return self.cost
|
||||
|
||||
@property
|
||||
def get_turnover(self):
|
||||
def get_turnover(self) -> float:
|
||||
return self.to
|
||||
|
||||
|
||||
class Account:
|
||||
"""
|
||||
The correctness of the metrics of Account in nested execution depends on the shallow copy of `trade_account` in qlib/backtest/executor.py:NestedExecutor
|
||||
Different level of executor has different Account object when calculating metrics. But the position object is shared cross all the Account object.
|
||||
The correctness of the metrics of Account in nested execution depends on the shallow copy of `trade_account` in
|
||||
qlib/backtest/executor.py:NestedExecutor
|
||||
Different level of executor has different Account object when calculating metrics. But the position object is
|
||||
shared cross all the Account object.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -78,7 +83,7 @@ class Account:
|
||||
benchmark_config: dict = {},
|
||||
pos_type: str = "Position",
|
||||
port_metr_enabled: bool = True,
|
||||
):
|
||||
) -> None:
|
||||
"""the trade account of backtest.
|
||||
|
||||
Parameters
|
||||
@@ -102,7 +107,7 @@ class Account:
|
||||
self.benchmark_config = None # avoid no attribute error
|
||||
self.init_vars(init_cash, position_dict, freq, benchmark_config)
|
||||
|
||||
def init_vars(self, init_cash, position_dict, freq: str, benchmark_config: dict):
|
||||
def init_vars(self, init_cash: float, position_dict: dict, freq: str, benchmark_config: dict) -> None:
|
||||
# 1) the following variables are shared by multiple layers
|
||||
# - you will see a shallow copy instead of deepcopy in the NestedExecutor;
|
||||
self.init_cash = init_cash
|
||||
@@ -114,7 +119,7 @@ class Account:
|
||||
"position_dict": position_dict,
|
||||
},
|
||||
"module_path": "qlib.backtest.position",
|
||||
}
|
||||
},
|
||||
)
|
||||
self.accum_info = AccumulatedInfo()
|
||||
|
||||
@@ -123,13 +128,13 @@ class Account:
|
||||
self.hist_positions = {}
|
||||
self.reset(freq=freq, benchmark_config=benchmark_config)
|
||||
|
||||
def is_port_metr_enabled(self):
|
||||
def is_port_metr_enabled(self) -> bool:
|
||||
"""
|
||||
Is portfolio-based metrics enabled.
|
||||
"""
|
||||
return self._port_metr_enabled and not self.current_position.skip_update()
|
||||
|
||||
def reset_report(self, freq, benchmark_config):
|
||||
def reset_report(self, freq: str, benchmark_config: dict) -> None:
|
||||
# portfolio related metrics
|
||||
if self.is_port_metr_enabled():
|
||||
# NOTE:
|
||||
@@ -140,13 +145,13 @@ class Account:
|
||||
# fill stock value
|
||||
# The frequency of account may not align with the trading frequency.
|
||||
# This may result in obscure bugs when data quality is low.
|
||||
if isinstance(self.benchmark_config, dict) and self.benchmark_config.get("start_time") is not None:
|
||||
if isinstance(self.benchmark_config, dict) and "start_time" in self.benchmark_config:
|
||||
self.current_position.fill_stock_value(self.benchmark_config["start_time"], self.freq)
|
||||
|
||||
# trading related metrics(e.g. high-frequency trading)
|
||||
self.indicator = Indicator()
|
||||
|
||||
def reset(self, freq=None, benchmark_config=None, port_metr_enabled: bool = None):
|
||||
def reset(self, freq: str = None, benchmark_config: dict = None, port_metr_enabled: bool = None) -> None:
|
||||
"""reset freq and report of account
|
||||
|
||||
Parameters
|
||||
@@ -155,6 +160,7 @@ class Account:
|
||||
frequency of account & report, by default None
|
||||
benchmark_config : {}, optional
|
||||
benchmark config of report, by default None
|
||||
port_metr_enabled: bool
|
||||
"""
|
||||
if freq is not None:
|
||||
self.freq = freq
|
||||
@@ -165,13 +171,13 @@ class Account:
|
||||
|
||||
self.reset_report(self.freq, self.benchmark_config)
|
||||
|
||||
def get_hist_positions(self):
|
||||
def get_hist_positions(self) -> dict:
|
||||
return self.hist_positions
|
||||
|
||||
def get_cash(self):
|
||||
def get_cash(self) -> float:
|
||||
return self.current_position.get_cash()
|
||||
|
||||
def _update_state_from_order(self, order, trade_val, cost, trade_price):
|
||||
def _update_state_from_order(self, order: Order, trade_val: float, cost: float, trade_price: float) -> None:
|
||||
if self.is_port_metr_enabled():
|
||||
# update turnover
|
||||
self.accum_info.add_turnover(trade_val)
|
||||
@@ -191,13 +197,14 @@ class Account:
|
||||
profit = self.current_position.get_stock_price(order.stock_id) * trade_amount - trade_val
|
||||
self.accum_info.add_return_value(profit) # note here do not consider cost
|
||||
|
||||
def update_order(self, order, trade_val, cost, trade_price):
|
||||
def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float) -> None:
|
||||
if self.current_position.skip_update():
|
||||
# TODO: supporting polymorphism for account
|
||||
# updating order for infinite position is meaningless
|
||||
return
|
||||
|
||||
# if stock is sold out, no stock price information in Position, then we should update account first, then update current position
|
||||
# if stock is sold out, no stock price information in Position, then we should update account first,
|
||||
# then update current position
|
||||
# if stock is bought, there is no stock in current position, update current, then update account
|
||||
# The cost will be subtracted from the cash at last. So the trading logic can ignore the cost calculation
|
||||
if order.direction == Order.SELL:
|
||||
@@ -212,8 +219,15 @@ class Account:
|
||||
self.current_position.update_order(order, trade_val, cost, trade_price)
|
||||
self._update_state_from_order(order, trade_val, cost, trade_price)
|
||||
|
||||
def update_current_position(self, trade_start_time, trade_end_time, trade_exchange):
|
||||
"""update current to make rtn consistent with earning at the end of bar, and update holding bar count of stock"""
|
||||
def update_current_position(
|
||||
self,
|
||||
trade_start_time: pd.Timestamp,
|
||||
trade_end_time: pd.Timestamp,
|
||||
trade_exchange: Exchange,
|
||||
) -> None:
|
||||
"""
|
||||
Update current to make rtn consistent with earning at the end of bar, and update holding bar count of stock
|
||||
"""
|
||||
# update price for stock in the position and the profit from changed_price
|
||||
# NOTE: updating position does not only serve portfolio metrics, it also serve the strategy
|
||||
if not self.current_position.skip_update():
|
||||
@@ -228,7 +242,7 @@ class Account:
|
||||
# NOTE: updating bar_count does not only serve portfolio metrics, it also serve the strategy
|
||||
self.current_position.add_count_all(bar=self.freq)
|
||||
|
||||
def update_portfolio_metrics(self, trade_start_time, trade_end_time):
|
||||
def update_portfolio_metrics(self, trade_start_time: pd.Timestamp, trade_end_time: pd.Timestamp) -> None:
|
||||
"""update portfolio_metrics"""
|
||||
# calculate earning
|
||||
# account_value - last_account_value
|
||||
@@ -243,14 +257,16 @@ class Account:
|
||||
last_account_value = self.portfolio_metrics.get_latest_account_value()
|
||||
last_total_cost = self.portfolio_metrics.get_latest_total_cost()
|
||||
last_total_turnover = self.portfolio_metrics.get_latest_total_turnover()
|
||||
|
||||
# get now_account_value, now_stock_value, now_earning, now_cost, now_turnover
|
||||
now_account_value = self.current_position.calculate_value()
|
||||
now_stock_value = self.current_position.calculate_stock_value()
|
||||
now_earning = now_account_value - last_account_value
|
||||
now_cost = self.accum_info.get_cost - last_total_cost
|
||||
now_turnover = self.accum_info.get_turnover - last_total_turnover
|
||||
|
||||
# update portfolio_metrics for today
|
||||
# judge whether the the trading is begin.
|
||||
# judge whether the trading is begin.
|
||||
# and don't add init account state into portfolio_metrics, due to we don't have excess return in those days.
|
||||
self.portfolio_metrics.update_portfolio_metrics_record(
|
||||
trade_start_time=trade_start_time,
|
||||
@@ -267,7 +283,7 @@ class Account:
|
||||
stock_value=now_stock_value,
|
||||
)
|
||||
|
||||
def update_hist_positions(self, trade_start_time):
|
||||
def update_hist_positions(self, trade_start_time: pd.Timestamp) -> None:
|
||||
"""update history position"""
|
||||
now_account_value = self.current_position.calculate_value()
|
||||
# set now_account_value to position
|
||||
@@ -287,7 +303,7 @@ class Account:
|
||||
inner_order_indicators: List[Dict[str, pd.Series]] = None,
|
||||
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = None,
|
||||
indicator_config: dict = {},
|
||||
):
|
||||
) -> None:
|
||||
"""update trade indicators and order indicators in each bar end"""
|
||||
# TODO: will skip empty decisions make it faster? `outer_trade_decision.empty():`
|
||||
|
||||
@@ -323,7 +339,7 @@ class Account:
|
||||
inner_order_indicators: List[Dict[str, pd.Series]] = None,
|
||||
decision_list: List[Tuple[BaseTradeDecision, pd.Timestamp, pd.Timestamp]] = None,
|
||||
indicator_config: dict = {},
|
||||
):
|
||||
) -> None:
|
||||
"""update account at each trading bar step
|
||||
|
||||
Parameters
|
||||
@@ -338,6 +354,8 @@ class Account:
|
||||
whether the trading executor is atomic, which means there is no higher-frequency trading executor inside it
|
||||
- if atomic is True, calculate the indicators with trade_info
|
||||
- else, aggregate indicators with inner indicators
|
||||
outer_trade_decision: BaseTradeDecision
|
||||
external trade decision
|
||||
trade_info : List[(Order, float, float, float)], optional
|
||||
trading information, by default None
|
||||
- necessary if atomic is True
|
||||
@@ -377,7 +395,7 @@ class Account:
|
||||
indicator_config=indicator_config,
|
||||
)
|
||||
|
||||
def get_portfolio_metrics(self):
|
||||
def get_portfolio_metrics(self) -> Tuple[pd.DataFrame, dict]:
|
||||
"""get the history portfolio_metrics and positions instance"""
|
||||
if self.is_port_metr_enabled():
|
||||
_portfolio_metrics = self.portfolio_metrics.generate_portfolio_metrics_dataframe()
|
||||
|
||||
@@ -2,17 +2,29 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Generator, Optional, Tuple, Union
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest.decision import BaseTradeDecision
|
||||
from typing import TYPE_CHECKING
|
||||
from qlib.backtest.report import Indicator, PortfolioMetrics
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qlib.strategy.base import BaseStrategy
|
||||
from qlib.backtest.executor import BaseExecutor
|
||||
from ..utils.time import Freq
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from ..utils.time import Freq
|
||||
|
||||
def backtest_loop(start_time, end_time, trade_strategy: BaseStrategy, trade_executor: BaseExecutor):
|
||||
|
||||
def backtest_loop(
|
||||
start_time: Union[pd.Timestamp, str],
|
||||
end_time: Union[pd.Timestamp, str],
|
||||
trade_strategy: BaseStrategy,
|
||||
trade_executor: BaseExecutor,
|
||||
) -> Tuple[PortfolioMetrics, Indicator]:
|
||||
"""backtest function for the interaction of the outermost strategy and executor in the nested decision execution
|
||||
|
||||
please refer to the docs of `collect_data_loop`
|
||||
@@ -31,19 +43,23 @@ def backtest_loop(start_time, end_time, trade_strategy: BaseStrategy, trade_exec
|
||||
|
||||
|
||||
def collect_data_loop(
|
||||
start_time, end_time, trade_strategy: BaseStrategy, trade_executor: BaseExecutor, return_value: dict = None
|
||||
):
|
||||
start_time: Union[pd.Timestamp, str],
|
||||
end_time: Union[pd.Timestamp, str],
|
||||
trade_strategy: BaseStrategy,
|
||||
trade_executor: BaseExecutor,
|
||||
return_value: dict = None,
|
||||
) -> Generator[BaseTradeDecision, Optional[BaseTradeDecision], None]:
|
||||
"""Generator for collecting the trade decision data for rl training
|
||||
|
||||
Parameters
|
||||
----------
|
||||
start_time : pd.Timestamp|str
|
||||
start_time : Union[pd.Timestamp, str]
|
||||
closed start time for backtest
|
||||
**NOTE**: This will be applied to the outmost executor's calendar.
|
||||
end_time : pd.Timestamp|str
|
||||
end_time : Union[pd.Timestamp, str]
|
||||
closed end time for backtest
|
||||
**NOTE**: This will be applied to the outmost executor's calendar.
|
||||
E.g. Executor[day](Executor[1min]), setting `end_time == 20XX0301` will include all the minutes on 20XX0301
|
||||
E.g. Executor[day](Executor[1min]), setting `end_time == 20XX0301` will include all the minutes on 20XX0301
|
||||
trade_strategy : BaseStrategy
|
||||
the outermost portfolio strategy
|
||||
trade_executor : BaseExecutor
|
||||
|
||||
@@ -2,23 +2,26 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
from enum import IntEnum
|
||||
from qlib.data.data import Cal
|
||||
from qlib.utils.time import concat_date_time, epsilon_change
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
from typing import ClassVar, Optional, Union, List, Tuple
|
||||
from abc import abstractmethod
|
||||
from enum import IntEnum
|
||||
|
||||
# try to fix circular imports when enabling type hints
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, ClassVar, List, Optional, Tuple, Union
|
||||
|
||||
from qlib.backtest.utils import TradeCalendarManager
|
||||
from qlib.data.data import Cal
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.utils.time import concat_date_time, epsilon_change
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qlib.strategy.base import BaseStrategy
|
||||
from qlib.backtest.exchange import Exchange
|
||||
from qlib.backtest.utils import TradeCalendarManager
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class OrderDir(IntEnum):
|
||||
@@ -46,7 +49,7 @@ class Order:
|
||||
# - they are set by users and is time-invariant.
|
||||
stock_id: str
|
||||
amount: float # `amount` is a non-negative and adjusted value
|
||||
direction: int
|
||||
direction: OrderDir
|
||||
|
||||
# 2) time variant values:
|
||||
# - Users may want to set these values when using lower level APIs
|
||||
@@ -61,7 +64,7 @@ class Order:
|
||||
# What the value should be about in all kinds of cases
|
||||
# - not tradable: the deal_amount == 0 , factor is None
|
||||
# - the stock is suspended and the entire order fails. No cost for this order
|
||||
# - dealed or partially dealed: deal_amount >= 0 and factor is not None
|
||||
# - dealt or partially dealt: deal_amount >= 0 and factor is not None
|
||||
deal_amount: Optional[float] = None # `deal_amount` is a non-negative value
|
||||
factor: Optional[float] = None
|
||||
|
||||
@@ -74,10 +77,10 @@ class Order:
|
||||
SELL: ClassVar[OrderDir] = OrderDir.SELL
|
||||
BUY: ClassVar[OrderDir] = OrderDir.BUY
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
if self.direction not in {Order.SELL, Order.BUY}:
|
||||
raise NotImplementedError("direction not supported, `Order.SELL` for sell, `Order.BUY` for buy")
|
||||
self.deal_amount = 0
|
||||
self.deal_amount = 0.0
|
||||
self.factor = None
|
||||
|
||||
@property
|
||||
@@ -99,7 +102,7 @@ class Order:
|
||||
return self.deal_amount * self.sign
|
||||
|
||||
@property
|
||||
def sign(self) -> float:
|
||||
def sign(self) -> int:
|
||||
"""
|
||||
return the sign of trading
|
||||
- `+1` indicates buying
|
||||
@@ -112,15 +115,12 @@ class Order:
|
||||
if isinstance(direction, OrderDir):
|
||||
return direction
|
||||
elif isinstance(direction, (int, float, np.integer, np.floating)):
|
||||
if direction > 0:
|
||||
return Order.BUY
|
||||
else:
|
||||
return Order.SELL
|
||||
return Order.BUY if direction > 0 else Order.SELL
|
||||
elif isinstance(direction, str):
|
||||
dl = direction.lower()
|
||||
if dl.strip() == "sell":
|
||||
dl = direction.lower().strip()
|
||||
if dl == "sell":
|
||||
return OrderDir.SELL
|
||||
elif dl.strip() == "buy":
|
||||
elif dl == "buy":
|
||||
return OrderDir.BUY
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
@@ -138,14 +138,14 @@ class OrderHelper:
|
||||
Motivation
|
||||
- Make generating order easier
|
||||
- User may have no knowledge about the adjust-factor information about the system.
|
||||
- It involves to much interaction with the exchange when generating orders.
|
||||
- It involves too much interaction with the exchange when generating orders.
|
||||
"""
|
||||
|
||||
def __init__(self, exchange: Exchange):
|
||||
def __init__(self, exchange: Exchange) -> None:
|
||||
self.exchange = exchange
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
self,
|
||||
code: str,
|
||||
amount: float,
|
||||
direction: OrderDir,
|
||||
@@ -175,21 +175,18 @@ class OrderHelper:
|
||||
Order:
|
||||
The created order
|
||||
"""
|
||||
if start_time is not None:
|
||||
start_time = pd.Timestamp(start_time)
|
||||
if end_time is not None:
|
||||
end_time = pd.Timestamp(end_time)
|
||||
# NOTE: factor is a value belongs to the results section. User don't have to care about it when creating orders
|
||||
return Order(
|
||||
stock_id=code,
|
||||
amount=amount,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
start_time=start_time if start_time is not None else pd.Timestamp(start_time),
|
||||
end_time=end_time if end_time is not None else pd.Timestamp(end_time),
|
||||
direction=direction,
|
||||
)
|
||||
|
||||
|
||||
class TradeRange:
|
||||
@abstractmethod
|
||||
def __call__(self, trade_calendar: TradeCalendarManager) -> Tuple[int, int]:
|
||||
"""
|
||||
This method will be call with following way
|
||||
@@ -216,6 +213,7 @@ class TradeRange:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `__call__` method")
|
||||
|
||||
@abstractmethod
|
||||
def clip_time_range(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Tuple[pd.Timestamp, pd.Timestamp]:
|
||||
"""
|
||||
Parameters
|
||||
@@ -234,23 +232,26 @@ class TradeRange:
|
||||
|
||||
|
||||
class IdxTradeRange(TradeRange):
|
||||
def __init__(self, start_idx: int, end_idx: int):
|
||||
def __init__(self, start_idx: int, end_idx: int) -> None:
|
||||
self._start_idx = start_idx
|
||||
self._end_idx = end_idx
|
||||
|
||||
def __call__(self, trade_calendar: TradeCalendarManager = None) -> Tuple[int, int]:
|
||||
return self._start_idx, self._end_idx
|
||||
|
||||
def clip_time_range(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Tuple[pd.Timestamp, pd.Timestamp]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TradeRangeByTime(TradeRange):
|
||||
"""This is a helper function for make decisions"""
|
||||
|
||||
def __init__(self, start_time: str, end_time: str):
|
||||
def __init__(self, start_time: str, end_time: str) -> None:
|
||||
"""
|
||||
This is a callable class.
|
||||
|
||||
**NOTE**:
|
||||
- It is designed for minute-bar for intraday trading!!!!!
|
||||
- It is designed for minute-bar for intra-day trading!!!!!
|
||||
- Both start_time and end_time are **closed** in the range
|
||||
|
||||
Parameters
|
||||
@@ -264,26 +265,25 @@ class TradeRangeByTime(TradeRange):
|
||||
self.end_time = pd.Timestamp(end_time).time()
|
||||
assert self.start_time < self.end_time
|
||||
|
||||
def __call__(self, trade_calendar: TradeCalendarManager = None) -> Tuple[int, int]:
|
||||
def __call__(self, trade_calendar: TradeCalendarManager) -> Tuple[int, int]:
|
||||
if trade_calendar is None:
|
||||
raise NotImplementedError("trade_calendar is necessary for getting TradeRangeByTime.")
|
||||
start = trade_calendar.start_time
|
||||
val_start, val_end = concat_date_time(start.date(), self.start_time), concat_date_time(
|
||||
start.date(), self.end_time
|
||||
)
|
||||
|
||||
start_date = trade_calendar.start_time.date()
|
||||
val_start, val_end = concat_date_time(start_date, self.start_time), concat_date_time(start_date, self.end_time)
|
||||
return trade_calendar.get_range_idx(val_start, val_end)
|
||||
|
||||
def clip_time_range(self, start_time: pd.Timestamp, end_time: pd.Timestamp) -> Tuple[pd.Timestamp, pd.Timestamp]:
|
||||
start_date = start_time.date()
|
||||
val_start, val_end = concat_date_time(start_date, self.start_time), concat_date_time(start_date, self.end_time)
|
||||
# NOTE: `end_date` should not be used. Because the `end_date` is for slicing. It may be in the next day
|
||||
# Assumption: start_time and end_time is for intraday trading. So it is OK for only using start_date
|
||||
# Assumption: start_time and end_time is for intra-day trading. So it is OK for only using start_date
|
||||
return max(val_start, start_time), min(val_end, end_time)
|
||||
|
||||
|
||||
class BaseTradeDecision:
|
||||
"""
|
||||
Trade decisions ara made by strategy and executed by exeuter
|
||||
Trade decisions ara made by strategy and executed by executor
|
||||
|
||||
Motivation:
|
||||
Here are several typical scenarios for `BaseTradeDecision`
|
||||
@@ -297,7 +297,7 @@ class BaseTradeDecision:
|
||||
2. Same as `case 1.3`
|
||||
"""
|
||||
|
||||
def __init__(self, strategy: BaseStrategy, trade_range: Union[Tuple[int, int], TradeRange] = None):
|
||||
def __init__(self, strategy: BaseStrategy, trade_range: Union[Tuple[int, int], TradeRange] = None) -> None:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -339,7 +339,7 @@ class BaseTradeDecision:
|
||||
"""
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
def update(self, trade_calendar: TradeCalendarManager) -> Union["BaseTradeDecision", None]:
|
||||
def update(self, trade_calendar: TradeCalendarManager) -> Optional[BaseTradeDecision]:
|
||||
"""
|
||||
Be called at the **start** of each step.
|
||||
|
||||
@@ -354,10 +354,8 @@ class BaseTradeDecision:
|
||||
|
||||
Returns
|
||||
-------
|
||||
None:
|
||||
No update, use previous decision(or unavailable)
|
||||
BaseTradeDecision:
|
||||
New update, use new decision
|
||||
New update, use new decision. If no updates, return None (use previous decision (or unavailable))
|
||||
"""
|
||||
# purpose 1)
|
||||
self.total_step = trade_calendar.get_trade_len()
|
||||
@@ -412,12 +410,12 @@ class BaseTradeDecision:
|
||||
"""
|
||||
try:
|
||||
_start_idx, _end_idx = self._get_range_limit(**kwargs)
|
||||
except NotImplementedError:
|
||||
except NotImplementedError as e:
|
||||
if "default_value" in kwargs:
|
||||
return kwargs["default_value"]
|
||||
else:
|
||||
# Default to get full index
|
||||
raise NotImplementedError(f"The decision didn't provide an index range") from NotImplementedError
|
||||
raise NotImplementedError(f"The decision didn't provide an index range") from e
|
||||
|
||||
# clip index
|
||||
if getattr(self, "total_step", None) is not None:
|
||||
@@ -426,7 +424,7 @@ class BaseTradeDecision:
|
||||
if _start_idx < 0 or _end_idx >= self.total_step:
|
||||
logger = get_module_logger("decision")
|
||||
logger.warning(
|
||||
f"[{_start_idx},{_end_idx}] go beyoud the total_step({self.total_step}), it will be clipped"
|
||||
f"[{_start_idx},{_end_idx}] go beyond the total_step({self.total_step}), it will be clipped.",
|
||||
)
|
||||
_start_idx, _end_idx = max(0, _start_idx), min(self.total_step - 1, _end_idx)
|
||||
return _start_idx, _end_idx
|
||||
@@ -444,7 +442,7 @@ class BaseTradeDecision:
|
||||
Parameters
|
||||
----------
|
||||
rtype: str
|
||||
- "full": return the full limitation of the deicsion in the day
|
||||
- "full": return the full limitation of the decision in the day
|
||||
- "step": return the limitation of current step
|
||||
|
||||
raise_error: bool
|
||||
@@ -497,11 +495,10 @@ class BaseTradeDecision:
|
||||
return True
|
||||
return True
|
||||
|
||||
def mod_inner_decision(self, inner_trade_decision: BaseTradeDecision):
|
||||
def mod_inner_decision(self, inner_trade_decision: BaseTradeDecision) -> None:
|
||||
"""
|
||||
|
||||
This method will be called on the inner_trade_decision after it is generated.
|
||||
`inner_trade_decision` will be changed **inplaced**.
|
||||
`inner_trade_decision` will be changed **inplace**.
|
||||
|
||||
Motivation of the `mod_inner_decision`
|
||||
- Leave a hook for outer decision to affect the decision generated by the inner strategy
|
||||
@@ -520,6 +517,9 @@ class BaseTradeDecision:
|
||||
|
||||
|
||||
class EmptyTradeDecision(BaseTradeDecision):
|
||||
def get_decision(self) -> List[object]:
|
||||
return []
|
||||
|
||||
def empty(self) -> bool:
|
||||
return True
|
||||
|
||||
@@ -544,4 +544,9 @@ class TradeDecisionWO(BaseTradeDecision):
|
||||
return self.order_list
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"class: {self.__class__.__name__}; strategy: {self.strategy}; trade_range: {self.trade_range}; order_list[{len(self.order_list)}]"
|
||||
return (
|
||||
f"class: {self.__class__.__name__}; "
|
||||
f"strategy: {self.strategy}; "
|
||||
f"trade_range: {self.trade_range}; "
|
||||
f"order_list[{len(self.order_list)}]"
|
||||
)
|
||||
|
||||
@@ -1,21 +1,25 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import List, Tuple, Union
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
|
||||
|
||||
from ..utils.index_data import IndexData
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .account import Account
|
||||
|
||||
from qlib.backtest.position import BasePosition, Position
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from ..data.data import D
|
||||
from qlib.backtest.position import BasePosition
|
||||
|
||||
from ..config import C
|
||||
from ..constant import REG_CN
|
||||
from ..data.data import D
|
||||
from ..log import get_module_logger
|
||||
from .decision import Order, OrderDir, OrderHelper
|
||||
from .high_performance_ds import BaseQuote, NumpyQuote
|
||||
@@ -24,22 +28,22 @@ from .high_performance_ds import BaseQuote, NumpyQuote
|
||||
class Exchange:
|
||||
def __init__(
|
||||
self,
|
||||
freq="day",
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
codes="all",
|
||||
freq: str = "day",
|
||||
start_time: Union[pd.Timestamp, str] = None,
|
||||
end_time: Union[pd.Timestamp, str] = None,
|
||||
codes: Union[list, str] = "all",
|
||||
deal_price: Union[str, Tuple[str], List[str]] = None,
|
||||
subscribe_fields=[],
|
||||
subscribe_fields: list = [],
|
||||
limit_threshold: Union[Tuple[str, str], float, None] = None,
|
||||
volume_threshold=None,
|
||||
open_cost=0.0015,
|
||||
close_cost=0.0025,
|
||||
min_cost=5,
|
||||
impact_cost=0.0,
|
||||
extra_quote=None,
|
||||
quote_cls=NumpyQuote,
|
||||
volume_threshold: Union[tuple, dict] = None,
|
||||
open_cost: float = 0.0015,
|
||||
close_cost: float = 0.0025,
|
||||
min_cost: float = 5.0,
|
||||
impact_cost: float = 0.0,
|
||||
extra_quote: pd.DataFrame = None,
|
||||
quote_cls: Type[BaseQuote] = NumpyQuote,
|
||||
**kwargs,
|
||||
):
|
||||
) -> None:
|
||||
"""__init__
|
||||
:param freq: frequency of data
|
||||
:param start_time: closed start time for backtest
|
||||
@@ -72,11 +76,12 @@ class Exchange:
|
||||
]
|
||||
1) ("cum" or "current", limit_str) denotes a single volume limit.
|
||||
- limit_str is qlib data expression which is allowed to define your own Operator.
|
||||
Please refer to qlib/contrib/ops/high_freq.py, here are any custom operator for high frequency,
|
||||
such as DayCumsum. !!!NOTE: if you want you use the custom operator, you need to
|
||||
register it in qlib_init.
|
||||
- "cum" means that this is a cumulative value over time, such as cumulative market volume.
|
||||
So when it is used as a volume limit, it is necessary to subtract the dealt amount.
|
||||
Please refer to qlib/contrib/ops/high_freq.py, here are any custom operator for
|
||||
high frequency, such as DayCumsum. !!!NOTE: if you want you use the custom
|
||||
operator, you need to register it in qlib_init.
|
||||
- "cum" means that this is a cumulative value over time, such as cumulative market
|
||||
volume. So when it is used as a volume limit, it is necessary to subtract the dealt
|
||||
amount.
|
||||
- "current" means that this is a real-time value and will not accumulate over time,
|
||||
so it can be directly used as a capacity limit.
|
||||
e.g. ("cum", "0.2 * DayCumsum($volume, '9:45', '14:45')"), ("current", "$bidV1")
|
||||
@@ -84,7 +89,7 @@ class Exchange:
|
||||
"buy" means the volume limits of buying. "sell" means the volume limits of selling.
|
||||
Different volume limits will be aggregated with min(). If volume_threshold is only
|
||||
("cum" or "current", limit_str) instead of a dict, the volume limits are for
|
||||
both by deault. In other words, it is same as {"all": ("cum" or "current", limit_str)}.
|
||||
both by default. In other words, it is same as {"all": ("cum" or "current", limit_str)}.
|
||||
3) e.g. "volume_threshold": {
|
||||
"all": ("cum", "0.2 * DayCumsum($volume, '9:45', '14:45')"),
|
||||
"buy": ("current", "$askV1"),
|
||||
@@ -104,13 +109,14 @@ class Exchange:
|
||||
Necessary fields:
|
||||
$close is for calculating the total value at end of each day.
|
||||
Optional fields:
|
||||
$volume is only necessary when we limit the trade amount or calculate PA(vwap) indicator
|
||||
$volume is only necessary when we limit the trade amount or calculate
|
||||
PA(vwap) indicator
|
||||
$vwap is only necessary when we use the $vwap price as the deal price
|
||||
$factor is for rounding to the trading unit
|
||||
limit_sell will be set to False by default(False indicates we can sell this
|
||||
target on this day).
|
||||
limit_buy will be set to False by default(False indicates we can buy this
|
||||
target on this day).
|
||||
limit_sell will be set to False by default (False indicates we can sell
|
||||
this target on this day).
|
||||
limit_buy will be set to False by default (False indicates we can buy
|
||||
this target on this day).
|
||||
index: MultipleIndex(instrument, pd.Datetime)
|
||||
"""
|
||||
self.freq = freq
|
||||
@@ -163,7 +169,7 @@ class Exchange:
|
||||
if self.limit_type == self.LT_TP_EXP:
|
||||
for exp in limit_threshold:
|
||||
necessary_fields.add(exp)
|
||||
all_fields = necessary_fields | vol_lt_fields
|
||||
all_fields = necessary_fields | set(vol_lt_fields)
|
||||
all_fields = list(all_fields | set(subscribe_fields))
|
||||
|
||||
self.all_fields = all_fields
|
||||
@@ -182,17 +188,22 @@ class Exchange:
|
||||
self.quote_cls = quote_cls
|
||||
self.quote: BaseQuote = self.quote_cls(self.quote_df, freq)
|
||||
|
||||
def get_quote_from_qlib(self):
|
||||
def get_quote_from_qlib(self) -> None:
|
||||
# get stock data from qlib
|
||||
if len(self.codes) == 0:
|
||||
self.codes = D.instruments()
|
||||
self.quote_df = D.features(
|
||||
self.codes, self.all_fields, self.start_time, self.end_time, freq=self.freq, disk_cache=True
|
||||
self.codes,
|
||||
self.all_fields,
|
||||
self.start_time,
|
||||
self.end_time,
|
||||
freq=self.freq,
|
||||
disk_cache=True,
|
||||
).dropna(subset=["$close"])
|
||||
self.quote_df.columns = self.all_fields
|
||||
|
||||
# check buy_price data and sell_price data
|
||||
for attr in "buy_price", "sell_price":
|
||||
for attr in ("buy_price", "sell_price"):
|
||||
pstr = getattr(self, attr) # price string
|
||||
if self.quote_df[pstr].isna().any():
|
||||
self.logger.warning("{} field data contains nan.".format(pstr))
|
||||
@@ -238,7 +249,7 @@ class Exchange:
|
||||
LT_FLT = "float" # float
|
||||
LT_NONE = "none" # none
|
||||
|
||||
def _get_limit_type(self, limit_threshold):
|
||||
def _get_limit_type(self, limit_threshold: Union[Tuple, float, None]) -> str:
|
||||
"""get limit type"""
|
||||
if isinstance(limit_threshold, Tuple):
|
||||
return self.LT_TP_EXP
|
||||
@@ -249,7 +260,7 @@ class Exchange:
|
||||
else:
|
||||
raise NotImplementedError(f"This type of `limit_threshold` is not supported")
|
||||
|
||||
def _update_limit(self, limit_threshold):
|
||||
def _update_limit(self, limit_threshold: Union[Tuple, float, None]) -> None:
|
||||
# check limit_threshold
|
||||
limit_type = self._get_limit_type(limit_threshold)
|
||||
if limit_type == self.LT_NONE:
|
||||
@@ -263,9 +274,10 @@ class Exchange:
|
||||
self.quote_df["limit_buy"] = self.quote_df["$change"].ge(limit_threshold)
|
||||
self.quote_df["limit_sell"] = self.quote_df["$change"].le(-limit_threshold) # pylint: disable=E1130
|
||||
|
||||
def _get_vol_limit(self, volume_threshold):
|
||||
@staticmethod
|
||||
def _get_vol_limit(volume_threshold: Union[tuple, dict]) -> Tuple[Optional[list], Optional[list], set]:
|
||||
"""
|
||||
preproccess the volume limit.
|
||||
preprocess the volume limit.
|
||||
get the fields need to get from qlib.
|
||||
get the volume limit list of buying and selling which is composed of all limits.
|
||||
Parameters
|
||||
@@ -295,8 +307,7 @@ class Exchange:
|
||||
volume_threshold = {"all": volume_threshold}
|
||||
|
||||
assert isinstance(volume_threshold, dict)
|
||||
for key in volume_threshold:
|
||||
vol_limit = volume_threshold[key]
|
||||
for key, vol_limit in volume_threshold.items():
|
||||
assert isinstance(vol_limit, tuple)
|
||||
fields.add(vol_limit[1])
|
||||
|
||||
@@ -307,10 +318,19 @@ class Exchange:
|
||||
|
||||
return buy_vol_limit, sell_vol_limit, fields
|
||||
|
||||
def check_stock_limit(self, stock_id, start_time, end_time, direction=None):
|
||||
def check_stock_limit(
|
||||
self,
|
||||
stock_id: str,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
direction: int = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
stock_id : str
|
||||
start_time: pd.Timestamp
|
||||
end_time: pd.Timestamp
|
||||
direction : int, optional
|
||||
trade direction, by default None
|
||||
- if direction is None, check if tradable for buying and selling.
|
||||
@@ -328,39 +348,42 @@ class Exchange:
|
||||
else:
|
||||
raise ValueError(f"direction {direction} is not supported!")
|
||||
|
||||
def check_stock_suspended(self, stock_id, start_time, end_time):
|
||||
def check_stock_suspended(
|
||||
self,
|
||||
stock_id: str,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
) -> bool:
|
||||
# is suspended
|
||||
if stock_id in self.quote.get_all_stock():
|
||||
return self.quote.get_data(stock_id, start_time, end_time, "$close") is None
|
||||
else:
|
||||
return True
|
||||
|
||||
def is_stock_tradable(self, stock_id, start_time, end_time, direction=None):
|
||||
def is_stock_tradable(
|
||||
self,
|
||||
stock_id: str,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
direction: int = None,
|
||||
) -> bool:
|
||||
# check if stock can be traded
|
||||
# same as check in check_order
|
||||
if self.check_stock_suspended(stock_id, start_time, end_time) or self.check_stock_limit(
|
||||
stock_id, start_time, end_time, direction
|
||||
):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
return not (
|
||||
self.check_stock_suspended(stock_id, start_time, end_time)
|
||||
or self.check_stock_limit(stock_id, start_time, end_time, direction)
|
||||
)
|
||||
|
||||
def check_order(self, order):
|
||||
def check_order(self, order: Order) -> bool:
|
||||
# check limit and suspended
|
||||
if self.check_stock_suspended(order.stock_id, order.start_time, order.end_time) or self.check_stock_limit(
|
||||
order.stock_id, order.start_time, order.end_time, order.direction
|
||||
):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
return self.is_stock_tradable(order.stock_id, order.start_time, order.end_time, order.direction)
|
||||
|
||||
def deal_order(
|
||||
self,
|
||||
order,
|
||||
order: Order,
|
||||
trade_account: Account = None,
|
||||
position: BasePosition = None,
|
||||
dealt_order_amount: defaultdict = defaultdict(float),
|
||||
):
|
||||
) -> Tuple[float, float, float]:
|
||||
"""
|
||||
Deal order when the actual transaction
|
||||
the results section in `Order` will be changed.
|
||||
@@ -371,9 +394,9 @@ class Exchange:
|
||||
:return: trade_val, trade_cost, trade_price
|
||||
"""
|
||||
# check order first.
|
||||
if self.check_order(order) is False:
|
||||
if not self.check_order(order):
|
||||
order.deal_amount = 0.0
|
||||
# using np.nan instead of None to make it more convenient to should the value in format string
|
||||
# using np.nan instead of None to make it more convenient to show the value in format string
|
||||
self.logger.debug(f"Order failed due to trading limitation: {order}")
|
||||
return 0.0, 0.0, np.nan
|
||||
|
||||
@@ -382,7 +405,9 @@ class Exchange:
|
||||
|
||||
# NOTE: order will be changed in this function
|
||||
trade_price, trade_val, trade_cost = self._calc_trade_info_by_order(
|
||||
order, trade_account.current_position if trade_account else position, dealt_order_amount
|
||||
order,
|
||||
trade_account.current_position if trade_account else position,
|
||||
dealt_order_amount,
|
||||
)
|
||||
if trade_val > 1e-5:
|
||||
# If the order can only be deal 0 value. Nothing to be updated
|
||||
@@ -396,23 +421,49 @@ class Exchange:
|
||||
|
||||
return trade_val, trade_cost, trade_price
|
||||
|
||||
def get_quote_info(self, stock_id, start_time, end_time, method="ts_data_last"):
|
||||
return self.quote.get_data(stock_id, start_time, end_time, method=method)
|
||||
def get_quote_info(
|
||||
self,
|
||||
stock_id: str,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
method: str = "ts_data_last",
|
||||
) -> Union[None, int, float, bool, IndexData]:
|
||||
return self.quote.get_data(stock_id, start_time, end_time, method=method) # TODO: missing `field`?
|
||||
|
||||
def get_close(self, stock_id, start_time, end_time, method="ts_data_last"):
|
||||
def get_close(
|
||||
self,
|
||||
stock_id: str,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
method: str = "ts_data_last",
|
||||
) -> Union[None, int, float, bool, IndexData]:
|
||||
return self.quote.get_data(stock_id, start_time, end_time, field="$close", method=method)
|
||||
|
||||
def get_volume(self, stock_id, start_time, end_time, method="sum"):
|
||||
def get_volume(
|
||||
self,
|
||||
stock_id: str,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
method: str = "sum",
|
||||
) -> float:
|
||||
"""get the total deal volume of stock with `stock_id` between the time interval [start_time, end_time)"""
|
||||
return self.quote.get_data(stock_id, start_time, end_time, field="$volume", method=method)
|
||||
|
||||
def get_deal_price(self, stock_id, start_time, end_time, direction: OrderDir, method="ts_data_last"):
|
||||
def get_deal_price(
|
||||
self,
|
||||
stock_id: str,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
direction: OrderDir,
|
||||
method: str = "ts_data_last",
|
||||
) -> float:
|
||||
if direction == OrderDir.SELL:
|
||||
pstr = self.sell_price
|
||||
elif direction == OrderDir.BUY:
|
||||
pstr = self.buy_price
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
deal_price = self.quote.get_data(stock_id, start_time, end_time, field=pstr, method=method)
|
||||
if method is not None and (deal_price is None or np.isnan(deal_price) or deal_price <= 1e-08):
|
||||
self.logger.warning(f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {pstr}): {deal_price}!!!")
|
||||
@@ -420,11 +471,16 @@ class Exchange:
|
||||
deal_price = self.get_close(stock_id, start_time, end_time, method)
|
||||
return deal_price
|
||||
|
||||
def get_factor(self, stock_id, start_time, end_time) -> Union[float, None]:
|
||||
def get_factor(
|
||||
self,
|
||||
stock_id: str,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
Returns
|
||||
-------
|
||||
Union[float, None]:
|
||||
Optional[float]:
|
||||
`None`: if the stock is suspended `None` may be returned
|
||||
`float`: return factor if the factor exists
|
||||
"""
|
||||
@@ -434,11 +490,16 @@ class Exchange:
|
||||
return self.quote.get_data(stock_id, start_time, end_time, field="$factor", method="ts_data_last")
|
||||
|
||||
def generate_amount_position_from_weight_position(
|
||||
self, weight_position, cash, start_time, end_time, direction=OrderDir.BUY
|
||||
):
|
||||
self,
|
||||
weight_position: dict,
|
||||
cash: float,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
direction: OrderDir = OrderDir.BUY,
|
||||
) -> dict:
|
||||
"""
|
||||
The generate the target position according to the weight and the cash.
|
||||
NOTE: All the cash will assigned to the tadable stock.
|
||||
NOTE: All the cash will assigned to the tradable stock.
|
||||
Parameter:
|
||||
weight_position : dict {stock_id : weight}; allocate cash by weight_position
|
||||
among then, weight must be in this range: 0 < weight < 1
|
||||
@@ -451,15 +512,14 @@ class Exchange:
|
||||
|
||||
# calculate the total weight of tradable value
|
||||
tradable_weight = 0.0
|
||||
for stock_id in weight_position:
|
||||
for stock_id, wp in weight_position.items():
|
||||
if self.is_stock_tradable(stock_id=stock_id, start_time=start_time, end_time=end_time):
|
||||
# weight_position must be greater than 0 and less than 1
|
||||
if weight_position[stock_id] < 0 or weight_position[stock_id] > 1:
|
||||
if wp < 0 or wp > 1:
|
||||
raise ValueError(
|
||||
"weight_position is {}, "
|
||||
"weight_position is not in the range of (0, 1).".format(weight_position[stock_id])
|
||||
"weight_position is {}, " "weight_position is not in the range of (0, 1).".format(wp),
|
||||
)
|
||||
tradable_weight += weight_position[stock_id]
|
||||
tradable_weight += wp
|
||||
|
||||
if tradable_weight - 1.0 >= 1e-5:
|
||||
raise ValueError("tradable_weight is {}, can not greater than 1.".format(tradable_weight))
|
||||
@@ -467,19 +527,24 @@ class Exchange:
|
||||
amount_dict = {}
|
||||
for stock_id in weight_position:
|
||||
if weight_position[stock_id] > 0.0 and self.is_stock_tradable(
|
||||
stock_id=stock_id, start_time=start_time, end_time=end_time
|
||||
stock_id=stock_id,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
):
|
||||
amount_dict[stock_id] = (
|
||||
cash
|
||||
* weight_position[stock_id]
|
||||
/ tradable_weight
|
||||
// self.get_deal_price(
|
||||
stock_id=stock_id, start_time=start_time, end_time=end_time, direction=direction
|
||||
stock_id=stock_id,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
direction=direction,
|
||||
)
|
||||
)
|
||||
return amount_dict
|
||||
|
||||
def get_real_deal_amount(self, current_amount, target_amount, factor):
|
||||
def get_real_deal_amount(self, current_amount: float, target_amount: float, factor: float) -> float:
|
||||
"""
|
||||
Calculate the real adjust deal amount when considering the trading unit
|
||||
:param current_amount:
|
||||
@@ -501,7 +566,13 @@ class Exchange:
|
||||
deal_amount = self.round_amount_by_trade_unit(deal_amount, factor)
|
||||
return -deal_amount
|
||||
|
||||
def generate_order_for_target_amount_position(self, target_position, current_position, start_time, end_time):
|
||||
def generate_order_for_target_amount_position(
|
||||
self,
|
||||
target_position: dict,
|
||||
current_position: dict,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
) -> list:
|
||||
"""
|
||||
Note: some future information is used in this function
|
||||
Parameter:
|
||||
@@ -517,7 +588,8 @@ class Exchange:
|
||||
# three parts: kept stock_id, dropped stock_id, new stock_id
|
||||
# handle kept stock_id
|
||||
|
||||
# because the order of the set is not fixed, the trading order of the stock is different, so that the backtest results of the same parameter are different;
|
||||
# because the order of the set is not fixed, the trading order of the stock is different, so that the backtest
|
||||
# results of the same parameter are different;
|
||||
# so here we sort stock_id, and then randomly shuffle the order of stock_id
|
||||
# because the same random seed is used, the final stock_id order is fixed
|
||||
sorted_ids = sorted(set(list(current_position.keys()) + list(target_position.keys())))
|
||||
@@ -546,7 +618,7 @@ class Exchange:
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
factor=factor,
|
||||
)
|
||||
),
|
||||
)
|
||||
else:
|
||||
# sell stock
|
||||
@@ -558,14 +630,19 @@ class Exchange:
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
factor=factor,
|
||||
)
|
||||
),
|
||||
)
|
||||
# return order_list : buy + sell
|
||||
return sell_order_list + buy_order_list
|
||||
|
||||
def calculate_amount_position_value(
|
||||
self, amount_dict, start_time, end_time, only_tradable=False, direction=OrderDir.SELL
|
||||
):
|
||||
self,
|
||||
amount_dict: dict,
|
||||
start_time: pd.Timestamp,
|
||||
end_time: pd.Timestamp,
|
||||
only_tradable: bool = False,
|
||||
direction: OrderDir = OrderDir.SELL,
|
||||
) -> float:
|
||||
"""Parameter
|
||||
position : Position()
|
||||
amount_dict : {stock_id : amount}
|
||||
@@ -576,21 +653,28 @@ class Exchange:
|
||||
"""
|
||||
value = 0
|
||||
for stock_id in amount_dict:
|
||||
if (
|
||||
only_tradable is True
|
||||
and self.check_stock_suspended(stock_id=stock_id, start_time=start_time, end_time=end_time) is False
|
||||
and self.check_stock_limit(stock_id=stock_id, start_time=start_time, end_time=end_time) is False
|
||||
or only_tradable is False
|
||||
if not only_tradable or (
|
||||
not self.check_stock_suspended(stock_id=stock_id, start_time=start_time, end_time=end_time)
|
||||
and not self.check_stock_limit(stock_id=stock_id, start_time=start_time, end_time=end_time)
|
||||
):
|
||||
value += (
|
||||
self.get_deal_price(
|
||||
stock_id=stock_id, start_time=start_time, end_time=end_time, direction=direction
|
||||
stock_id=stock_id,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
direction=direction,
|
||||
)
|
||||
* amount_dict[stock_id]
|
||||
)
|
||||
return value
|
||||
|
||||
def _get_factor_or_raise_error(self, factor: float = None, stock_id: str = None, start_time=None, end_time=None):
|
||||
def _get_factor_or_raise_error(
|
||||
self,
|
||||
factor: float = None,
|
||||
stock_id: str = None,
|
||||
start_time: pd.Timestamp = None,
|
||||
end_time: pd.Timestamp = None,
|
||||
) -> float:
|
||||
"""Please refer to the docs of get_amount_of_trade_unit"""
|
||||
if factor is None:
|
||||
if stock_id is not None and start_time is not None and end_time is not None:
|
||||
@@ -599,7 +683,13 @@ class Exchange:
|
||||
raise ValueError(f"`factor` and (`stock_id`, `start_time`, `end_time`) can't both be None")
|
||||
return factor
|
||||
|
||||
def get_amount_of_trade_unit(self, factor: float = None, stock_id: str = None, start_time=None, end_time=None):
|
||||
def get_amount_of_trade_unit(
|
||||
self,
|
||||
factor: float = None,
|
||||
stock_id: str = None,
|
||||
start_time: pd.Timestamp = None,
|
||||
end_time: pd.Timestamp = None,
|
||||
) -> Optional[float]:
|
||||
"""
|
||||
get the trade unit of amount based on **factor**
|
||||
the factor can be given directly or calculated in given time range and stock id.
|
||||
@@ -617,14 +707,22 @@ class Exchange:
|
||||
"""
|
||||
if not self.trade_w_adj_price and self.trade_unit is not None:
|
||||
factor = self._get_factor_or_raise_error(
|
||||
factor=factor, stock_id=stock_id, start_time=start_time, end_time=end_time
|
||||
factor=factor,
|
||||
stock_id=stock_id,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
return self.trade_unit / factor
|
||||
else:
|
||||
return None
|
||||
|
||||
def round_amount_by_trade_unit(
|
||||
self, deal_amount, factor: float = None, stock_id: str = None, start_time=None, end_time=None
|
||||
self,
|
||||
deal_amount,
|
||||
factor: float = None,
|
||||
stock_id: str = None,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
):
|
||||
"""Parameter
|
||||
Please refer to the docs of get_amount_of_trade_unit
|
||||
@@ -635,7 +733,10 @@ class Exchange:
|
||||
if not self.trade_w_adj_price and self.trade_unit is not None:
|
||||
# the minimal amount is 1. Add 0.1 for solving precision problem.
|
||||
factor = self._get_factor_or_raise_error(
|
||||
factor=factor, stock_id=stock_id, start_time=start_time, end_time=end_time
|
||||
factor=factor,
|
||||
stock_id=stock_id,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
return (deal_amount * factor + 0.1) // self.trade_unit * self.trade_unit / factor
|
||||
return deal_amount
|
||||
@@ -714,7 +815,12 @@ class Exchange:
|
||||
max_trade_amount = (cash - self.min_cost) / trade_price
|
||||
return max_trade_amount
|
||||
|
||||
def _calc_trade_info_by_order(self, order, position: Position, dealt_order_amount):
|
||||
def _calc_trade_info_by_order(
|
||||
self,
|
||||
order: Order,
|
||||
position: Optional[BasePosition],
|
||||
dealt_order_amount: dict,
|
||||
) -> Tuple[float, float, float]:
|
||||
"""
|
||||
Calculation of trade info
|
||||
**NOTE**: Order will be changed in this function
|
||||
@@ -753,7 +859,8 @@ class Exchange:
|
||||
if not np.isclose(order.deal_amount, current_amount):
|
||||
# when not selling last stock. rounding is necessary
|
||||
order.deal_amount = self.round_amount_by_trade_unit(
|
||||
min(current_amount, order.deal_amount), order.factor
|
||||
min(current_amount, order.deal_amount),
|
||||
order.factor,
|
||||
)
|
||||
|
||||
# in case of negative value of cash
|
||||
@@ -778,7 +885,8 @@ class Exchange:
|
||||
# The money is not enough
|
||||
max_buy_amount = self._get_buy_amount_by_cash_limit(trade_price, cash, cost_ratio)
|
||||
order.deal_amount = self.round_amount_by_trade_unit(
|
||||
min(max_buy_amount, order.deal_amount), order.factor
|
||||
min(max_buy_amount, order.deal_amount),
|
||||
order.factor,
|
||||
)
|
||||
self.logger.debug(f"Order clipped due to cash limitation: {order}")
|
||||
else:
|
||||
|
||||
@@ -1,19 +1,28 @@
|
||||
from abc import abstractmethod
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from abc import abstractmethod
|
||||
from collections import defaultdict
|
||||
from types import GeneratorType
|
||||
from typing import Generator, List, Optional, Tuple, Union
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest.account import Account
|
||||
from qlib.backtest.position import BasePosition
|
||||
from qlib.log import get_module_logger
|
||||
from types import GeneratorType
|
||||
from qlib.backtest.account import Account
|
||||
import pandas as pd
|
||||
from typing import List, Tuple, Union
|
||||
from collections import defaultdict
|
||||
|
||||
from .decision import Order, BaseTradeDecision
|
||||
from .exchange import Exchange
|
||||
from .utils import TradeCalendarManager, CommonInfrastructure, LevelInfrastructure, get_start_end_idx
|
||||
|
||||
from ..utils import init_instance_by_config
|
||||
from ..strategy.base import BaseStrategy
|
||||
from ..utils import init_instance_by_config
|
||||
from .decision import BaseTradeDecision, Order
|
||||
from .exchange import Exchange
|
||||
from .utils import (
|
||||
BaseInfrastructure,
|
||||
CommonInfrastructure,
|
||||
LevelInfrastructure,
|
||||
TradeCalendarManager,
|
||||
get_start_end_idx,
|
||||
)
|
||||
|
||||
|
||||
class BaseExecutor:
|
||||
@@ -30,9 +39,9 @@ class BaseExecutor:
|
||||
track_data: bool = False,
|
||||
trade_exchange: Exchange = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
settle_type=BasePosition.ST_NO,
|
||||
settle_type=BasePosition.ST_NO, # TODO: add typehint
|
||||
**kwargs,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -53,15 +62,21 @@ class BaseExecutor:
|
||||
- 'base_price': the based price than which the trading price is advanced, Optional, default by 'twap'
|
||||
- If 'base_price' is 'twap', the based price is the time weighted average price
|
||||
- If 'base_price' is 'vwap', the based price is the volume weighted average price
|
||||
- 'weight_method': weighted method when calculating total trading pa by different orders' pa in each step, optional, default by 'mean'
|
||||
- 'weight_method': weighted method when calculating total trading pa by different orders' pa in each
|
||||
step, optional, default by 'mean'
|
||||
- If 'weight_method' is 'mean', calculating mean value of different orders' pa
|
||||
- If 'weight_method' is 'amount_weighted', calculating amount weighted average value of different orders' pa
|
||||
- If 'weight_method' is 'value_weighted', calculating value weighted average value of different orders' pa
|
||||
- If 'weight_method' is 'amount_weighted', calculating amount weighted average value of different
|
||||
orders' pa
|
||||
- If 'weight_method' is 'value_weighted', calculating value weighted average value of different
|
||||
orders' pa
|
||||
- 'ffr_config': config for calculating fulfill rate(ffr), optional
|
||||
- 'weight_method': weighted method when calculating total trading ffr by different orders' ffr in each step, optional, default by 'mean'
|
||||
- 'weight_method': weighted method when calculating total trading ffr by different orders' ffr in each
|
||||
step, optional, default by 'mean'
|
||||
- If 'weight_method' is 'mean', calculating mean value of different orders' ffr
|
||||
- If 'weight_method' is 'amount_weighted', calculating amount weighted average value of different orders' ffr
|
||||
- If 'weight_method' is 'value_weighted', calculating value weighted average value of different orders' ffr
|
||||
- If 'weight_method' is 'amount_weighted', calculating amount weighted average value of different
|
||||
orders' ffr
|
||||
- If 'weight_method' is 'value_weighted', calculating value weighted average value of different
|
||||
orders' ffr
|
||||
Example:
|
||||
{
|
||||
'show_indicator': True,
|
||||
@@ -79,7 +94,8 @@ class BaseExecutor:
|
||||
whether to print trading info, by default False
|
||||
track_data : bool, optional
|
||||
whether to generate trade_decision, will be used when training rl agent
|
||||
- If `self.track_data` is true, when making data for training, the input `trade_decision` of `execute` will be generated by `collect_data`
|
||||
- If `self.track_data` is true, when making data for training, the input `trade_decision` of `execute` will
|
||||
be generated by `collect_data`
|
||||
- Else, `trade_decision` will not be generated
|
||||
|
||||
trade_exchange : Exchange
|
||||
@@ -114,7 +130,7 @@ class BaseExecutor:
|
||||
self.dealt_order_amount = defaultdict(float)
|
||||
self.deal_day = None
|
||||
|
||||
def reset_common_infra(self, common_infra, copy_trade_account=False):
|
||||
def reset_common_infra(self, common_infra: BaseInfrastructure, copy_trade_account: bool = False) -> None:
|
||||
"""
|
||||
reset infrastructure for trading
|
||||
- reset trade_account
|
||||
@@ -132,7 +148,7 @@ class BaseExecutor:
|
||||
# 2. Others are not shared, so each level has it own metrics (portfolio and trading metrics)
|
||||
self.trade_account: Account = copy.copy(common_infra.get("trade_account"))
|
||||
else:
|
||||
self.trade_account = common_infra.get("trade_account")
|
||||
self.trade_account: Account = common_infra.get("trade_account")
|
||||
self.trade_account.reset(freq=self.time_per_step, port_metr_enabled=self.generate_portfolio_metrics)
|
||||
|
||||
@property
|
||||
@@ -148,7 +164,7 @@ class BaseExecutor:
|
||||
"""
|
||||
return self.level_infra.get("trade_calendar")
|
||||
|
||||
def reset(self, common_infra: CommonInfrastructure = None, **kwargs):
|
||||
def reset(self, common_infra: CommonInfrastructure = None, **kwargs) -> None:
|
||||
"""
|
||||
- reset `start_time` and `end_time`, used in trade calendar
|
||||
- reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc
|
||||
@@ -161,13 +177,13 @@ class BaseExecutor:
|
||||
if common_infra is not None:
|
||||
self.reset_common_infra(common_infra)
|
||||
|
||||
def get_level_infra(self):
|
||||
def get_level_infra(self) -> LevelInfrastructure:
|
||||
return self.level_infra
|
||||
|
||||
def finished(self):
|
||||
def finished(self) -> bool:
|
||||
return self.trade_calendar.finished()
|
||||
|
||||
def execute(self, trade_decision: BaseTradeDecision, level: int = 0):
|
||||
def execute(self, trade_decision: BaseTradeDecision, level: int = 0) -> List[object]:
|
||||
"""execute the trade decision and return the executed result
|
||||
|
||||
NOTE: this function is never used directly in the framework. Should we delete it?
|
||||
@@ -189,9 +205,15 @@ class BaseExecutor:
|
||||
pass
|
||||
return return_value.get("execute_result")
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def _collect_data(cls, trade_decision: BaseTradeDecision, level: int = 0) -> Tuple[List[object], dict]:
|
||||
def _collect_data(
|
||||
self,
|
||||
trade_decision: BaseTradeDecision,
|
||||
level: int = 0,
|
||||
) -> Union[
|
||||
Generator[BaseTradeDecision, Optional[BaseTradeDecision], Tuple[List[object], dict]],
|
||||
Tuple[List[object], dict],
|
||||
]:
|
||||
"""
|
||||
Please refer to the doc of collect_data
|
||||
The only difference between `_collect_data` and `collect_data` is that some common steps are moved into
|
||||
@@ -209,8 +231,11 @@ class BaseExecutor:
|
||||
"""
|
||||
|
||||
def collect_data(
|
||||
self, trade_decision: BaseTradeDecision, return_value: dict = None, level: int = 0
|
||||
) -> List[object]:
|
||||
self,
|
||||
trade_decision: BaseTradeDecision,
|
||||
return_value: dict = None,
|
||||
level: int = 0,
|
||||
) -> Generator[BaseTradeDecision, Optional[BaseTradeDecision], List[object]]:
|
||||
"""Generator for collecting the trade decision data for rl training
|
||||
|
||||
his function will make a step forward
|
||||
@@ -253,7 +278,9 @@ class BaseExecutor:
|
||||
obj = self._collect_data(trade_decision=trade_decision, level=level)
|
||||
|
||||
if isinstance(obj, GeneratorType):
|
||||
res, kwargs = yield from obj
|
||||
yield_res = yield from obj
|
||||
assert isinstance(yield_res, tuple) and len(yield_res) == 2
|
||||
res, kwargs = yield_res
|
||||
else:
|
||||
# Some concrete executor don't have inner decisions
|
||||
res, kwargs = obj
|
||||
@@ -279,7 +306,7 @@ class BaseExecutor:
|
||||
return_value.update({"execute_result": res})
|
||||
return res
|
||||
|
||||
def get_all_executors(self):
|
||||
def get_all_executors(self) -> List[BaseExecutor]:
|
||||
"""get all executors"""
|
||||
return [self]
|
||||
|
||||
@@ -287,7 +314,8 @@ class BaseExecutor:
|
||||
class NestedExecutor(BaseExecutor):
|
||||
"""
|
||||
Nested Executor with inner strategy and executor
|
||||
- At each time `execute` is called, it will call the inner strategy and executor to execute the `trade_decision` in a higher frequency env.
|
||||
- At each time `execute` is called, it will call the inner strategy and executor to execute the `trade_decision`
|
||||
in a higher frequency env.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -305,7 +333,7 @@ class NestedExecutor(BaseExecutor):
|
||||
align_range_limit: bool = True,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
**kwargs,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -323,10 +351,14 @@ class NestedExecutor(BaseExecutor):
|
||||
It is only for nested executor, because range_limit is given by outer strategy
|
||||
"""
|
||||
self.inner_executor: BaseExecutor = init_instance_by_config(
|
||||
inner_executor, common_infra=common_infra, accept_types=BaseExecutor
|
||||
inner_executor,
|
||||
common_infra=common_infra,
|
||||
accept_types=BaseExecutor,
|
||||
)
|
||||
self.inner_strategy: BaseStrategy = init_instance_by_config(
|
||||
inner_strategy, common_infra=common_infra, accept_types=BaseStrategy
|
||||
inner_strategy,
|
||||
common_infra=common_infra,
|
||||
accept_types=BaseStrategy,
|
||||
)
|
||||
|
||||
self._skip_empty_decision = skip_empty_decision
|
||||
@@ -344,10 +376,10 @@ class NestedExecutor(BaseExecutor):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def reset_common_infra(self, common_infra, copy_trade_account=False):
|
||||
def reset_common_infra(self, common_infra: CommonInfrastructure, copy_trade_account: bool = False) -> None:
|
||||
"""
|
||||
reset infrastructure for trading
|
||||
- reset inner_strategyand inner_executor common infra
|
||||
- reset inner_strategy and inner_executor common infra
|
||||
"""
|
||||
# NOTE: please refer to the docs of BaseExecutor.reset_common_infra for the meaning of `copy_trade_account`
|
||||
|
||||
@@ -358,7 +390,7 @@ class NestedExecutor(BaseExecutor):
|
||||
self.inner_executor.reset_common_infra(common_infra, copy_trade_account=True)
|
||||
self.inner_strategy.reset_common_infra(common_infra)
|
||||
|
||||
def _init_sub_trading(self, trade_decision):
|
||||
def _init_sub_trading(self, trade_decision: BaseTradeDecision) -> None:
|
||||
trade_start_time, trade_end_time = self.trade_calendar.get_step_time()
|
||||
self.inner_executor.reset(start_time=trade_start_time, end_time=trade_end_time)
|
||||
sub_level_infra = self.inner_executor.get_level_infra()
|
||||
@@ -368,14 +400,18 @@ class NestedExecutor(BaseExecutor):
|
||||
def _update_trade_decision(self, trade_decision: BaseTradeDecision) -> BaseTradeDecision:
|
||||
# outer strategy have chance to update decision each iterator
|
||||
updated_trade_decision = trade_decision.update(self.inner_executor.trade_calendar)
|
||||
if updated_trade_decision is not None:
|
||||
if updated_trade_decision is not None: # TODO: always is None for now?
|
||||
trade_decision = updated_trade_decision
|
||||
# NEW UPDATE
|
||||
# create a hook for inner strategy to update outer decision
|
||||
self.inner_strategy.alter_outer_trade_decision(trade_decision)
|
||||
return trade_decision
|
||||
|
||||
def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0):
|
||||
def _collect_data(
|
||||
self,
|
||||
trade_decision: BaseTradeDecision,
|
||||
level: int = 0,
|
||||
) -> Generator[BaseTradeDecision, Optional[BaseTradeDecision], Tuple[List[object], dict]]:
|
||||
execute_result = []
|
||||
inner_order_indicators = []
|
||||
decision_list = []
|
||||
@@ -390,8 +426,8 @@ class NestedExecutor(BaseExecutor):
|
||||
|
||||
if trade_decision.empty() and self._skip_empty_decision:
|
||||
# give one chance for outer strategy to update the strategy
|
||||
# - For updating some information in the sub executor(the strategy have no knowledge of the inner
|
||||
# executor when generating the decision)
|
||||
# - For updating some information in the sub executor (the strategy have no knowledge of the inner
|
||||
# executor when generating the decision)
|
||||
break
|
||||
|
||||
sub_cal: TradeCalendarManager = self.inner_executor.trade_calendar
|
||||
@@ -405,15 +441,19 @@ class NestedExecutor(BaseExecutor):
|
||||
|
||||
# NOTE: !!!!!
|
||||
# the two lines below is for a special case in RL
|
||||
# To solve the confliction below
|
||||
# - Normally, user will create a strategy and embed it into Qlib's executor and simulator interaction loop
|
||||
# For a _nested qlib example_, (Qlib Strategy) <=> (Qlib Executor[(inner Qlib Strategy) <=> (inner Qlib Executor)])
|
||||
# To solve the conflicts below
|
||||
# - Normally, user will create a strategy and embed it into Qlib's executor and simulator interaction
|
||||
# loop For a _nested qlib example_, (Qlib Strategy) <=> (Qlib Executor[(inner Qlib Strategy) <=>
|
||||
# (inner Qlib Executor)])
|
||||
# - However, RL-based framework has it's own script to run the loop
|
||||
# For an _RL learning example_, (RL Policy) <=> (RL Env[(inner Qlib Executor)])
|
||||
# To make it possible to run _nested qlib example_ and _RL learning example_ together, the solution below is proposed
|
||||
# - The entry script follow the example of _RL learning example_ to be compatible with all kinds of RL Framework
|
||||
# To make it possible to run _nested qlib example_ and _RL learning example_ together, the solution
|
||||
# below is proposed
|
||||
# - The entry script follow the example of _RL learning example_ to be compatible with all kinds of
|
||||
# RL Framework
|
||||
# - Each step of (RL Env) will make (inner Qlib Executor) one step forward
|
||||
# - (inner Qlib Strategy) is a proxy strategy, it will give the program control right to (RL Env) by `yield from` and wait for the action from the policy
|
||||
# - (inner Qlib Strategy) is a proxy strategy, it will give the program control right to (RL Env)
|
||||
# by `yield from` and wait for the action from the policy
|
||||
# So the two lines below is the implementation of yielding control rights
|
||||
if isinstance(res, GeneratorType):
|
||||
res = yield from res
|
||||
@@ -427,13 +467,15 @@ class NestedExecutor(BaseExecutor):
|
||||
|
||||
# NOTE: Trade Calendar will step forward in the follow line
|
||||
_inner_execute_result = yield from self.inner_executor.collect_data(
|
||||
trade_decision=_inner_trade_decision, level=level + 1
|
||||
trade_decision=_inner_trade_decision,
|
||||
level=level + 1,
|
||||
)
|
||||
assert isinstance(_inner_execute_result, list)
|
||||
self.post_inner_exe_step(_inner_execute_result)
|
||||
execute_result.extend(_inner_execute_result)
|
||||
|
||||
inner_order_indicators.append(
|
||||
self.inner_executor.trade_account.get_trade_indicator().get_order_indicator(raw=True)
|
||||
self.inner_executor.trade_account.get_trade_indicator().get_order_indicator(raw=True),
|
||||
)
|
||||
else:
|
||||
# do nothing and just step forward
|
||||
@@ -441,7 +483,7 @@ class NestedExecutor(BaseExecutor):
|
||||
|
||||
return execute_result, {"inner_order_indicators": inner_order_indicators, "decision_list": decision_list}
|
||||
|
||||
def post_inner_exe_step(self, inner_exe_res):
|
||||
def post_inner_exe_step(self, inner_exe_res: List[object]) -> None:
|
||||
"""
|
||||
A hook for doing sth after each step of inner strategy
|
||||
|
||||
@@ -451,11 +493,23 @@ class NestedExecutor(BaseExecutor):
|
||||
the execution result of inner task
|
||||
"""
|
||||
|
||||
def get_all_executors(self):
|
||||
def get_all_executors(self) -> List[object]:
|
||||
"""get all executors, including self and inner_executor.get_all_executors()"""
|
||||
return [self, *self.inner_executor.get_all_executors()]
|
||||
|
||||
|
||||
def _retrieve_orders_from_decision(trade_decision: BaseTradeDecision) -> List[Order]:
|
||||
"""
|
||||
IDE-friendly helper function.
|
||||
"""
|
||||
decisions = trade_decision.get_decision()
|
||||
orders: List[Order] = []
|
||||
for decision in decisions:
|
||||
assert isinstance(decision, Order)
|
||||
orders.append(decision)
|
||||
return orders
|
||||
|
||||
|
||||
class SimulatorExecutor(BaseExecutor):
|
||||
"""Executor that simulate the true market"""
|
||||
|
||||
@@ -464,10 +518,10 @@ class SimulatorExecutor(BaseExecutor):
|
||||
|
||||
# available trade_types
|
||||
TT_SERIAL = "serial"
|
||||
## The orders will be executed serially in a sequence
|
||||
# The orders will be executed serially in a sequence
|
||||
# In each trading step, it is possible that users sell instruments first and use the money to buy new instruments
|
||||
TT_PARAL = "parallel"
|
||||
## The orders will be executed parallelly
|
||||
# The orders will be executed in parallel
|
||||
# In each trading step, if users try to sell instruments first and buy new instruments with money, failure will
|
||||
# occur
|
||||
|
||||
@@ -483,7 +537,7 @@ class SimulatorExecutor(BaseExecutor):
|
||||
common_infra: CommonInfrastructure = None,
|
||||
trade_type: str = TT_SERIAL,
|
||||
**kwargs,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -517,7 +571,7 @@ class SimulatorExecutor(BaseExecutor):
|
||||
List[Order]:
|
||||
get a list orders according to `self.trade_type`
|
||||
"""
|
||||
orders = trade_decision.get_decision()
|
||||
orders = _retrieve_orders_from_decision(trade_decision)
|
||||
|
||||
if self.trade_type == self.TT_SERIAL:
|
||||
# Orders will be traded in a parallel way
|
||||
@@ -525,15 +579,15 @@ class SimulatorExecutor(BaseExecutor):
|
||||
elif self.trade_type == self.TT_PARAL:
|
||||
# NOTE: !!!!!!!
|
||||
# Assumption: there will not be orders in different trading direction in a single step of a strategy !!!!
|
||||
# The parallel trading failure will be caused only by the confliction of money
|
||||
# Therefore, make the buying go first will make sure the confliction happen.
|
||||
# The parallel trading failure will be caused only by the conflicts of money
|
||||
# Therefore, make the buying go first will make sure the conflicts happen.
|
||||
# It equals to parallel trading after sorting the order by direction
|
||||
order_it = sorted(orders, key=lambda order: -order.direction)
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
return order_it
|
||||
|
||||
def _update_dealt_order_amount(self, order):
|
||||
def _update_dealt_order_amount(self, order: Order) -> None:
|
||||
"""update date and dealt order amount in the day."""
|
||||
|
||||
now_deal_day = self.trade_calendar.get_step_time()[0].floor(freq="D")
|
||||
@@ -542,8 +596,7 @@ class SimulatorExecutor(BaseExecutor):
|
||||
self.deal_day = now_deal_day
|
||||
self.dealt_order_amount[order.stock_id] += order.deal_amount
|
||||
|
||||
def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0):
|
||||
|
||||
def _collect_data(self, trade_decision: BaseTradeDecision, level: int = 0) -> Tuple[List[object], dict]:
|
||||
trade_start_time, _ = self.trade_calendar.get_step_time()
|
||||
execute_result = []
|
||||
|
||||
@@ -559,7 +612,8 @@ class SimulatorExecutor(BaseExecutor):
|
||||
self._update_dealt_order_amount(order)
|
||||
if self.verbose:
|
||||
print(
|
||||
"[I {:%Y-%m-%d %H:%M:%S}]: {} {}, price {:.2f}, amount {}, deal_amount {}, factor {}, value {:.2f}, cash {:.2f}.".format(
|
||||
"[I {:%Y-%m-%d %H:%M:%S}]: {} {}, price {:.2f}, amount {}, deal_amount {}, factor {}, "
|
||||
"value {:.2f}, cash {:.2f}.".format(
|
||||
trade_start_time,
|
||||
"sell" if order.direction == Order.SELL else "buy",
|
||||
order.stock_id,
|
||||
@@ -569,6 +623,6 @@ class SimulatorExecutor(BaseExecutor):
|
||||
order.factor,
|
||||
trade_val,
|
||||
self.trade_account.get_cash(),
|
||||
)
|
||||
),
|
||||
)
|
||||
return execute_result, {"trade_info": execute_result}
|
||||
|
||||
@@ -1,20 +1,21 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from functools import lru_cache
|
||||
import logging
|
||||
from typing import List, Text, Union, Callable, Iterable, Dict
|
||||
from collections import OrderedDict
|
||||
|
||||
import inspect
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from functools import lru_cache
|
||||
from typing import Callable, Dict, Iterable, List, Text, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
import qlib.utils.index_data as idd
|
||||
|
||||
from ..log import get_module_logger
|
||||
from ..utils.index_data import IndexData, SingleData
|
||||
from ..utils.resam import resam_ts_data, ts_data_last
|
||||
from ..log import get_module_logger
|
||||
from ..utils.time import is_single_value, Freq
|
||||
import qlib.utils.index_data as idd
|
||||
from ..utils.time import Freq, is_single_value
|
||||
|
||||
|
||||
class BaseQuote:
|
||||
@@ -627,7 +628,9 @@ class NumpyOrderIndicator(BaseOrderIndicator):
|
||||
metrics = [metrics]
|
||||
for metric in metrics:
|
||||
order_indicator.data[metric] = idd.sum_by_index(
|
||||
[indicator.data[metric] for indicator in indicators], stocks, fill_value
|
||||
[indicator.data[metric] for indicator in indicators],
|
||||
stocks,
|
||||
fill_value,
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
|
||||
@@ -2,24 +2,28 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from datetime import timedelta
|
||||
from typing import Dict, List, Union
|
||||
|
||||
import pandas as pd
|
||||
from datetime import timedelta
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from .decision import Order
|
||||
from ..data.data import D
|
||||
from .decision import Order
|
||||
|
||||
|
||||
class BasePosition:
|
||||
"""
|
||||
The Position want to maintain the position like a dictionary
|
||||
The Position wants to maintain the position like a dictionary
|
||||
Please refer to the `Position` class for the position
|
||||
"""
|
||||
|
||||
def __init__(self, *args, cash=0.0, **kwargs):
|
||||
def __init__(self, *args, cash: float = 0.0, **kwargs) -> None:
|
||||
self._settle_type = self.ST_NO
|
||||
self.position = {}
|
||||
|
||||
def fill_stock_value(self, start_time: Union[str, pd.Timestamp], freq: str, last_days: int = 30) -> None:
|
||||
pass
|
||||
|
||||
def skip_update(self) -> bool:
|
||||
"""
|
||||
@@ -49,7 +53,7 @@ class BasePosition:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `check_stock` method")
|
||||
|
||||
def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float):
|
||||
def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float) -> None:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -64,7 +68,7 @@ class BasePosition:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `update_order` method")
|
||||
|
||||
def update_stock_price(self, stock_id, price: float):
|
||||
def update_stock_price(self, stock_id: str, price: float) -> None:
|
||||
"""
|
||||
Updating the latest price of the order
|
||||
The useful when clearing balance at each bar end
|
||||
@@ -89,6 +93,9 @@ class BasePosition:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `calculate_stock_value` method")
|
||||
|
||||
def calculate_value(self) -> float:
|
||||
raise NotImplementedError(f"Please implement the `calculate_value` method")
|
||||
|
||||
def get_stock_list(self) -> List:
|
||||
"""
|
||||
Get the list of stocks in the position.
|
||||
@@ -124,14 +131,16 @@ class BasePosition:
|
||||
|
||||
def get_cash(self, include_settle: bool = False) -> float:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
include_settle:
|
||||
will the unsettled(delayed) cash included
|
||||
Default: not include those unavailable cash
|
||||
|
||||
Returns
|
||||
-------
|
||||
float:
|
||||
the available(tradable) cash in position
|
||||
include_settle:
|
||||
will the unsettled(delayed) cash included
|
||||
Default: not include those unavailable cash
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_cash` method")
|
||||
|
||||
@@ -165,7 +174,7 @@ class BasePosition:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `get_stock_weight_dict` method")
|
||||
|
||||
def add_count_all(self, bar):
|
||||
def add_count_all(self, bar) -> None:
|
||||
"""
|
||||
Will be called at the end of each bar on each level
|
||||
|
||||
@@ -176,24 +185,19 @@ class BasePosition:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `add_count_all` method")
|
||||
|
||||
def update_weight_all(self):
|
||||
def update_weight_all(self) -> None:
|
||||
"""
|
||||
Updating the position weight;
|
||||
|
||||
# TODO: this function is a little weird. The weight data in the position is in a wrong state after dealing order
|
||||
# and before updating weight.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bar :
|
||||
The level to be updated
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `add_count_all` method")
|
||||
|
||||
ST_CASH = "cash"
|
||||
ST_NO = None
|
||||
|
||||
def settle_start(self, settle_type: str):
|
||||
def settle_start(self, settle_type: str) -> None:
|
||||
"""
|
||||
settlement start
|
||||
It will act like start and commit a transaction
|
||||
@@ -210,14 +214,9 @@ class BasePosition:
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `settle_conf` method")
|
||||
|
||||
def settle_commit(self):
|
||||
def settle_commit(self) -> None:
|
||||
"""
|
||||
settlement commit
|
||||
|
||||
Parameters
|
||||
----------
|
||||
settle_type : str
|
||||
please refer to the documents of Executor
|
||||
"""
|
||||
raise NotImplementedError(f"Please implement the `settle_commit` method")
|
||||
|
||||
@@ -242,13 +241,11 @@ class Position(BasePosition):
|
||||
}
|
||||
"""
|
||||
|
||||
def __init__(self, cash: float = 0, position_dict: Dict[str, Dict[str, float]] = {}):
|
||||
def __init__(self, cash: float = 0, position_dict: Dict[str, Union[Dict[str, float], float]] = {}) -> None:
|
||||
"""Init position by cash and position_dict.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
start_time :
|
||||
the start time of backtest. It's for filling the initial value of stocks.
|
||||
cash : float, optional
|
||||
initial cash in account, by default 0
|
||||
position_dict : Dict[
|
||||
@@ -268,9 +265,9 @@ class Position(BasePosition):
|
||||
# Otherwise the initial value
|
||||
self.init_cash = cash
|
||||
self.position = position_dict.copy()
|
||||
for stock in self.position:
|
||||
if isinstance(self.position[stock], int):
|
||||
self.position[stock] = {"amount": self.position[stock]}
|
||||
for stock, value in self.position.items():
|
||||
if isinstance(value, int):
|
||||
self.position[stock] = {"amount": value}
|
||||
self.position["cash"] = cash
|
||||
|
||||
# If the stock price information is missing, the account value will not be calculated temporarily
|
||||
@@ -279,21 +276,23 @@ class Position(BasePosition):
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
def fill_stock_value(self, start_time: Union[str, pd.Timestamp], freq: str, last_days: int = 30):
|
||||
def fill_stock_value(self, start_time: Union[str, pd.Timestamp], freq: str, last_days: int = 30) -> None:
|
||||
"""fill the stock value by the close price of latest last_days from qlib.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
start_time :
|
||||
the start time of backtest.
|
||||
freq : str
|
||||
Frequency
|
||||
last_days : int, optional
|
||||
the days to get the latest close price, by default 30.
|
||||
"""
|
||||
stock_list = []
|
||||
for stock in self.position:
|
||||
if not isinstance(self.position[stock], dict):
|
||||
for stock, value in self.position.items():
|
||||
if not isinstance(value, dict):
|
||||
continue
|
||||
if ("price" not in self.position[stock]) or (self.position[stock]["price"] is None):
|
||||
if value.get("price", None) is None:
|
||||
stock_list.append(stock)
|
||||
|
||||
if len(stock_list) == 0:
|
||||
@@ -304,7 +303,12 @@ class Position(BasePosition):
|
||||
price_end_time = start_time
|
||||
price_start_time = start_time - timedelta(days=last_days)
|
||||
price_df = D.features(
|
||||
stock_list, ["$close"], price_start_time, price_end_time, freq=freq, disk_cache=True
|
||||
stock_list,
|
||||
["$close"],
|
||||
price_start_time,
|
||||
price_end_time,
|
||||
freq=freq,
|
||||
disk_cache=True,
|
||||
).dropna()
|
||||
price_dict = price_df.groupby(["instrument"]).tail(1).reset_index(level=1, drop=True)["$close"].to_dict()
|
||||
|
||||
@@ -316,7 +320,7 @@ class Position(BasePosition):
|
||||
self.position[stock]["price"] = price_dict[stock]
|
||||
self.position["now_account_value"] = self.calculate_value()
|
||||
|
||||
def _init_stock(self, stock_id, amount, price=None):
|
||||
def _init_stock(self, stock_id: str, amount: float, price: float = None) -> None:
|
||||
"""
|
||||
initialization the stock in current position
|
||||
|
||||
@@ -334,7 +338,7 @@ class Position(BasePosition):
|
||||
self.position[stock_id]["price"] = price
|
||||
self.position[stock_id]["weight"] = 0 # update the weight in the end of the trade date
|
||||
|
||||
def _buy_stock(self, stock_id, trade_val, cost, trade_price):
|
||||
def _buy_stock(self, stock_id: str, trade_val: float, cost: float, trade_price: float) -> None:
|
||||
trade_amount = trade_val / trade_price
|
||||
if stock_id not in self.position:
|
||||
self._init_stock(stock_id=stock_id, amount=trade_amount, price=trade_price)
|
||||
@@ -344,15 +348,16 @@ class Position(BasePosition):
|
||||
|
||||
self.position["cash"] -= trade_val + cost
|
||||
|
||||
def _sell_stock(self, stock_id, trade_val, cost, trade_price):
|
||||
def _sell_stock(self, stock_id: str, trade_val: float, cost: float, trade_price: float) -> None:
|
||||
trade_amount = trade_val / trade_price
|
||||
if stock_id not in self.position:
|
||||
raise KeyError("{} not in current position".format(stock_id))
|
||||
else:
|
||||
if np.isclose(self.position[stock_id]["amount"], trade_amount):
|
||||
# Selling all the stocks
|
||||
# we use np.isclose instead of abs(<the final amount>) <= 1e-5 because `np.isclose` consider both ralative amount and absolute amount
|
||||
# Using abs(<the final amount>) <= 1e-5 will result in error when the amount is large
|
||||
# we use np.isclose instead of abs(<the final amount>) <= 1e-5 because `np.isclose` consider both
|
||||
# relative amount and absolute amount
|
||||
# Using abs(<the final amount>) <= 1e-5 will result in error when the amount is large
|
||||
self._del_stock(stock_id)
|
||||
else:
|
||||
# decrease the amount of stock
|
||||
@@ -361,8 +366,10 @@ class Position(BasePosition):
|
||||
if self.position[stock_id]["amount"] < -1e-5:
|
||||
raise ValueError(
|
||||
"only have {} {}, require {}".format(
|
||||
self.position[stock_id]["amount"] + trade_amount, stock_id, trade_amount
|
||||
)
|
||||
self.position[stock_id]["amount"] + trade_amount,
|
||||
stock_id,
|
||||
trade_amount,
|
||||
),
|
||||
)
|
||||
|
||||
new_cash = trade_val - cost
|
||||
@@ -373,13 +380,13 @@ class Position(BasePosition):
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
|
||||
def _del_stock(self, stock_id):
|
||||
def _del_stock(self, stock_id: str) -> None:
|
||||
del self.position[stock_id]
|
||||
|
||||
def check_stock(self, stock_id):
|
||||
def check_stock(self, stock_id: str) -> bool:
|
||||
return stock_id in self.position
|
||||
|
||||
def update_order(self, order, trade_val, cost, trade_price):
|
||||
def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float) -> None:
|
||||
# handle order, order is a order class, defined in exchange.py
|
||||
if order.direction == Order.BUY:
|
||||
# BUY
|
||||
@@ -390,54 +397,54 @@ class Position(BasePosition):
|
||||
else:
|
||||
raise NotImplementedError("do not support order direction {}".format(order.direction))
|
||||
|
||||
def update_stock_price(self, stock_id, price):
|
||||
def update_stock_price(self, stock_id: str, price: float) -> None:
|
||||
self.position[stock_id]["price"] = price
|
||||
|
||||
def update_stock_count(self, stock_id, bar, count):
|
||||
def update_stock_count(self, stock_id: str, bar: str, count: float) -> None: # TODO: check type of `bar`
|
||||
self.position[stock_id][f"count_{bar}"] = count
|
||||
|
||||
def update_stock_weight(self, stock_id, weight):
|
||||
def update_stock_weight(self, stock_id: str, weight: float) -> None:
|
||||
self.position[stock_id]["weight"] = weight
|
||||
|
||||
def calculate_stock_value(self):
|
||||
def calculate_stock_value(self) -> float:
|
||||
stock_list = self.get_stock_list()
|
||||
value = 0
|
||||
for stock_id in stock_list:
|
||||
value += self.position[stock_id]["amount"] * self.position[stock_id]["price"]
|
||||
return value
|
||||
|
||||
def calculate_value(self):
|
||||
def calculate_value(self) -> float:
|
||||
value = self.calculate_stock_value()
|
||||
value += self.position["cash"] + self.position.get("cash_delay", 0.0)
|
||||
return value
|
||||
|
||||
def get_stock_list(self):
|
||||
def get_stock_list(self) -> List[str]:
|
||||
stock_list = list(set(self.position.keys()) - {"cash", "now_account_value", "cash_delay"})
|
||||
return stock_list
|
||||
|
||||
def get_stock_price(self, code):
|
||||
def get_stock_price(self, code: str) -> float:
|
||||
return self.position[code]["price"]
|
||||
|
||||
def get_stock_amount(self, code):
|
||||
def get_stock_amount(self, code: str) -> float:
|
||||
return self.position[code]["amount"] if code in self.position else 0
|
||||
|
||||
def get_stock_count(self, code, bar):
|
||||
def get_stock_count(self, code: str, bar: str) -> float:
|
||||
"""the days the account has been hold, it may be used in some special strategies"""
|
||||
if f"count_{bar}" in self.position[code]:
|
||||
return self.position[code][f"count_{bar}"]
|
||||
else:
|
||||
return 0
|
||||
|
||||
def get_stock_weight(self, code):
|
||||
def get_stock_weight(self, code: str) -> float:
|
||||
return self.position[code]["weight"]
|
||||
|
||||
def get_cash(self, include_settle=False):
|
||||
def get_cash(self, include_settle: bool = False) -> float:
|
||||
cash = self.position["cash"]
|
||||
if include_settle:
|
||||
cash += self.position.get("cash_delay", 0.0)
|
||||
return cash
|
||||
|
||||
def get_stock_amount_dict(self):
|
||||
def get_stock_amount_dict(self) -> dict:
|
||||
"""generate stock amount dict {stock_id : amount of stock}"""
|
||||
d = {}
|
||||
stock_list = self.get_stock_list()
|
||||
@@ -445,7 +452,7 @@ class Position(BasePosition):
|
||||
d[stock_code] = self.get_stock_amount(code=stock_code)
|
||||
return d
|
||||
|
||||
def get_stock_weight_dict(self, only_stock=False):
|
||||
def get_stock_weight_dict(self, only_stock: bool = False) -> dict:
|
||||
"""get_stock_weight_dict
|
||||
generate stock weight dict {stock_id : value weight of stock in the position}
|
||||
it is meaningful in the beginning or the end of each trade date
|
||||
@@ -463,7 +470,7 @@ class Position(BasePosition):
|
||||
d[stock_code] = self.position[stock_code]["amount"] * self.position[stock_code]["price"] / position_value
|
||||
return d
|
||||
|
||||
def add_count_all(self, bar):
|
||||
def add_count_all(self, bar: str) -> None:
|
||||
stock_list = self.get_stock_list()
|
||||
for code in stock_list:
|
||||
if f"count_{bar}" in self.position[code]:
|
||||
@@ -471,18 +478,18 @@ class Position(BasePosition):
|
||||
else:
|
||||
self.position[code][f"count_{bar}"] = 1
|
||||
|
||||
def update_weight_all(self):
|
||||
def update_weight_all(self) -> None:
|
||||
weight_dict = self.get_stock_weight_dict()
|
||||
for stock_code, weight in weight_dict.items():
|
||||
self.update_stock_weight(stock_code, weight)
|
||||
|
||||
def settle_start(self, settle_type):
|
||||
def settle_start(self, settle_type: str) -> None:
|
||||
assert self._settle_type == self.ST_NO, "Currently, settlement can't be nested!!!!!"
|
||||
self._settle_type = settle_type
|
||||
if settle_type == self.ST_CASH:
|
||||
self.position["cash_delay"] = 0.0
|
||||
|
||||
def settle_commit(self):
|
||||
def settle_commit(self) -> None:
|
||||
if self._settle_type != self.ST_NO:
|
||||
if self._settle_type == self.ST_CASH:
|
||||
self.position["cash"] += self.position["cash_delay"]
|
||||
@@ -507,10 +514,10 @@ class InfPosition(BasePosition):
|
||||
# InfPosition always have any stocks
|
||||
return True
|
||||
|
||||
def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float):
|
||||
def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float) -> None:
|
||||
pass
|
||||
|
||||
def update_stock_price(self, stock_id, price: float):
|
||||
def update_stock_price(self, stock_id: str, price: float) -> None:
|
||||
pass
|
||||
|
||||
def calculate_stock_value(self) -> float:
|
||||
@@ -522,17 +529,20 @@ class InfPosition(BasePosition):
|
||||
"""
|
||||
return np.inf
|
||||
|
||||
def get_stock_list(self) -> List:
|
||||
def calculate_value(self) -> float:
|
||||
raise NotImplementedError(f"InfPosition doesn't support calculating value")
|
||||
|
||||
def get_stock_list(self) -> list:
|
||||
raise NotImplementedError(f"InfPosition doesn't support stock list position")
|
||||
|
||||
def get_stock_price(self, code) -> float:
|
||||
def get_stock_price(self, code: str) -> float:
|
||||
"""the price of the inf position is meaningless"""
|
||||
return np.nan
|
||||
|
||||
def get_stock_amount(self, code) -> float:
|
||||
def get_stock_amount(self, code: str) -> float:
|
||||
return np.inf
|
||||
|
||||
def get_cash(self, include_settle=False) -> float:
|
||||
def get_cash(self, include_settle: bool = False) -> float:
|
||||
return np.inf
|
||||
|
||||
def get_stock_amount_dict(self) -> Dict:
|
||||
@@ -541,14 +551,14 @@ class InfPosition(BasePosition):
|
||||
def get_stock_weight_dict(self, only_stock: bool = False) -> Dict:
|
||||
raise NotImplementedError(f"InfPosition doesn't support get_stock_weight_dict")
|
||||
|
||||
def add_count_all(self, bar):
|
||||
def add_count_all(self, bar: str) -> None:
|
||||
raise NotImplementedError(f"InfPosition doesn't support add_count_all")
|
||||
|
||||
def update_weight_all(self):
|
||||
def update_weight_all(self) -> None:
|
||||
raise NotImplementedError(f"InfPosition doesn't support update_weight_all")
|
||||
|
||||
def settle_start(self, settle_type: str):
|
||||
def settle_start(self, settle_type: str) -> None:
|
||||
pass
|
||||
|
||||
def settle_commit(self):
|
||||
def settle_commit(self) -> None:
|
||||
pass
|
||||
|
||||
@@ -4,14 +4,16 @@
|
||||
This module is not well maintained.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from .position import Position
|
||||
from ..data import D
|
||||
from ..config import C
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from ..config import C
|
||||
from ..data import D
|
||||
from .position import Position
|
||||
|
||||
|
||||
def get_benchmark_weight(
|
||||
bench,
|
||||
@@ -214,7 +216,9 @@ def get_stock_group(stock_group_field_df, bench_stock_weight_df, group_method, g
|
||||
for idx, row in (~bench_stock_weight_df.isna()).iterrows():
|
||||
bench_values = stock_group_field_df.loc[idx, row[row].index]
|
||||
new_stock_group_df.loc[idx] = get_daily_bin_group(
|
||||
bench_values, stock_group_field_df.loc[idx], group_n=group_n
|
||||
bench_values,
|
||||
stock_group_field_df.loc[idx],
|
||||
group_n=group_n,
|
||||
)
|
||||
return new_stock_group_df
|
||||
|
||||
@@ -315,7 +319,7 @@ def brinson_pa(
|
||||
# The excess profit from the interaction of assets allocation and stocks selection
|
||||
"RIN": Q4 - Q3 - Q2 + Q1,
|
||||
"RTotal": Q4 - Q1, # The totoal excess profit
|
||||
}
|
||||
},
|
||||
),
|
||||
{
|
||||
"port_group_ret": port_group_ret_df,
|
||||
|
||||
@@ -2,19 +2,20 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
from collections import OrderedDict
|
||||
import pathlib
|
||||
from collections import OrderedDict
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from qlib.backtest.exchange import Exchange
|
||||
import qlib.utils.index_data as idd
|
||||
from qlib.backtest.decision import BaseTradeDecision, Order, OrderDir
|
||||
from .high_performance_ds import BaseOrderIndicator, NumpyOrderIndicator, SingleMetric
|
||||
from qlib.backtest.exchange import Exchange
|
||||
|
||||
from ..tests.config import CSI300_BENCH
|
||||
from ..utils.resam import get_higher_eq_freq_feature, resam_ts_data
|
||||
import qlib.utils.index_data as idd
|
||||
from .high_performance_ds import BaseOrderIndicator, NumpyOrderIndicator, SingleMetric
|
||||
|
||||
|
||||
class PortfolioMetrics:
|
||||
@@ -161,7 +162,8 @@ class PortfolioMetrics:
|
||||
stock_value,
|
||||
]:
|
||||
raise ValueError(
|
||||
"None in [trade_start_time, account_value, cash, return_rate, total_turnover, turnover_rate, total_cost, cost_rate, stock_value]"
|
||||
"None in [trade_start_time, account_value, cash, return_rate, total_turnover, turnover_rate, "
|
||||
"total_cost, cost_rate, stock_value]",
|
||||
)
|
||||
|
||||
if trade_end_time is None and bench_value is None:
|
||||
@@ -335,7 +337,10 @@ class Indicator:
|
||||
# sum inner order indicators with same metric.
|
||||
all_metric = ["inner_amount", "deal_amount", "trade_price", "trade_value", "trade_cost", "trade_dir"]
|
||||
self.order_indicator_cls.sum_all_indicators(
|
||||
self.order_indicator, inner_order_indicators, all_metric, fill_value=0
|
||||
self.order_indicator,
|
||||
inner_order_indicators,
|
||||
all_metric,
|
||||
fill_value=0,
|
||||
)
|
||||
|
||||
def func(trade_price, deal_amount):
|
||||
@@ -378,12 +383,17 @@ class Indicator:
|
||||
|
||||
if decision.trade_range is not None:
|
||||
trade_start_time, trade_end_time = decision.trade_range.clip_time_range(
|
||||
start_time=trade_start_time, end_time=trade_end_time
|
||||
start_time=trade_start_time,
|
||||
end_time=trade_end_time,
|
||||
)
|
||||
|
||||
if price == "deal_price":
|
||||
price_s = trade_exchange.get_deal_price(
|
||||
inst, trade_start_time, trade_end_time, direction=direction, method=None
|
||||
inst,
|
||||
trade_start_time,
|
||||
trade_end_time,
|
||||
direction=direction,
|
||||
method=None,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"This type of input is not supported")
|
||||
@@ -599,8 +609,12 @@ class Indicator:
|
||||
if show_indicator:
|
||||
print(
|
||||
"[Indicator({}) {:%Y-%m-%d %H:%M:%S}]: FFR: {}, PA: {}, POS: {}".format(
|
||||
freq, trade_start_time, fulfill_rate, price_advantage, positive_rate
|
||||
)
|
||||
freq,
|
||||
trade_start_time,
|
||||
fulfill_rate,
|
||||
price_advantage,
|
||||
positive_rate,
|
||||
),
|
||||
)
|
||||
|
||||
def get_order_indicator(self, raw: bool = True):
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from qlib.utils import init_instance_by_config
|
||||
import abc
|
||||
from typing import Dict, List, Text, Tuple, Union
|
||||
from ..model.base import BaseModel
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from qlib.utils import init_instance_by_config
|
||||
|
||||
from ..data.dataset import Dataset
|
||||
from ..data.dataset.utils import convert_index_format
|
||||
from ..model.base import BaseModel
|
||||
from ..utils.resam import resam_ts_data
|
||||
import pandas as pd
|
||||
import abc
|
||||
|
||||
|
||||
class Signal(metaclass=abc.ABCMeta):
|
||||
@@ -82,7 +85,7 @@ class ModelSignal(SignalWCache):
|
||||
|
||||
|
||||
def create_signal_from(
|
||||
obj: Union[Signal, Tuple[BaseModel, Dataset], List, Dict, Text, pd.Series, pd.DataFrame]
|
||||
obj: Union[Signal, Tuple[BaseModel, Dataset], List, Dict, Text, pd.Series, pd.DataFrame],
|
||||
) -> Signal:
|
||||
"""
|
||||
create signal from diverse information
|
||||
|
||||
@@ -2,16 +2,22 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import bisect
|
||||
from abc import abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Set, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from qlib.utils.time import epsilon_change
|
||||
from typing import TYPE_CHECKING, Tuple, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qlib.backtest.decision import BaseTradeDecision
|
||||
|
||||
import pandas as pd
|
||||
import warnings
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from ..data.data import Cal
|
||||
|
||||
|
||||
@@ -26,8 +32,8 @@ class TradeCalendarManager:
|
||||
freq: str,
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
level_infra: "LevelInfrastructure" = None,
|
||||
):
|
||||
level_infra: LevelInfrastructure = None,
|
||||
) -> None:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -43,19 +49,26 @@ class TradeCalendarManager:
|
||||
self.level_infra = level_infra
|
||||
self.reset(freq=freq, start_time=start_time, end_time=end_time)
|
||||
|
||||
def reset(self, freq, start_time, end_time):
|
||||
def reset(
|
||||
self,
|
||||
freq: str,
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Please refer to the docs of `__init__`
|
||||
|
||||
Reset the trade calendar
|
||||
- self.trade_len : The total count for trading step
|
||||
- self.trade_step : The number of trading step finished, self.trade_step can be [0, 1, 2, ..., self.trade_len - 1]
|
||||
- self.trade_step : The number of trading step finished, self.trade_step can be
|
||||
[0, 1, 2, ..., self.trade_len - 1]
|
||||
"""
|
||||
self.freq = freq
|
||||
self.start_time = pd.Timestamp(start_time) if start_time else None
|
||||
self.end_time = pd.Timestamp(end_time) if end_time else None
|
||||
|
||||
_calendar = Cal.calendar(freq=freq, future=True)
|
||||
assert isinstance(_calendar, np.ndarray)
|
||||
self._calendar = _calendar
|
||||
_, _, _start_index, _end_index = Cal.locate_index(start_time, end_time, freq=freq, future=True)
|
||||
self.start_index = _start_index
|
||||
@@ -63,7 +76,7 @@ class TradeCalendarManager:
|
||||
self.trade_len = _end_index - _start_index + 1
|
||||
self.trade_step = 0
|
||||
|
||||
def finished(self):
|
||||
def finished(self) -> bool:
|
||||
"""
|
||||
Check if the trading finished
|
||||
- Should check before calling strategy.generate_decisions and executor.execute
|
||||
@@ -72,29 +85,32 @@ class TradeCalendarManager:
|
||||
"""
|
||||
return self.trade_step >= self.trade_len
|
||||
|
||||
def step(self):
|
||||
def step(self) -> None:
|
||||
if self.finished():
|
||||
raise RuntimeError(f"The calendar is finished, please reset it if you want to call it!")
|
||||
self.trade_step = self.trade_step + 1
|
||||
self.trade_step += 1
|
||||
|
||||
def get_freq(self):
|
||||
def get_freq(self) -> str:
|
||||
return self.freq
|
||||
|
||||
def get_trade_len(self):
|
||||
def get_trade_len(self) -> int:
|
||||
"""get the total step length"""
|
||||
return self.trade_len
|
||||
|
||||
def get_trade_step(self):
|
||||
def get_trade_step(self) -> int:
|
||||
return self.trade_step
|
||||
|
||||
def get_step_time(self, trade_step=None, shift=0):
|
||||
def get_step_time(self, trade_step: int = None, shift: int = 0) -> Tuple[pd.Timestamp, pd.Timestamp]:
|
||||
"""
|
||||
Get the left and right endpoints of the trade_step'th trading interval
|
||||
|
||||
About the endpoints:
|
||||
- Qlib uses the closed interval in time-series data selection, which has the same performance as pandas.Series.loc
|
||||
# - The returned right endpoints should minus 1 seconds because of the closed interval representation in Qlib.
|
||||
# Note: Qlib supports up to minutely decision execution, so 1 seconds is less than any trading time interval.
|
||||
- Qlib uses the closed interval in time-series data selection, which has the same performance as
|
||||
pandas.Series.loc
|
||||
# - The returned right endpoints should minus 1 seconds because of the closed interval representation in
|
||||
# Qlib.
|
||||
# Note: Qlib supports up to minutely decision execution, so 1 seconds is less than any trading time
|
||||
# interval.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -105,15 +121,14 @@ class TradeCalendarManager:
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[pd.Timestamp, pd.Timestap]
|
||||
Tuple[pd.Timestamp, pd.Timestamp]
|
||||
- If shift == 0, return the trading time range
|
||||
- If shift > 0, return the trading time range of the earlier shift bars
|
||||
- If shift < 0, return the trading time range of the later shift bar
|
||||
"""
|
||||
if trade_step is None:
|
||||
trade_step = self.get_trade_step()
|
||||
trade_step = trade_step - shift
|
||||
calendar_index = self.start_index + trade_step
|
||||
calendar_index = self.start_index + trade_step - shift
|
||||
return self._calendar[calendar_index], epsilon_change(self._calendar[calendar_index + 1])
|
||||
|
||||
def get_data_cal_range(self, rtype: str = "full") -> Tuple[int, int]:
|
||||
@@ -126,7 +141,7 @@ class TradeCalendarManager:
|
||||
Parameters
|
||||
----------
|
||||
rtype: str
|
||||
- "full": return the full limitation of the deicsion in the day
|
||||
- "full": return the full limitation of the decision in the day
|
||||
- "step": return the limitation of current step
|
||||
|
||||
Returns
|
||||
@@ -148,7 +163,7 @@ class TradeCalendarManager:
|
||||
|
||||
return start_idx - day_start_idx, end_index - day_start_idx
|
||||
|
||||
def get_all_time(self):
|
||||
def get_all_time(self) -> Tuple[pd.Timestamp, pd.Timestamp]:
|
||||
"""Get the start_time and end_time for trading"""
|
||||
return self.start_time, self.end_time
|
||||
|
||||
@@ -167,30 +182,33 @@ class TradeCalendarManager:
|
||||
Tuple[int, int]:
|
||||
the index of the range. **the left and right are closed**
|
||||
"""
|
||||
left, right = (
|
||||
bisect.bisect_right(self._calendar, start_time) - 1,
|
||||
bisect.bisect_right(self._calendar, end_time) - 1,
|
||||
)
|
||||
left = bisect.bisect_right(self._calendar, start_time) - 1
|
||||
right = bisect.bisect_right(self._calendar, end_time) - 1
|
||||
left -= self.start_index
|
||||
right -= self.start_index
|
||||
|
||||
def clip(idx):
|
||||
def clip(idx: int) -> int:
|
||||
return min(max(0, idx), self.trade_len - 1)
|
||||
|
||||
return clip(left), clip(right)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"class: {self.__class__.__name__}; {self.start_time}[{self.start_index}]~{self.end_time}[{self.end_index}]: [{self.trade_step}/{self.trade_len}]"
|
||||
return (
|
||||
f"class: {self.__class__.__name__}; "
|
||||
f"{self.start_time}[{self.start_index}]~{self.end_time}[{self.end_index}]: "
|
||||
f"[{self.trade_step}/{self.trade_len}]"
|
||||
)
|
||||
|
||||
|
||||
class BaseInfrastructure:
|
||||
def __init__(self, **kwargs):
|
||||
def __init__(self, **kwargs) -> None:
|
||||
self.reset_infra(**kwargs)
|
||||
|
||||
def get_support_infra(self):
|
||||
@abstractmethod
|
||||
def get_support_infra(self) -> Set[str]:
|
||||
raise NotImplementedError("`get_support_infra` is not implemented!")
|
||||
|
||||
def reset_infra(self, **kwargs):
|
||||
def reset_infra(self, **kwargs) -> None:
|
||||
support_infra = self.get_support_infra()
|
||||
for k, v in kwargs.items():
|
||||
if k in support_infra:
|
||||
@@ -198,53 +216,58 @@ class BaseInfrastructure:
|
||||
else:
|
||||
warnings.warn(f"{k} is ignored in `reset_infra`!")
|
||||
|
||||
def get(self, infra_name):
|
||||
def get(self, infra_name: str) -> Any:
|
||||
if hasattr(self, infra_name):
|
||||
return getattr(self, infra_name)
|
||||
else:
|
||||
warnings.warn(f"infra {infra_name} is not found!")
|
||||
|
||||
def has(self, infra_name):
|
||||
def has(self, infra_name: str) -> bool:
|
||||
return infra_name in self.get_support_infra() and hasattr(self, infra_name)
|
||||
|
||||
def update(self, other):
|
||||
def update(self, other: BaseInfrastructure) -> None:
|
||||
support_infra = other.get_support_infra()
|
||||
infra_dict = {_infra: getattr(other, _infra) for _infra in support_infra if hasattr(other, _infra)}
|
||||
self.reset_infra(**infra_dict)
|
||||
|
||||
|
||||
class CommonInfrastructure(BaseInfrastructure):
|
||||
def get_support_infra(self):
|
||||
return ["trade_account", "trade_exchange"]
|
||||
def get_support_infra(self) -> Set[str]:
|
||||
return {"trade_account", "trade_exchange"}
|
||||
|
||||
|
||||
class LevelInfrastructure(BaseInfrastructure):
|
||||
"""level infrastructure is created by executor, and then shared to strategies on the same level"""
|
||||
|
||||
def get_support_infra(self):
|
||||
def get_support_infra(self) -> Set[str]:
|
||||
"""
|
||||
Descriptions about the infrastructure
|
||||
|
||||
sub_level_infra:
|
||||
- **NOTE**: this will only work after _init_sub_trading !!!
|
||||
"""
|
||||
return ["trade_calendar", "sub_level_infra", "common_infra"]
|
||||
return {"trade_calendar", "sub_level_infra", "common_infra"}
|
||||
|
||||
def reset_cal(self, freq, start_time, end_time):
|
||||
def reset_cal(
|
||||
self,
|
||||
freq: str,
|
||||
start_time: Union[str, pd.Timestamp, None],
|
||||
end_time: Union[str, pd.Timestamp, None],
|
||||
) -> None:
|
||||
"""reset trade calendar manager"""
|
||||
if self.has("trade_calendar"):
|
||||
self.get("trade_calendar").reset(freq, start_time=start_time, end_time=end_time)
|
||||
else:
|
||||
self.reset_infra(
|
||||
trade_calendar=TradeCalendarManager(freq, start_time=start_time, end_time=end_time, level_infra=self)
|
||||
trade_calendar=TradeCalendarManager(freq, start_time=start_time, end_time=end_time, level_infra=self),
|
||||
)
|
||||
|
||||
def set_sub_level_infra(self, sub_level_infra: LevelInfrastructure):
|
||||
"""this will make the calendar access easier when acrossing multi-levels"""
|
||||
def set_sub_level_infra(self, sub_level_infra: LevelInfrastructure) -> None:
|
||||
"""this will make the calendar access easier when crossing multi-levels"""
|
||||
self.reset_infra(sub_level_infra=sub_level_infra)
|
||||
|
||||
|
||||
def get_start_end_idx(trade_calendar: TradeCalendarManager, outer_trade_decision: BaseTradeDecision) -> Union[int, int]:
|
||||
def get_start_end_idx(trade_calendar: TradeCalendarManager, outer_trade_decision: BaseTradeDecision) -> Tuple[int, int]:
|
||||
"""
|
||||
A helper function for getting the decision-level index range limitation for inner strategy
|
||||
- NOTE: this function is not applicable to order-level
|
||||
|
||||
@@ -75,6 +75,17 @@ class Config:
|
||||
def set_conf_from_C(self, config_c):
|
||||
self.update(**config_c.__dict__["_config"])
|
||||
|
||||
def register_from_C(self, config, skip_register=True):
|
||||
from .utils import set_log_with_config # pylint: disable=C0415
|
||||
|
||||
if C.registered and skip_register:
|
||||
return
|
||||
|
||||
C.set_conf_from_C(config)
|
||||
if C.logging_config:
|
||||
set_log_with_config(C.logging_config)
|
||||
C.register()
|
||||
|
||||
|
||||
# pickle.dump protocol version: https://docs.python.org/3/library/pickle.html#data-stream-format
|
||||
PROTOCOL_VERSION = 4
|
||||
@@ -102,7 +113,7 @@ _default_config = {
|
||||
# "~/.qlib/stock_data/cn_data"
|
||||
# # dict
|
||||
# {"day": "~/.qlib/stock_data/cn_data", "1min": "~/.qlib/stock_data/cn_data_1min"}
|
||||
# NOTE: provider_uri priority:
|
||||
# NOTE: provider_uri priority:
|
||||
# 1. backend_config: backend_obj["kwargs"]["provider_uri"]
|
||||
# 2. backend_config: backend_obj["kwargs"]["provider_uri_map"]
|
||||
# 3. qlib.init: provider_uri
|
||||
|
||||
@@ -63,11 +63,20 @@ def _get_date_parse_fn(target):
|
||||
get_date_parse_fn(20120101)('2017-01-01') => 20170101
|
||||
"""
|
||||
if isinstance(target, int):
|
||||
_fn = lambda x: int(str(x).replace("-", "")[:8]) # 20200201
|
||||
|
||||
def _fn(x):
|
||||
return int(str(x).replace("-", "")[:8]) # 20200201
|
||||
|
||||
elif isinstance(target, str) and len(target) == 8:
|
||||
_fn = lambda x: str(x).replace("-", "")[:8] # '20200201'
|
||||
|
||||
def _fn(x):
|
||||
return str(x).replace("-", "")[:8] # '20200201'
|
||||
|
||||
else:
|
||||
_fn = lambda x: x # '2021-01-01'
|
||||
|
||||
def _fn(x):
|
||||
return x # '2021-01-01'
|
||||
|
||||
return _fn
|
||||
|
||||
|
||||
|
||||
@@ -255,7 +255,10 @@ class Alpha158(DataHandlerLP):
|
||||
exclude = config["rolling"].get("exclude", [])
|
||||
# `exclude` in dataset config unnecessary filed
|
||||
# `include` in dataset config necessary field
|
||||
use = lambda x: x not in exclude and (include is None or x in include)
|
||||
|
||||
def use(x):
|
||||
return x not in exclude and (include is None or x in include)
|
||||
|
||||
if use("ROC"):
|
||||
fields += ["Ref($close, %d)/$close" % d for d in windows]
|
||||
names += ["ROC%d" % d for d in windows]
|
||||
|
||||
@@ -48,7 +48,9 @@ def calc_long_short_prec(
|
||||
|
||||
group = df.groupby(level=date_col)
|
||||
|
||||
N = lambda x: int(len(x) * quantile)
|
||||
def N(x):
|
||||
return int(len(x) * quantile)
|
||||
|
||||
# find the top/low quantile of prediction and treat them as long and short target
|
||||
long = group.apply(lambda x: x.nlargest(N(x), columns="pred").label).reset_index(level=0, drop=True)
|
||||
short = group.apply(lambda x: x.nsmallest(N(x), columns="pred").label).reset_index(level=0, drop=True)
|
||||
@@ -98,7 +100,10 @@ def calc_long_short_return(
|
||||
if dropna:
|
||||
df.dropna(inplace=True)
|
||||
group = df.groupby(level=date_col)
|
||||
N = lambda x: int(len(x) * quantile)
|
||||
|
||||
def N(x):
|
||||
return int(len(x) * quantile)
|
||||
|
||||
r_long = group.apply(lambda x: x.nlargest(N(x), columns="pred").label.mean())
|
||||
r_short = group.apply(lambda x: x.nsmallest(N(x), columns="pred").label.mean())
|
||||
r_avg = group.label.mean()
|
||||
|
||||
@@ -26,6 +26,13 @@ logger = get_module_logger("Evaluate")
|
||||
|
||||
def risk_analysis(r, N: int = None, freq: str = "day"):
|
||||
"""Risk Analysis
|
||||
NOTE:
|
||||
The calculation of annulaized return is different from the definition of annualized return.
|
||||
It is implemented by design.
|
||||
Qlib tries to cumulated returns by summation instead of production to avoid the cumulated curve being skewed exponentially.
|
||||
All the calculation of annualized returns follows this principle in Qlib.
|
||||
|
||||
TODO: add a parameter to enable calculating metrics with production accumulation of return.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
@@ -217,7 +217,7 @@ class MetaDatasetDS(MetaTaskDataset):
|
||||
----------
|
||||
task_tpl : Union[dict, list]
|
||||
Decide what tasks are used.
|
||||
- dict : the task template, the prepared task is generated with `step`, `trunc_days` and `RollingGen`
|
||||
- dict : the task template, the prepared task is generated with `step`, `trunc_days` and `RollingGen`
|
||||
- list : when list, use the list of tasks directly
|
||||
the list is supposed to be sorted according timeline
|
||||
step : int
|
||||
@@ -290,7 +290,7 @@ class MetaDatasetDS(MetaTaskDataset):
|
||||
ic_df = self.internal_data.data_ic_df
|
||||
|
||||
segs = task["dataset"]["kwargs"]["segments"]
|
||||
end = max([segs[k][1] for k in ("train", "valid") if k in segs])
|
||||
end = max(segs[k][1] for k in ("train", "valid") if k in segs)
|
||||
ic_df_avail = ic_df.loc[:end, pd.IndexSlice[:, :end]]
|
||||
|
||||
# meta data set focus on the **information** instead of preprocess
|
||||
|
||||
@@ -92,7 +92,10 @@ class HFLGBModel(ModelFT, LightGBMFInt):
|
||||
# Convert label into alpha
|
||||
df_train["label"][l_name] = df_train["label"][l_name] - df_train["label"][l_name].mean(level=0)
|
||||
df_valid["label"][l_name] = df_valid["label"][l_name] - df_valid["label"][l_name].mean(level=0)
|
||||
mapping_fn = lambda x: 0 if x < 0 else 1
|
||||
|
||||
def mapping_fn(x):
|
||||
return 0 if x < 0 else 1
|
||||
|
||||
df_train["label_c"] = df_train["label"][l_name].apply(mapping_fn)
|
||||
df_valid["label_c"] = df_valid["label"][l_name].apply(mapping_fn)
|
||||
x_train, y_train = df_train["feature"], df_train["label_c"].values
|
||||
|
||||
@@ -292,7 +292,9 @@ class HIST(Model):
|
||||
pretrained_model.load_state_dict(torch.load(self.model_path))
|
||||
|
||||
model_dict = self.HIST_model.state_dict()
|
||||
pretrained_dict = {k: v for k, v in pretrained_model.state_dict().items() if k in model_dict}
|
||||
pretrained_dict = {
|
||||
k: v for k, v in pretrained_model.state_dict().items() if k in model_dict # pylint: disable=E1135
|
||||
}
|
||||
model_dict.update(pretrained_dict)
|
||||
self.HIST_model.load_state_dict(model_dict)
|
||||
self.logger.info("Loading pretrained model Done...")
|
||||
|
||||
@@ -53,7 +53,7 @@ class TabnetModel(Model):
|
||||
"""
|
||||
TabNet model for Qlib
|
||||
|
||||
Args:
|
||||
Args:
|
||||
ps: probability to generate the bernoulli mask
|
||||
"""
|
||||
# set hyper-parameters.
|
||||
|
||||
@@ -167,8 +167,8 @@ class TRAModel(Model):
|
||||
for param in self.tra.predictors.parameters():
|
||||
param.requires_grad_(False)
|
||||
|
||||
self.logger.info("# model params: %d" % sum([p.numel() for p in self.model.parameters() if p.requires_grad]))
|
||||
self.logger.info("# tra params: %d" % sum([p.numel() for p in self.tra.parameters() if p.requires_grad]))
|
||||
self.logger.info("# model params: %d" % sum(p.numel() for p in self.model.parameters() if p.requires_grad))
|
||||
self.logger.info("# tra params: %d" % sum(p.numel() for p in self.tra.parameters() if p.requires_grad))
|
||||
|
||||
self.optimizer = optim.Adam(list(self.model.parameters()) + list(self.tra.parameters()), lr=self.lr)
|
||||
|
||||
|
||||
@@ -32,7 +32,6 @@ from ..utils import (
|
||||
hash_args,
|
||||
normalize_cache_fields,
|
||||
code_to_fname,
|
||||
set_log_with_config,
|
||||
time_to_slc_point,
|
||||
read_period_data,
|
||||
get_period_list,
|
||||
@@ -603,11 +602,7 @@ class DatasetProvider(abc.ABC):
|
||||
"""
|
||||
# FIXME: Windows OS or MacOS using spawn: https://docs.python.org/3.8/library/multiprocessing.html?highlight=spawn#contexts-and-start-methods
|
||||
# NOTE: This place is compatible with windows, windows multi-process is spawn
|
||||
if not C.registered:
|
||||
C.set_conf_from_C(g_config)
|
||||
if C.logging_config:
|
||||
set_log_with_config(C.logging_config)
|
||||
C.register()
|
||||
C.register_from_C(g_config)
|
||||
|
||||
obj = dict()
|
||||
for field in column_names:
|
||||
|
||||
@@ -438,7 +438,7 @@ class TSDataSampler:
|
||||
|
||||
@property
|
||||
def empty(self):
|
||||
return self.__len__() == 0
|
||||
return len(self) == 0
|
||||
|
||||
def _get_indices(self, row: int, col: int) -> np.array:
|
||||
"""
|
||||
|
||||
@@ -24,7 +24,7 @@ class FileStorageMixin:
|
||||
|
||||
"""
|
||||
|
||||
# NOTE: provider_uri priority:
|
||||
# NOTE: provider_uri priority:
|
||||
# 1. self._provider_uri : if provider_uri is provided.
|
||||
# 2. provider_uri in qlib.config.C
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ Ensemble module can merge the objects in an Ensemble. For example, if there are
|
||||
from typing import Union
|
||||
import pandas as pd
|
||||
from qlib.utils import FLATTEN_TUPLE, flatten_dict
|
||||
from qlib.log import get_module_logger
|
||||
|
||||
|
||||
class Ensemble:
|
||||
@@ -79,6 +80,7 @@ class RollingEnsemble(Ensemble):
|
||||
"""
|
||||
|
||||
def __call__(self, ensemble_dict: dict) -> pd.DataFrame:
|
||||
get_module_logger("RollingEnsemble").info(f"keys in group: {list(ensemble_dict.keys())}")
|
||||
artifact_list = list(ensemble_dict.values())
|
||||
artifact_list.sort(key=lambda x: x.index.get_level_values("datetime").min())
|
||||
artifact = pd.concat(artifact_list)
|
||||
@@ -121,6 +123,7 @@ class AverageEnsemble(Ensemble):
|
||||
"""
|
||||
# need to flatten the nested dict
|
||||
ensemble_dict = flatten_dict(ensemble_dict, sep=FLATTEN_TUPLE)
|
||||
get_module_logger("AverageEnsemble").info(f"keys in group: {list(ensemble_dict.keys())}")
|
||||
values = list(ensemble_dict.values())
|
||||
# NOTE: this may change the style underlying data!!!!
|
||||
# from pd.DataFrame to pd.Series
|
||||
|
||||
@@ -15,13 +15,22 @@ import socket
|
||||
from typing import Callable, List
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from qlib.config import C
|
||||
from qlib.data.dataset import Dataset
|
||||
from qlib.data.dataset.weight import Reweighter
|
||||
from qlib.log import get_module_logger
|
||||
from qlib.model.base import Model
|
||||
from qlib.utils import flatten_dict, init_instance_by_config, auto_filter_kwargs, fill_placeholder
|
||||
from qlib.utils import (
|
||||
auto_filter_kwargs,
|
||||
fill_placeholder,
|
||||
flatten_dict,
|
||||
init_instance_by_config,
|
||||
)
|
||||
from qlib.utils.paral import call_in_subproc
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.recorder import Recorder
|
||||
from qlib.workflow.task.manage import TaskManager, run_task
|
||||
from qlib.data.dataset.weight import Reweighter
|
||||
|
||||
|
||||
def _log_task_info(task_config: dict):
|
||||
@@ -210,17 +219,19 @@ class TrainerR(Trainer):
|
||||
STATUS_BEGIN = "begin_task_train"
|
||||
STATUS_END = "end_task_train"
|
||||
|
||||
def __init__(self, experiment_name: str = None, train_func: Callable = task_train):
|
||||
def __init__(self, experiment_name: str = None, train_func: Callable = task_train, call_in_subproc: bool = False):
|
||||
"""
|
||||
Init TrainerR.
|
||||
|
||||
Args:
|
||||
experiment_name (str, optional): the default name of experiment.
|
||||
train_func (Callable, optional): default training method. Defaults to `task_train`.
|
||||
call_in_subproc (bool): call the process in subprocess to force memory release
|
||||
"""
|
||||
super().__init__()
|
||||
self.experiment_name = experiment_name
|
||||
self.train_func = train_func
|
||||
self._call_in_subproc = call_in_subproc
|
||||
|
||||
def train(self, tasks: list, train_func: Callable = None, experiment_name: str = None, **kwargs) -> List[Recorder]:
|
||||
"""
|
||||
@@ -245,6 +256,9 @@ class TrainerR(Trainer):
|
||||
experiment_name = self.experiment_name
|
||||
recs = []
|
||||
for task in tqdm(tasks, desc="train tasks"):
|
||||
if self._call_in_subproc:
|
||||
get_module_logger("TrainerR").info("running models in sub process (for forcing release memroy).")
|
||||
train_func = call_in_subproc(train_func, C)
|
||||
rec = train_func(task, experiment_name, **kwargs)
|
||||
rec.set_tags(**{self.STATUS_KEY: self.STATUS_BEGIN})
|
||||
recs.append(rec)
|
||||
|
||||
@@ -145,7 +145,7 @@ class DataQueue(Generic[T]):
|
||||
def __iter__(self):
|
||||
if not self._activated:
|
||||
raise ValueError(
|
||||
"Need to call activate() to launch a daemon worker " "to produce data into data queue before using it."
|
||||
"Need to call activate() to launch a daemon worker to produce data into data queue before using it."
|
||||
)
|
||||
return self._consumer()
|
||||
|
||||
|
||||
@@ -1,17 +1,20 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
from __future__ import annotations
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from qlib.backtest.exchange import Exchange
|
||||
from qlib.backtest.position import BasePosition
|
||||
|
||||
from typing import Tuple, Union
|
||||
|
||||
from ..backtest.decision import BaseTradeDecision
|
||||
from ..backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager
|
||||
from ..rl.interpreter import ActionInterpreter, StateInterpreter
|
||||
from ..utils import init_instance_by_config
|
||||
from ..backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager
|
||||
from ..backtest.decision import BaseTradeDecision
|
||||
|
||||
__all__ = ["BaseStrategy", "RLStrategy", "RLIntStrategy"]
|
||||
|
||||
@@ -25,12 +28,13 @@ class BaseStrategy:
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
trade_exchange: Exchange = None,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
outer_trade_decision : BaseTradeDecision, optional
|
||||
the trade decision of outer strategy which this strategy relies, and it will be traded in [start_time, end_time], by default None
|
||||
the trade decision of outer strategy which this strategy relies, and it will be traded in
|
||||
[start_time, end_time], by default None
|
||||
- If the strategy is used to split trade decision, it will be used
|
||||
- If the strategy is used for portfolio management, it can be ignored
|
||||
level_infra : LevelInfrastructure, optional
|
||||
@@ -41,9 +45,10 @@ class BaseStrategy:
|
||||
trade_exchange : Exchange
|
||||
exchange that provides market info, used to deal order and generate report
|
||||
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
|
||||
- It allowes different trade_exchanges is used in different executions.
|
||||
- It allows different trade_exchanges is used in different executions.
|
||||
- For example:
|
||||
- In daily execution, both daily exchange and minutely are usable, but the daily exchange is recommended because it run faster.
|
||||
- In daily execution, both daily exchange and minutely are usable, but the daily exchange is
|
||||
recommended because it run faster.
|
||||
- In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended.
|
||||
"""
|
||||
|
||||
@@ -63,13 +68,13 @@ class BaseStrategy:
|
||||
"""get trade exchange in a prioritized order"""
|
||||
return getattr(self, "_trade_exchange", None) or self.common_infra.get("trade_exchange")
|
||||
|
||||
def reset_level_infra(self, level_infra: LevelInfrastructure):
|
||||
def reset_level_infra(self, level_infra: LevelInfrastructure) -> None:
|
||||
if not hasattr(self, "level_infra"):
|
||||
self.level_infra = level_infra
|
||||
else:
|
||||
self.level_infra.update(level_infra)
|
||||
|
||||
def reset_common_infra(self, common_infra: CommonInfrastructure):
|
||||
def reset_common_infra(self, common_infra: CommonInfrastructure) -> None:
|
||||
if not hasattr(self, "common_infra"):
|
||||
self.common_infra: CommonInfrastructure = common_infra
|
||||
else:
|
||||
@@ -79,9 +84,9 @@ class BaseStrategy:
|
||||
self,
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
outer_trade_decision=None,
|
||||
**kwargs,
|
||||
):
|
||||
outer_trade_decision: BaseTradeDecision = None,
|
||||
**kwargs, # TODO: remove this?
|
||||
) -> None:
|
||||
"""
|
||||
- reset `level_infra`, used to reset trade calendar, .etc
|
||||
- reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc
|
||||
@@ -89,18 +94,20 @@ class BaseStrategy:
|
||||
|
||||
**NOTE**:
|
||||
split this function into `reset` and `_reset` will make following cases more convenient
|
||||
1. Users want to initialize his strategy by overriding `reset`, but they don't want to affect the `_reset` called
|
||||
when initialization
|
||||
1. Users want to initialize his strategy by overriding `reset`, but they don't want to affect the `_reset`
|
||||
called when initialization
|
||||
"""
|
||||
self._reset(
|
||||
level_infra=level_infra, common_infra=common_infra, outer_trade_decision=outer_trade_decision, **kwargs
|
||||
level_infra=level_infra,
|
||||
common_infra=common_infra,
|
||||
outer_trade_decision=outer_trade_decision,
|
||||
)
|
||||
|
||||
def _reset(
|
||||
self,
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
outer_trade_decision=None,
|
||||
outer_trade_decision: BaseTradeDecision = None,
|
||||
):
|
||||
"""
|
||||
Please refer to the docs of `reset`
|
||||
@@ -114,7 +121,8 @@ class BaseStrategy:
|
||||
if outer_trade_decision is not None:
|
||||
self.outer_trade_decision = outer_trade_decision
|
||||
|
||||
def generate_trade_decision(self, execute_result=None):
|
||||
@abstractmethod
|
||||
def generate_trade_decision(self, execute_result: list = None) -> BaseTradeDecision:
|
||||
"""Generate trade decision in each trading bar
|
||||
|
||||
Parameters
|
||||
@@ -125,9 +133,11 @@ class BaseStrategy:
|
||||
"""
|
||||
raise NotImplementedError("generate_trade_decision is not implemented!")
|
||||
|
||||
@staticmethod
|
||||
def update_trade_decision(
|
||||
self, trade_decision: BaseTradeDecision, trade_calendar: TradeCalendarManager
|
||||
) -> Union[BaseTradeDecision, None]:
|
||||
trade_decision: BaseTradeDecision,
|
||||
trade_calendar: TradeCalendarManager,
|
||||
) -> Optional[BaseTradeDecision]:
|
||||
"""
|
||||
update trade decision in each step of inner execution, this method enable all order
|
||||
|
||||
@@ -145,7 +155,8 @@ class BaseStrategy:
|
||||
# default to return None, which indicates that the trade decision is not changed
|
||||
return None
|
||||
|
||||
def alter_outer_trade_decision(self, outer_trade_decision: BaseTradeDecision):
|
||||
# FIXME: do not define this method as an abstract one since it is never implemented
|
||||
def alter_outer_trade_decision(self, outer_trade_decision: BaseTradeDecision) -> BaseTradeDecision:
|
||||
"""
|
||||
A method for updating the outer_trade_decision.
|
||||
The outer strategy may change its decision during updating.
|
||||
@@ -154,6 +165,10 @@ class BaseStrategy:
|
||||
----------
|
||||
outer_trade_decision : BaseTradeDecision
|
||||
the decision updated by the outer strategy
|
||||
|
||||
Returns
|
||||
-------
|
||||
BaseTradeDecision
|
||||
"""
|
||||
# default to reset the decision directly
|
||||
# NOTE: normally, user should do something to the strategy due to the change of outer decision
|
||||
@@ -200,7 +215,7 @@ class RLStrategy(BaseStrategy):
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
**kwargs,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -223,7 +238,7 @@ class RLIntStrategy(RLStrategy):
|
||||
level_infra: LevelInfrastructure = None,
|
||||
common_infra: CommonInfrastructure = None,
|
||||
**kwargs,
|
||||
):
|
||||
) -> None:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -242,7 +257,7 @@ class RLIntStrategy(RLStrategy):
|
||||
self.state_interpreter = init_instance_by_config(state_interpreter, accept_types=StateInterpreter)
|
||||
self.action_interpreter = init_instance_by_config(action_interpreter, accept_types=ActionInterpreter)
|
||||
|
||||
def generate_trade_decision(self, execute_result=None):
|
||||
def generate_trade_decision(self, execute_result: list = None) -> BaseTradeDecision:
|
||||
_interpret_state = self.state_interpreter.interpret(execute_result=execute_result)
|
||||
_action = self.policy.step(_interpret_state)
|
||||
_trade_decision = self.action_interpreter.interpret(action=_action)
|
||||
|
||||
@@ -16,7 +16,7 @@ from qlib.utils import exists_qlib_data
|
||||
|
||||
class GetData:
|
||||
DATASET_VERSION = "v2"
|
||||
REMOTE_URL = "http://fintech.msra.cn/stock_data/downloads"
|
||||
REMOTE_URL = "https://qlibpublic.blob.core.windows.net/data/default/stock_data"
|
||||
QLIB_DATA_NAME = "{dataset_name}_{region}_{interval}_{qlib_version}.zip"
|
||||
|
||||
def __init__(self, delete_zip_file=False):
|
||||
|
||||
@@ -376,7 +376,7 @@ get_cls_kwargs = get_callable_kwargs # NOTE: this is for compatibility for the
|
||||
|
||||
|
||||
def init_instance_by_config(
|
||||
config: Union[str, dict, object, Path],
|
||||
config: Union[str, dict, object, Path], # TODO: use a user-defined type to replace this Union.
|
||||
default_module=None,
|
||||
accept_types: Union[type, Tuple[type]] = (),
|
||||
try_kwargs: Dict = {},
|
||||
@@ -949,6 +949,10 @@ def auto_filter_kwargs(func: Callable, warning=True) -> Callable:
|
||||
|
||||
The decrated function will ignore and give warning when the parameter is not acceptable
|
||||
|
||||
For example, if you have a function `f` which may optionally consume the keywards `bar`.
|
||||
then you can call it by `auto_filter_kwargs(f)(bar=3)`, which will automatically filter out
|
||||
`bar` when f does not need bar
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func : Callable
|
||||
@@ -1063,4 +1067,5 @@ __all__ = [
|
||||
"unpack_archive_with_buffer",
|
||||
"get_tmp_file_with_buffer",
|
||||
"set_log_with_config",
|
||||
"init_instance_by_config",
|
||||
]
|
||||
|
||||
@@ -10,6 +10,9 @@ from joblib._parallel_backends import MultiprocessingBackend
|
||||
import pandas as pd
|
||||
|
||||
from queue import Queue
|
||||
import concurrent
|
||||
|
||||
from qlib.config import C, QlibConfig
|
||||
|
||||
|
||||
class ParallelExt(Parallel):
|
||||
@@ -273,3 +276,40 @@ def complex_parallel(paral: Parallel, complex_iter):
|
||||
dt.set_res(res)
|
||||
complex_iter = _recover_dt(complex_iter)
|
||||
return complex_iter
|
||||
|
||||
|
||||
class call_in_subproc:
|
||||
"""
|
||||
When we repeating run functions, it is hard to avoid memory leakage.
|
||||
So we run it in the subprocess to ensure it is OK.
|
||||
|
||||
NOTE: Because local object can't be pickled. So we can't implement it via closure.
|
||||
We have to implement it via callable Class
|
||||
"""
|
||||
|
||||
def __init__(self, func: Callable, qlib_config: QlibConfig = None):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
func : Callable
|
||||
the function to be wrapped
|
||||
|
||||
qlib_config : QlibConfig
|
||||
Qlib config for initialization in subprocess
|
||||
|
||||
Returns
|
||||
-------
|
||||
Callable
|
||||
"""
|
||||
self.func = func
|
||||
self.qlib_config = qlib_config
|
||||
|
||||
def _func_mod(self, *args, **kwargs):
|
||||
"""Modify the initial function by adding Qlib initialization"""
|
||||
if self.qlib_config is not None:
|
||||
C.register_from_C(self.qlib_config)
|
||||
return self.func(*args, **kwargs)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor:
|
||||
return executor.submit(self._func_mod, *args, **kwargs).result()
|
||||
|
||||
@@ -131,7 +131,7 @@ class DSBasedUpdater(RecordUpdater, metaclass=ABCMeta):
|
||||
|
||||
.. note::
|
||||
|
||||
the start_time is not included in the hist_ref
|
||||
the start_time is not included in the `hist_ref`; So the `hist_ref` will be `step_len - 1` in most cases
|
||||
|
||||
loader_cls : type
|
||||
the class to load the model and dataset
|
||||
@@ -184,9 +184,9 @@ class DSBasedUpdater(RecordUpdater, metaclass=ABCMeta):
|
||||
dataset: DatasetH = self.record.load_object("dataset") if unprepared_dataset is None else unprepared_dataset
|
||||
# Special treatment of historical dependencies
|
||||
if isinstance(dataset, TSDatasetH):
|
||||
hist_ref = dataset.step_len
|
||||
hist_ref = dataset.step_len - 1
|
||||
else:
|
||||
hist_ref = 0
|
||||
hist_ref = 0 # if only the lastest data is used, then only current data will be used and no historical data will be used
|
||||
else:
|
||||
hist_ref = self.hist_ref
|
||||
|
||||
|
||||
@@ -169,7 +169,10 @@ class RecorderCollector(Collector):
|
||||
self.experiment = experiment
|
||||
self.artifacts_path = artifacts_path
|
||||
if rec_key_func is None:
|
||||
rec_key_func = lambda rec: rec.info["id"]
|
||||
|
||||
def rec_key_func(rec):
|
||||
return rec.info["id"]
|
||||
|
||||
if artifacts_key is None:
|
||||
artifacts_key = list(self.artifacts_path.keys())
|
||||
self.rec_key_func = rec_key_func
|
||||
|
||||
@@ -488,7 +488,7 @@ class DumpDataUpdate(DumpDataBase):
|
||||
except Exception:
|
||||
error_code[futures[_future]] = traceback.format_exc()
|
||||
p_bar.update()
|
||||
logger.info(f"dump bin errors: {error_code}")
|
||||
logger.info(f"dump bin errors: {error_code}")
|
||||
|
||||
logger.info("end of features dump.\n")
|
||||
|
||||
|
||||
2
setup.py
2
setup.py
@@ -80,6 +80,8 @@ REQUIRED = [
|
||||
"filelock",
|
||||
"jinja2<3.1.0", # for passing the readthedocs workflow.
|
||||
"gym",
|
||||
# Installing the latest version of protobuf for python versions below 3.8 will cause unit tests to fail.
|
||||
"protobuf<=3.20.1;python_version<='3.8'",
|
||||
]
|
||||
|
||||
# Numpy include
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import unittest
|
||||
from qlib.backtest import backtest, decision
|
||||
from qlib.backtest import backtest
|
||||
from qlib.tests import TestAutoData
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
@@ -52,13 +52,12 @@ class FileStrTest(TestAutoData):
|
||||
factor = df["$factor"].item()
|
||||
price_unit = price / factor * 100
|
||||
dealt_num_for_1000 = (account_money // price_unit) * (100 / factor)
|
||||
print(price, factor, price_unit, dealt_num_for_1000)
|
||||
|
||||
# 2) generate orders
|
||||
orders = self._gen_orders(dealt_num_for_1000)
|
||||
print(orders)
|
||||
orders.to_csv(self.EXAMPLE_FILE)
|
||||
|
||||
orders = pd.read_csv(self.EXAMPLE_FILE, index_col=["datetime", "instrument"])
|
||||
print(orders)
|
||||
|
||||
# 3) run the strategy
|
||||
strategy_config = {
|
||||
@@ -101,7 +100,11 @@ class FileStrTest(TestAutoData):
|
||||
},
|
||||
},
|
||||
}
|
||||
report_dict, indicator_dict = backtest(executor=executor_config, strategy=strategy_config, **backtest_config)
|
||||
report_dict, indicator_dict = backtest(
|
||||
executor=executor_config,
|
||||
strategy=strategy_config,
|
||||
**backtest_config,
|
||||
)
|
||||
|
||||
# ffr valid
|
||||
ffr_dict = indicator_dict["1day"]["ffr"].to_dict()
|
||||
|
||||
Reference in New Issue
Block a user