mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
update env & strategy, add workflow
This commit is contained in:
135
examples/highfreq/backtest/workflow.py
Normal file
135
examples/highfreq/backtest/workflow.py
Normal file
@@ -0,0 +1,135 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import qlib
|
||||
import pandas as pd
|
||||
from qlib.config import REG_CN
|
||||
from qlib.contrib.model.gbdt import LGBModel
|
||||
from qlib.contrib.data.handler import Alpha158
|
||||
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
|
||||
from qlib.contrib.evaluate import (
|
||||
backtest as normal_backtest,
|
||||
risk_analysis,
|
||||
)
|
||||
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
|
||||
from qlib.tests.data import GetData
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
# use default data
|
||||
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
|
||||
if not exists_qlib_data(provider_uri):
|
||||
print(f"Qlib data is not found in {provider_uri}")
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
|
||||
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
market = "csi300"
|
||||
benchmark = "SH000300"
|
||||
|
||||
###################################
|
||||
# train model
|
||||
###################################
|
||||
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": market,
|
||||
}
|
||||
|
||||
task = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
"kwargs": {
|
||||
"loss": "mse",
|
||||
"colsample_bytree": 0.8879,
|
||||
"learning_rate": 0.0421,
|
||||
"subsample": 0.8789,
|
||||
"lambda_l1": 205.6999,
|
||||
"lambda_l2": 580.9768,
|
||||
"max_depth": 8,
|
||||
"num_leaves": 210,
|
||||
"num_threads": 20,
|
||||
},
|
||||
},
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
# model initialization
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
model.fit(dataset)
|
||||
|
||||
trade_start_time = "2017-01-01"
|
||||
trade_end_time = "2020-08-01"
|
||||
trade_exchange = get_exchange(start_time=trade_start_time, end_time=trade_end_time)
|
||||
|
||||
backtest_config={
|
||||
"strategy": {
|
||||
"class": "TopkDropoutStrategy",
|
||||
"module_path": "qlib.contrib.strategy.dl_strategy",
|
||||
"kwargs": {
|
||||
"step_bar": "day",
|
||||
"model": model,
|
||||
"dataset": dataset,
|
||||
"trade_exchange": trade_exchange,
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
},
|
||||
},
|
||||
"env":{
|
||||
"class": "SplitEnv",
|
||||
"module_path": "qlib.backtest.env",
|
||||
"kwargs": {
|
||||
"step_bar": "day",
|
||||
"sub_env": {
|
||||
"class": "SimulatorEnv",
|
||||
"module_path": "qlib.backtest.env",
|
||||
"kwargs": {
|
||||
"step_bar": "1min",
|
||||
"trade_exchange": trade_exchange,
|
||||
}
|
||||
},
|
||||
"sub_strategy": {
|
||||
"class": "SBBStrategyEMA",
|
||||
"module_path": "qlib.contrib.strategy.rule_strategy",
|
||||
"kwargs": {
|
||||
"step_bar": "1min",
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
# prediction
|
||||
recorder = R.get_recorder()
|
||||
sr = SignalRecord(model, dataset, recorder)
|
||||
sr.generate()
|
||||
|
||||
# backtest. If users want to use backtest based on their own prediction,
|
||||
# please refer to https://qlib.readthedocs.io/en/latest/component/recorder.html#record-template.
|
||||
par = PortAnaRecord(recorder, port_analysis_config)
|
||||
par.generate()
|
||||
@@ -8,95 +8,37 @@ from .exchange import Exchange
|
||||
from .report import Report
|
||||
from .backtest import backtest as backtest_func, get_date_range
|
||||
|
||||
import copy
|
||||
import numpy as np
|
||||
import inspect
|
||||
from ...utils import init_instance_by_config
|
||||
from ...log import get_module_logger
|
||||
from ...config import C
|
||||
from ..utils import init_instance_by_config
|
||||
from ..log import get_module_logger
|
||||
from ..config import C
|
||||
|
||||
logger = get_module_logger("backtest caller")
|
||||
|
||||
|
||||
def get_strategy(
|
||||
strategy=None,
|
||||
topk=50,
|
||||
margin=0.5,
|
||||
n_drop=5,
|
||||
risk_degree=0.95,
|
||||
str_type="dropout",
|
||||
adjust_dates=None,
|
||||
):
|
||||
"""get_strategy
|
||||
|
||||
There will be 3 ways to return a stratgy. Please follow the code.
|
||||
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
strategy : Strategy()
|
||||
strategy used in backtest.
|
||||
topk : int (Default value: 50)
|
||||
top-N stocks to buy.
|
||||
margin : int or float(Default value: 0.5)
|
||||
- if isinstance(margin, int):
|
||||
|
||||
sell_limit = margin
|
||||
|
||||
- else:
|
||||
|
||||
sell_limit = pred_in_a_day.count() * margin
|
||||
|
||||
buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit).
|
||||
sell_limit should be no less than topk.
|
||||
n_drop : int
|
||||
number of stocks to be replaced in each trading date.
|
||||
risk_degree: float
|
||||
0-1, 0.95 for example, use 95% money to trade.
|
||||
str_type: 'amount', 'weight' or 'dropout'
|
||||
strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy.
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class: Strategy
|
||||
an initialized strategy object
|
||||
"""
|
||||
|
||||
# There will be 3 ways to return a strategy.
|
||||
if strategy is None:
|
||||
# 1) create strategy with param `strategy`
|
||||
str_cls_dict = {
|
||||
"amount": "TopkAmountStrategy",
|
||||
"weight": "TopkWeightStrategy",
|
||||
"dropout": "TopkDropoutStrategy",
|
||||
}
|
||||
logger.info("Create new strategy ")
|
||||
from .. import strategy as strategy_pool
|
||||
|
||||
str_cls = getattr(strategy_pool, str_cls_dict.get(str_type))
|
||||
strategy = str_cls(
|
||||
topk=topk,
|
||||
buffer_margin=margin,
|
||||
n_drop=n_drop,
|
||||
risk_degree=risk_degree,
|
||||
adjust_dates=adjust_dates,
|
||||
)
|
||||
elif isinstance(strategy, (dict, str)):
|
||||
# 2) create strategy with init_instance_by_config
|
||||
logger.info("Create new strategy ")
|
||||
strategy = init_instance_by_config(strategy)
|
||||
|
||||
from ..strategy.strategy import BaseStrategy
|
||||
|
||||
# else: nothing happens. 3) Use the strategy directly
|
||||
if not isinstance(strategy, BaseStrategy):
|
||||
raise TypeError("Strategy not supported")
|
||||
return strategy
|
||||
def init_env_instance_by_config(env):
|
||||
if isinstance(env, dict):
|
||||
env_config = copy.copy(env)
|
||||
if "kwargs" in env_config:
|
||||
env_kwargs = copy.copy(env_config["kwargs"]):
|
||||
if "sub_env" in env_kwargs:
|
||||
env_kwargs["sub_env"] = init_env_instance_by_config(env_kwargs["sub_env"])
|
||||
if "sub_strategy" in env_kwargs:
|
||||
env_kwargs["sub_strategy"] = init_instance_by_config(env_kwargs["sub_strategy"])
|
||||
env_config["kwargs"] = env_kwargs
|
||||
return init_instance_by_config(env_config)
|
||||
else:
|
||||
return env
|
||||
|
||||
|
||||
def get_exchange(
|
||||
pred,
|
||||
exchange=None,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
codes = "all",
|
||||
subscribe_fields=[],
|
||||
open_cost=0.0015,
|
||||
close_cost=0.0025,
|
||||
@@ -104,7 +46,6 @@ def get_exchange(
|
||||
trade_unit=None,
|
||||
limit_threshold=None,
|
||||
deal_price=None,
|
||||
extract_codes=False,
|
||||
shift=1,
|
||||
):
|
||||
"""get_exchange
|
||||
@@ -128,9 +69,6 @@ def get_exchange(
|
||||
dealing price type: 'close', 'open', 'vwap'.
|
||||
limit_threshold : float
|
||||
limit move 0.1 (10%) for example, long and short with same limit.
|
||||
extract_codes: bool
|
||||
will we pass the codes extracted from the pred to the exchange.
|
||||
NOTE: This will be faster with offline qlib.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -149,176 +87,52 @@ def get_exchange(
|
||||
# handle exception for deal_price
|
||||
if deal_price[0] != "$":
|
||||
deal_price = "$" + deal_price
|
||||
if extract_codes:
|
||||
codes = sorted(pred.index.get_level_values("instrument").unique())
|
||||
else:
|
||||
codes = "all" # TODO: We must ensure that 'all.txt' includes all the stocks
|
||||
|
||||
dates = sorted(pred.index.get_level_values("datetime").unique())
|
||||
dates = np.append(dates, get_date_range(dates[-1], left_shift=1, right_shift=shift))
|
||||
|
||||
exchange = Exchange(
|
||||
trade_dates=dates,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
codes=codes,
|
||||
deal_price=deal_price,
|
||||
subscribe_fields=subscribe_fields,
|
||||
limit_threshold=limit_threshold,
|
||||
open_cost=open_cost,
|
||||
close_cost=close_cost,
|
||||
min_cost=min_cost,
|
||||
trade_unit=trade_unit,
|
||||
min_cost=min_cost,
|
||||
)
|
||||
return exchange
|
||||
else:
|
||||
return init_instance_by_config(exchange, accept_types=Exchange)
|
||||
|
||||
def backtest(start_time, end_time, strategy, env, account=1e9, benchmark, **kwargs):
|
||||
trade_strategy = init_instance_by_config(strategy)
|
||||
trade_env = init_env_instance_by_config(env)
|
||||
trade_account = Account(init_cash=account)
|
||||
|
||||
def get_executor(
|
||||
executor=None,
|
||||
trade_exchange=None,
|
||||
verbose=True,
|
||||
):
|
||||
"""get_executor
|
||||
|
||||
There will be 3 ways to return a executor. Please follow the code.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
|
||||
executor : BaseExecutor
|
||||
executor used in backtest.
|
||||
trade_exchange : Exchange
|
||||
exchange used in executor
|
||||
verbose : bool
|
||||
whether to print log.
|
||||
|
||||
Returns
|
||||
-------
|
||||
:class: BaseExecutor
|
||||
an initialized BaseExecutor object
|
||||
"""
|
||||
|
||||
# There will be 3 ways to return a executor.
|
||||
if executor is None:
|
||||
# 1) create executor with param `executor`
|
||||
logger.info("Create new executor ")
|
||||
from ..online.executor import SimulatorExecutor
|
||||
|
||||
executor = SimulatorExecutor(trade_exchange=trade_exchange, verbose=verbose)
|
||||
elif isinstance(executor, (dict, str)):
|
||||
# 2) create executor with config
|
||||
logger.info("Create new executor ")
|
||||
executor = init_instance_by_config(executor)
|
||||
|
||||
from ..online.executor import BaseExecutor
|
||||
|
||||
# 3) Use the executor directly
|
||||
if not isinstance(executor, BaseExecutor):
|
||||
raise TypeError("Executor not supported")
|
||||
return executor
|
||||
|
||||
|
||||
# This is the API for compatibility for legacy code
|
||||
def backtest(pred, account=1e9, shift=1, benchmark="SH000905", verbose=True, return_order=False, **kwargs):
|
||||
"""This function will help you set a reasonable Exchange and provide default value for strategy
|
||||
Parameters
|
||||
----------
|
||||
|
||||
- **backtest workflow related or commmon arguments**
|
||||
|
||||
pred : pandas.DataFrame
|
||||
predict should has <datetime, instrument> index and one `score` column.
|
||||
account : float
|
||||
init account value.
|
||||
shift : int
|
||||
whether to shift prediction by one day.
|
||||
benchmark : str
|
||||
benchmark code, default is SH000905 CSI 500.
|
||||
verbose : bool
|
||||
whether to print log.
|
||||
return_order : bool
|
||||
whether to return order list
|
||||
|
||||
- **strategy related arguments**
|
||||
|
||||
strategy : Strategy()
|
||||
strategy used in backtest.
|
||||
topk : int (Default value: 50)
|
||||
top-N stocks to buy.
|
||||
margin : int or float(Default value: 0.5)
|
||||
- if isinstance(margin, int):
|
||||
|
||||
sell_limit = margin
|
||||
|
||||
- else:
|
||||
|
||||
sell_limit = pred_in_a_day.count() * margin
|
||||
|
||||
buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit).
|
||||
sell_limit should be no less than topk.
|
||||
n_drop : int
|
||||
number of stocks to be replaced in each trading date.
|
||||
risk_degree: float
|
||||
0-1, 0.95 for example, use 95% money to trade.
|
||||
str_type: 'amount', 'weight' or 'dropout'
|
||||
strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy.
|
||||
|
||||
- **exchange related arguments**
|
||||
|
||||
exchange: Exchange()
|
||||
pass the exchange for speeding up.
|
||||
subscribe_fields: list
|
||||
subscribe fields.
|
||||
open_cost : float
|
||||
open transaction cost. The default value is 0.002(0.2%).
|
||||
close_cost : float
|
||||
close transaction cost. The default value is 0.002(0.2%).
|
||||
min_cost : float
|
||||
min transaction cost.
|
||||
trade_unit : int
|
||||
100 for China A.
|
||||
deal_price: str
|
||||
dealing price type: 'close', 'open', 'vwap'.
|
||||
limit_threshold : float
|
||||
limit move 0.1 (10%) for example, long and short with same limit.
|
||||
extract_codes: bool
|
||||
will we pass the codes extracted from the pred to the exchange.
|
||||
|
||||
.. note:: This will be faster with offline qlib.
|
||||
|
||||
- **executor related arguments**
|
||||
|
||||
executor : BaseExecutor()
|
||||
executor used in backtest.
|
||||
verbose : bool
|
||||
whether to print log.
|
||||
|
||||
"""
|
||||
# check strategy:
|
||||
spec = inspect.getfullargspec(get_strategy)
|
||||
str_args = {k: v for k, v in kwargs.items() if k in spec.args}
|
||||
strategy = get_strategy(**str_args)
|
||||
|
||||
# init exchange:
|
||||
spec = inspect.getfullargspec(get_exchange)
|
||||
ex_args = {k: v for k, v in kwargs.items() if k in spec.args}
|
||||
trade_exchange = get_exchange(pred, **ex_args)
|
||||
|
||||
# init executor:
|
||||
executor = get_executor(executor=kwargs.get("executor"), trade_exchange=trade_exchange, verbose=verbose)
|
||||
temp_env = trade_env
|
||||
while True:
|
||||
if hasattr(temp_env, "trade_exchange"):
|
||||
temp_env.reset(trade_exchange=trade_exchange)
|
||||
if hasattr(temp_env, "sub_env"):
|
||||
temp_env = temp_env.sub_env
|
||||
else:
|
||||
break
|
||||
|
||||
trade_env.reset(start_time=start_time, end_time=end_time, trade_account=trade_account)
|
||||
trade_strategy.reset(start_time=start_time, end_time=end_time)
|
||||
trade_state = self.sub_env.get_first_state()
|
||||
|
||||
|
||||
while not trade_env.finished():
|
||||
_order_list = self.sub_strategy.generate_order(**trade_state)
|
||||
trade_state, trade_info = self.sub_env.execute(sub_order_list)
|
||||
|
||||
report_df = trade_account.report.generate_report_dataframe()
|
||||
positions = trade_account.get_positions()
|
||||
|
||||
# run backtest
|
||||
report_dict = backtest_func(
|
||||
pred=pred,
|
||||
strategy=strategy,
|
||||
executor=executor,
|
||||
trade_exchange=trade_exchange,
|
||||
shift=shift,
|
||||
verbose=verbose,
|
||||
account=account,
|
||||
benchmark=benchmark,
|
||||
return_order=return_order,
|
||||
)
|
||||
# for compatibility of the old API. return the dict positions
|
||||
report_dict = {"report_df": report_df, "positions": positions}
|
||||
|
||||
positions = report_dict.get("positions")
|
||||
report_dict.update({"positions": {k: p.position for k, p in positions.items()}})
|
||||
return report_dict
|
||||
return
|
||||
|
||||
@@ -129,8 +129,7 @@ class Account:
|
||||
# judge whether the the trading is begin.
|
||||
# and don't add init account state into report, due to we don't have excess return in those days.
|
||||
self.report.update_report_record(
|
||||
trade_start_time=trade_start_time,
|
||||
trade_end_time=trade_end_time,
|
||||
trade_time=trade_start_time,
|
||||
account_value=now_account_value,
|
||||
cash=self.current.position["cash"],
|
||||
return_rate=(self.earning + self.ct) / last_account_value,
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
import re
|
||||
import json
|
||||
import copy
|
||||
import warnings
|
||||
import pathlib
|
||||
import pandas as pd
|
||||
from loguru import Logger
|
||||
@@ -22,70 +23,76 @@ class BaseEnv:
|
||||
def __init__(
|
||||
self,
|
||||
step_bar,
|
||||
trade_account,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
track=False,
|
||||
trade_account=None,
|
||||
verbose=False,
|
||||
**kwargs
|
||||
):
|
||||
self.step_bar = step_bar
|
||||
self.reset(start_time=start_time, end_time=end_time, trade_account=trade_account, track=track, **kwargs)
|
||||
self.verbose = verbose
|
||||
self.reset(start_time=start_time, end_time=end_time, trade_account=trade_account, **kwargs)
|
||||
|
||||
def _reset_trade_date(self, start_time=None, end_time=None):
|
||||
def _reset_trade_calendar(self, start_time, end_time):
|
||||
if start_time:
|
||||
self.start_time = start_time
|
||||
if end_time:
|
||||
self.end_time = end_time
|
||||
if not self.start_time or not self.end_time:
|
||||
raise ValueError("value of `start_time` or `end_time` is None")
|
||||
_calendar = get_sample_freq_calendar(start_time=start_time, end_time=end_time, freq=step_bar)
|
||||
self.trade_dates = np.hstack(_calendar, pd.Timestamp(self.end_time))
|
||||
self.trade_len = len(self.trade_dates)
|
||||
self.trade_index = 0
|
||||
if self.start_time and self.end_time:
|
||||
_calendar = get_sample_freq_calendar(start_time=start_time, end_time=end_time, freq=step_bar)
|
||||
self.trade_calendar = np.hstack(_calendar, pd.Timestamp(self.end_time))
|
||||
self.trade_len = len(self.trade_calendar)
|
||||
self.trade_index = 0
|
||||
else:
|
||||
raise ValueError("failed to reset trade calendar, params `start_time` or `end_time` is None.")
|
||||
|
||||
def reset(self, start_time=None, end_time=None, **kwargs):
|
||||
def _get_position(self):
|
||||
return self.trade_account.current
|
||||
|
||||
def _get_trade_time(self):
|
||||
if 0 < self.trade_index < self.trade_len - 1:
|
||||
trade_start_time = self.trade_calendar[self.trade_index - 1]
|
||||
trade_end_time = self.trade_calendar[self.trade_index] - pd.Timestamp(second=1)
|
||||
return trade_start_time, trade_end_time
|
||||
elif self.trade_index == self.trade_len - 1:
|
||||
trade_start_time = self.trade_calendar[self.trade_index - 1]
|
||||
trade_end_time = self.trade_calendar[self.trade_index]
|
||||
return trade_start_time, trade_end_time
|
||||
else:
|
||||
raise RuntimeError("trade_index out of range")
|
||||
|
||||
def reset(self, start_time=None, end_time=None, trade_account=None, **kwargs):
|
||||
if start_time or end_time:
|
||||
self._reset_trade_date(start_time=start_time, end_time=end_time)
|
||||
self.track = kwargs.get("track", False)
|
||||
self.upper_action = kwargs.get("upper_action", None)
|
||||
self.trade_account = init_instance_by_config(kwargs.get("trade_account"))
|
||||
return self.trade_account
|
||||
self._reset_trade_calendar(start_time=start_time, end_time=end_time)
|
||||
self.trade_account = trade_account
|
||||
|
||||
def execute(self, **kwargs):
|
||||
def get_first_state(self):
|
||||
init_state = {"current": self._get_position()}
|
||||
return init_state
|
||||
|
||||
|
||||
def execute(self, order_list, **kwargs):
|
||||
self.trade_index = self.trade_index + 1
|
||||
return
|
||||
(
|
||||
self.trade_account,
|
||||
{
|
||||
"start_time": self.start_time,
|
||||
"end_time": self.end_time,
|
||||
"trade_len": self.trade_len,
|
||||
"trade_index": self.trade_index - 1,
|
||||
}
|
||||
)
|
||||
|
||||
def finished(self):
|
||||
return self.trade_index >= self.trade_len - 1
|
||||
|
||||
|
||||
|
||||
class SplitEnv(BaseEnv):
|
||||
def __init__(
|
||||
self,
|
||||
step_bar,
|
||||
start_time,
|
||||
end_time,
|
||||
trade_account,
|
||||
sub_env,
|
||||
sub_strategy,
|
||||
track=False,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
trade_account=None,
|
||||
verbose=False,
|
||||
**kwargs
|
||||
):
|
||||
self.sub_env = sub_env
|
||||
self.sub_strategy = sub_strategy
|
||||
super(SplitEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, track=track)
|
||||
super(SplitEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, verbose=verbose)
|
||||
|
||||
def execute(self, order_list, **kwargs):
|
||||
if self.finished():
|
||||
@@ -93,16 +100,18 @@ class SplitEnv(BaseEnv):
|
||||
#if self.track:
|
||||
# yield action
|
||||
#episode_reward = 0
|
||||
trade_start_time = self.trade_dates[self.trade_index]
|
||||
trade_end_time = self.trade_dates[self.trade_index + 1]
|
||||
self.sub_strategy.reset(trade_order_list=order_list)
|
||||
sub_account = self.sub_env.reset(trade_order_list=order_list, start_time=self.trade_dates[self.trade_index - 1], end_time=self.trade_dates[self.trade_index])
|
||||
super(SimulatorEnv, self).execute(**kwargs)
|
||||
trade_start_time, trade_end_time = self._get_trade_time()
|
||||
self.sub_env.reset(start_time=trade_start_time, end_time=trade_end_time, trade_account=self.trade_account)
|
||||
self.sub_strategy.reset(start_time=trade_start_time, end_time=trade_end_time, trade_order_list=order_list)
|
||||
trade_state = self.sub_env.get_first_state()
|
||||
while not self.sub_env.finished():
|
||||
sub_order_list = self.sub_strategy.generate_order(sub_account)
|
||||
sub_account, sub_info = self.sub_env.execute(sub_order_list)
|
||||
_order_list = self.sub_strategy.generate_order(**trade_state)
|
||||
trade_state, trade_info = self.sub_env.execute(order_list=_order_list)
|
||||
#episode_reward += sub_reward
|
||||
_account, _info = super(SimulatorEnv, self).execute(**kwargs)
|
||||
return _account, _info
|
||||
_obs = {"current": self._get_position()}
|
||||
_info = {}
|
||||
return _obs, _info
|
||||
|
||||
|
||||
|
||||
@@ -111,16 +120,18 @@ class SimulatorEnv(BaseEnv):
|
||||
def __init__(
|
||||
self,
|
||||
step_bar,
|
||||
start_time,
|
||||
end_time,
|
||||
trade_account,
|
||||
trade_exchange,
|
||||
track=False,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
trade_account=None,
|
||||
trade_exchange=None,
|
||||
verbose=False,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
self.trade_exchange = trade_exchange
|
||||
super(SimulatorEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, track=track, verbose=verbose)
|
||||
super(SimulatorEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, trade_exchange=trade_exchange, verbose=verbose)
|
||||
|
||||
def reset(trade_exchange=None, **kwargs):
|
||||
super(SimulatorEnv, self).reset(**kwargs)
|
||||
self.trade_exchange=trade_exchange
|
||||
|
||||
def execute(self, order_list, **kwargs):
|
||||
"""
|
||||
@@ -128,9 +139,8 @@ class SimulatorEnv(BaseEnv):
|
||||
"""
|
||||
if self.finished():
|
||||
raise StopIteration(f"this env has completed its task, please reset it if you want to call it!")
|
||||
|
||||
trade_start_time = self.trade_dates[self.trade_index]
|
||||
trade_end_time = self.trade_dates[self.trade_index + 1]
|
||||
super(SimulatorEnv, self).execute(**kwargs)
|
||||
ttrade_start_time, trade_end_time = self._get_trade_time()
|
||||
trade_info = []
|
||||
for order in order_list:
|
||||
if self.trade_exchange.check_order(order) is True:
|
||||
@@ -165,5 +175,6 @@ class SimulatorEnv(BaseEnv):
|
||||
# do nothing
|
||||
pass
|
||||
self.trade_account.update_bar_end(trade_start_time=trade_start_time, trade_end_time=trade_end_time, trade_exchange=self.trade_exchange)
|
||||
_account, _info = super(SimulatorEnv, self).execute(**kwargs)
|
||||
return _account, {**_info, "trade_info", trade_info}
|
||||
_obs = {"current": self._get_position()}
|
||||
_info = {"trade_info": trade_info}
|
||||
return _obs, _info
|
||||
@@ -64,9 +64,9 @@ class TopkDropoutStrategy(DLStrategy):
|
||||
# self.stock_count['code'] will be the days the stock has been hold
|
||||
# since last buy signal. This is designed for thresh
|
||||
self.stock_count = {}
|
||||
|
||||
self.hold_thresh = hold_thresh
|
||||
self.only_tradable = only_tradable
|
||||
|
||||
|
||||
def get_risk_degree(self, trade_index):
|
||||
"""get_risk_degree
|
||||
@@ -76,12 +76,10 @@ class TopkDropoutStrategy(DLStrategy):
|
||||
# It will use 95% amoutn of your total value by default
|
||||
return self.risk_degree
|
||||
|
||||
def generate_order_list(self, trade_account, trade_start_time, trade_end_time, **kwargs):
|
||||
def generate_order_list(self, current, **kwargs):
|
||||
super(TopkDropoutStrategy, self).generate_order_list()
|
||||
if self.trade_index == 1:
|
||||
pred_start_time, pred_end_time = None, trade_start_time - pd.Timedelta(seconds=1)
|
||||
else:
|
||||
pred_start_time, pred_end_time = self.trade_dates[self.trade_index - 2], trade_start_time - pd.Timedelta(seconds=1)
|
||||
trade_start_time, trade_end_time = self._get_trade_time()
|
||||
pred_start_time, pred_end_time = self._get_last_trade_time()
|
||||
pred_score = sample_feature(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
|
||||
if self.only_tradable:
|
||||
# If The strategy only consider tradable stock when make decision
|
||||
@@ -114,7 +112,7 @@ class TopkDropoutStrategy(DLStrategy):
|
||||
def filter_stock(l):
|
||||
return l
|
||||
|
||||
current_temp = copy.deepcopy(trade_account.current)
|
||||
current_temp = copy.deepcopy(current)
|
||||
# generate order list for this adjust date
|
||||
sell_order_list = []
|
||||
buy_order_list = []
|
||||
@@ -229,14 +227,15 @@ class TopkDropoutStrategy(DLStrategy):
|
||||
return sell_order_list + buy_order_list
|
||||
|
||||
class WeightStrategyBase(DLStrategy):
|
||||
def __init__(self, trade_exchange, order_generator_cls_or_obj=OrderGenWInteract, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
def __init__(self, trade_exchange, order_generator_cls_or_obj=OrderGenWInteract, start_time=None, end_time=None, **kwargs):
|
||||
super(WeightStrategyBase, self).__init__(step_bar, start_time, end_time)
|
||||
self.trade_exchange = trade_exchange
|
||||
if isinstance(order_generator_cls_or_obj, type):
|
||||
self.order_generator = order_generator_cls_or_obj()
|
||||
else:
|
||||
self.order_generator = order_generator_cls_or_obj
|
||||
|
||||
|
||||
|
||||
def generate_target_weight_position(self, score, current, trade_start_time, trade_end_time):
|
||||
"""
|
||||
@@ -253,7 +252,7 @@ class WeightStrategyBase(DLStrategy):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def generate_order_list(self, trade_account, trade_start_time, trade_end_time, **kwargs):
|
||||
def generate_order_list(self, current, **kwargs):
|
||||
"""
|
||||
Parameters
|
||||
-----------
|
||||
@@ -269,11 +268,8 @@ class WeightStrategyBase(DLStrategy):
|
||||
# generate_order_list
|
||||
# generate_target_weight_position() and generate_order_list_from_target_weight_position() to generate order_list
|
||||
super(WeightStrategyBase, self).generate_order_list()
|
||||
if self.trade_index == 1:
|
||||
pred_start_time, pred_end_time = None, trade_start_time - pd.Timedelta(seconds=1)
|
||||
else:
|
||||
pred_start_time, pred_end_time = self.trade_dates[self.trade_index - 2], trade_start_time - pd.Timedelta(seconds=1)
|
||||
|
||||
trade_start_time, trade_end_time = self._get_trade_time()
|
||||
pred_start_time, pred_end_time = self._get_pred_time()
|
||||
pred_score = sample_feature(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
|
||||
current_temp = copy.deepcopy(trade_account.current)
|
||||
target_weight_position = self.generate_target_weight_position(
|
||||
|
||||
@@ -9,27 +9,18 @@ from ...backtest.order import Order
|
||||
|
||||
|
||||
class TWAPStrategy(RuleStrategy, TradingEnhancement):
|
||||
def __init__(
|
||||
self,
|
||||
step_bar,
|
||||
start_time,
|
||||
end_time,
|
||||
**kwargs,
|
||||
):
|
||||
self.step_bar = step_bar
|
||||
self.reset(start_time=start_time, end_time=end_time, **kwargs)
|
||||
|
||||
def reset(self, trade_order_list=None, **kwargs):
|
||||
super(TWAPStrategy, self).reset(**kwargs)
|
||||
TradingEnhancement.reset(trade_order_list=trade_order_list)
|
||||
self.trade_amount = {}
|
||||
for order in self.trade_order_list:
|
||||
self.trade_amount[(order.stock_id, order.direction)] = order.amount // self.trade_len
|
||||
|
||||
def reset(self, start_time=None, end_time=None, trade_order_list=None, **kwargs):
|
||||
super(SignalStrategy, self).reset(start_time=start_time, end_time=end_time, **kwargs)
|
||||
TradingEnhancement.reset(trade_order_list=trade_order_list)
|
||||
|
||||
|
||||
def generate_order_list(self, **kwargs):
|
||||
super(TopkDropoutStrategy, self).generate_order_list()
|
||||
trade_start_time = self.trade_dates[self.trade_index - 1]
|
||||
trade_end_time = self.trade_dates[self.trade_index]
|
||||
trade_start_time, trade_end_time = self._get_trade_time()
|
||||
order_list = []
|
||||
for order in self.trade_order_list:
|
||||
_order = Order(
|
||||
@@ -43,7 +34,7 @@ class TWAPStrategy(RuleStrategy, TradingEnhancement):
|
||||
order_list.append(_order)
|
||||
return order_list
|
||||
|
||||
class SBBStrategy(RuleStrategy, TradingEnhancement):
|
||||
class SBBStrategyBase(RuleStrategy, TradingEnhancement):
|
||||
"""
|
||||
(S)elect the (B)etter one among every two adjacent trading (B)ars to sell or buy.
|
||||
"""
|
||||
@@ -51,34 +42,22 @@ class SBBStrategy(RuleStrategy, TradingEnhancement):
|
||||
TREND_SHORT = 1
|
||||
TREND_LONG = 2
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
step_bar,
|
||||
start_time,
|
||||
end_time,
|
||||
**kwargs,
|
||||
):
|
||||
self.step_bar = step_bar
|
||||
self.reset(start_time=start_time, end_time=end_time, **kwargs)
|
||||
def reset(self, trade_order_list=None, **kwargs):
|
||||
TradingEnhancement.reset(trade_order_list=trade_order_list)
|
||||
self.trade_amount = {}
|
||||
self.trade_delay = {}
|
||||
for order in self.trade_order_list:
|
||||
self.trade_amount[(order.stock_id, order.direction)] = order.amount // self.trade_len
|
||||
self.trade_trend[(order.stock_id, order.direction)] = TREND_MID
|
||||
|
||||
def reset(self, start_time=None, end_time=None, trade_order_list=None, **kwargs):
|
||||
super(SignalStrategy, self).reset(start_time=start_time, end_time=end_time, **kwargs)
|
||||
TradingEnhancement.reset(trade_order_list=trade_order_list)
|
||||
super(SBBStrategyBase, self).reset(**kwargs)
|
||||
|
||||
def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None):
|
||||
raise NotImplementedError("pred_price_trend method is not implemented!")
|
||||
|
||||
def generate_order_list(self, trade_start_time, trade_end_time, **kwargs):
|
||||
super(TopkDropoutStrategy, self).generate_order_list()
|
||||
if self.trade_index == 1:
|
||||
pred_start_time, pred_end_time = None, trade_start_time - pd.Timedelta(seconds=1)
|
||||
else:
|
||||
pred_start_time, pred_end_time = self.trade_dates[self.trade_index - 2], trade_start_time - pd.Timedelta(seconds=1)
|
||||
def generate_order_list(self, **kwargs):
|
||||
super(SBBStrategyBase, self).generate_order_list()
|
||||
trade_start_time, trade_end_time = self._get_trade_time()
|
||||
pred_start_time, pred_end_time = self._get_last_trade_time()
|
||||
order_list = []
|
||||
for order in self.trade_order_list:
|
||||
if self.trade_index % 2 == 1:
|
||||
@@ -124,7 +103,7 @@ class SBBStrategy(RuleStrategy, TradingEnhancement):
|
||||
return order_list
|
||||
|
||||
|
||||
class SBBEMAStrategy(SBBStrategy):
|
||||
class SBBStrategyEMA(SBBStrategyBase):
|
||||
"""
|
||||
(S)elect the (B)etter one among every two adjacent trading (B)ars to sell or buy with (EMA).
|
||||
"""
|
||||
@@ -137,17 +116,17 @@ class SBBEMAStrategy(SBBStrategy):
|
||||
freq="day",
|
||||
**kwargs,
|
||||
):
|
||||
self.step_bar = step_bar
|
||||
super(SBBStrategyEMA, self).__init__(step_bar, start_time, end_time, **kwargs)
|
||||
if instruments is None:
|
||||
warnings.warn("`instruments` is not set, will load all stocks")
|
||||
self.instruments = "all"
|
||||
if isinstance(instruments, str):
|
||||
self.instruments = D.instruments(instruments, filter_pipe=self.filter_pipe)
|
||||
self.freq = freq
|
||||
self.reset(start_time=start_time, end_time=end_time)
|
||||
|
||||
|
||||
def _reset_trade_date(self, start_time=None, end_time=None):
|
||||
super(SignalStrategy, self)._reset_trade_date(start_time=start_time, end_time=end_time)
|
||||
def _reset_trade_calendar(self, start_time=None, end_time=None, _calendar=None):
|
||||
super(SBBStrategyEMA, self)._reset_trade_calendar(start_time=start_time, end_time=end_time, _calendar=_calendar)
|
||||
fields = [("EMA...", "signal")]
|
||||
self.signal = D.features(instruments, fields, start_time=self.start_time, end_time=self.end_time, freq=self.freq)
|
||||
|
||||
|
||||
@@ -20,26 +20,49 @@ from ..data.data import D
|
||||
- label和freq和strategy的bar分离,这个如何决策呢
|
||||
"""
|
||||
class BaseStrategy:
|
||||
def __init__(self, step_bar, start_time, end_time, **kwargs):
|
||||
def __init__(self, step_bar, start_time=None, end_time=None, **kwargs):
|
||||
self.step_bar = step_bar
|
||||
self.reset(start_time=start_time, end_time=end_time, **kwargs)
|
||||
|
||||
def _reset_trade_date(self, start_time=None, end_time=None):
|
||||
def _reset_trade_calendar(self, start_time, end_time, _calendar=None):
|
||||
if start_time:
|
||||
self.start_time = start_time
|
||||
if end_time:
|
||||
self.end_time = end_time
|
||||
if not self.start_time or not self.end_time:
|
||||
raise ValueError("value of `start_time` or `end_time` is None")
|
||||
_calendar = get_sample_freq_calendar(start_time=start_time, end_time=end_time, freq=step_bar)
|
||||
self.trade_dates = np.hstack(_calendar, pd.Timestamp(self.end_time))
|
||||
self.trade_len = len(self.trade_dates)
|
||||
self.trade_index = 0
|
||||
|
||||
def reset(self, start_time=None, end_time=None, **kwargs):
|
||||
if start_time or end_time:
|
||||
self._reset_trade_date(start_time=start_time, end_time=end_time)
|
||||
if self.start_time and self.end_time:
|
||||
if not _calendar:
|
||||
_calendar = get_sample_freq_calendar(start_time=start_time, end_time=end_time, freq=step_bar)
|
||||
self.trade_calendar = np.hstack(_calendar, pd.Timestamp(self.end_time))
|
||||
else:
|
||||
self.trade_calendar = _calendar
|
||||
self.trade_len = len(self.trade_calendar)
|
||||
self.trade_index = 0
|
||||
else:
|
||||
raise ValueError("failed to reset trade calendar, params `start_time` or `end_time` is None.")
|
||||
|
||||
def reset(self, start_time=None, end_time=None, _calendar=None):
|
||||
if start_time or end_time :
|
||||
self._reset_trade_calendar(start_time=start_time, end_time=end_time, calendar=calendar)
|
||||
|
||||
def _get_trade_time(self):
|
||||
if 0 < self.trade_index < self.trade_len - 1:
|
||||
trade_start_time = self.trade_calendar[self.trade_index - 1]
|
||||
trade_end_time = self.trade_calendar[self.trade_index] - pd.Timestamp(second=1)
|
||||
return trade_start_time, trade_end_time
|
||||
elif self.trade_index == self.trade_len - 1:
|
||||
trade_start_time = self.trade_calendar[self.trade_index - 1]
|
||||
trade_end_time = self.trade_calendar[self.trade_index]
|
||||
return trade_start_time, trade_end_time
|
||||
else:
|
||||
raise RuntimeError("trade_index out of range")
|
||||
|
||||
def _get_last_trade_time(self, shift=1):
|
||||
if self.trade_index - shift < 0:
|
||||
return None, None
|
||||
elif self.trade_index - shift == 0:
|
||||
return None, self.trade_index[self.trade_index - shift]
|
||||
else:
|
||||
return self.trade_index[self.trade_index - shift - 1], self.trade_index[self.trade_index - shift]
|
||||
def generate_order_list(self, **kwargs):
|
||||
self.trade_index = self.trade_index + 1
|
||||
|
||||
@@ -48,7 +71,7 @@ class RuleStrategy(BaseStrategy):
|
||||
pass
|
||||
|
||||
class DLStrategy(BaseStrategy):
|
||||
def __init__(self, step_bar, start_time, end_time, model, dataset:DatasetH):
|
||||
def __init__(self, step_bar, model, dataset:DatasetH, start_time=None, end_time=None):
|
||||
self.model = model
|
||||
self.dataset = dataset
|
||||
self.pred_scores = self.model.predict(dataset)
|
||||
@@ -62,6 +85,5 @@ class DLStrategy(BaseStrategy):
|
||||
|
||||
class TradingEnhancement:
|
||||
def reset(self, trade_order_list):
|
||||
if trade_order_list:
|
||||
self.trade_order_list = trade_order_list
|
||||
self.trade_order_list = trade_order_list
|
||||
|
||||
|
||||
Reference in New Issue
Block a user