1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

update env & strategy, add workflow

This commit is contained in:
bxdd
2021-04-22 22:28:01 +08:00
parent 8979d786a9
commit 39deb7d27f
12 changed files with 319 additions and 363 deletions

View File

@@ -0,0 +1,135 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import sys
from pathlib import Path
import qlib
import pandas as pd
from qlib.config import REG_CN
from qlib.contrib.model.gbdt import LGBModel
from qlib.contrib.data.handler import Alpha158
from qlib.contrib.strategy.strategy import TopkDropoutStrategy
from qlib.contrib.evaluate import (
backtest as normal_backtest,
risk_analysis,
)
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
from qlib.workflow import R
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
from qlib.tests.data import GetData
if __name__ == "__main__":
# use default data
provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir
if not exists_qlib_data(provider_uri):
print(f"Qlib data is not found in {provider_uri}")
GetData().qlib_data(target_dir=provider_uri, region=REG_CN)
qlib.init(provider_uri=provider_uri, region=REG_CN)
market = "csi300"
benchmark = "SH000300"
###################################
# train model
###################################
data_handler_config = {
"start_time": "2008-01-01",
"end_time": "2020-08-01",
"fit_start_time": "2008-01-01",
"fit_end_time": "2014-12-31",
"instruments": market,
}
task = {
"model": {
"class": "LGBModel",
"module_path": "qlib.contrib.model.gbdt",
"kwargs": {
"loss": "mse",
"colsample_bytree": 0.8879,
"learning_rate": 0.0421,
"subsample": 0.8789,
"lambda_l1": 205.6999,
"lambda_l2": 580.9768,
"max_depth": 8,
"num_leaves": 210,
"num_threads": 20,
},
},
"dataset": {
"class": "DatasetH",
"module_path": "qlib.data.dataset",
"kwargs": {
"handler": {
"class": "Alpha158",
"module_path": "qlib.contrib.data.handler",
"kwargs": data_handler_config,
},
"segments": {
"train": ("2008-01-01", "2014-12-31"),
"valid": ("2015-01-01", "2016-12-31"),
"test": ("2017-01-01", "2020-08-01"),
},
},
},
}
# model initialization
model = init_instance_by_config(task["model"])
dataset = init_instance_by_config(task["dataset"])
model.fit(dataset)
trade_start_time = "2017-01-01"
trade_end_time = "2020-08-01"
trade_exchange = get_exchange(start_time=trade_start_time, end_time=trade_end_time)
backtest_config={
"strategy": {
"class": "TopkDropoutStrategy",
"module_path": "qlib.contrib.strategy.dl_strategy",
"kwargs": {
"step_bar": "day",
"model": model,
"dataset": dataset,
"trade_exchange": trade_exchange,
"topk": 50,
"n_drop": 5,
},
},
"env":{
"class": "SplitEnv",
"module_path": "qlib.backtest.env",
"kwargs": {
"step_bar": "day",
"sub_env": {
"class": "SimulatorEnv",
"module_path": "qlib.backtest.env",
"kwargs": {
"step_bar": "1min",
"trade_exchange": trade_exchange,
}
},
"sub_strategy": {
"class": "SBBStrategyEMA",
"module_path": "qlib.contrib.strategy.rule_strategy",
"kwargs": {
"step_bar": "1min",
}
}
}
}
}
# prediction
recorder = R.get_recorder()
sr = SignalRecord(model, dataset, recorder)
sr.generate()
# backtest. If users want to use backtest based on their own prediction,
# please refer to https://qlib.readthedocs.io/en/latest/component/recorder.html#record-template.
par = PortAnaRecord(recorder, port_analysis_config)
par.generate()

View File

@@ -8,95 +8,37 @@ from .exchange import Exchange
from .report import Report
from .backtest import backtest as backtest_func, get_date_range
import copy
import numpy as np
import inspect
from ...utils import init_instance_by_config
from ...log import get_module_logger
from ...config import C
from ..utils import init_instance_by_config
from ..log import get_module_logger
from ..config import C
logger = get_module_logger("backtest caller")
def get_strategy(
strategy=None,
topk=50,
margin=0.5,
n_drop=5,
risk_degree=0.95,
str_type="dropout",
adjust_dates=None,
):
"""get_strategy
There will be 3 ways to return a stratgy. Please follow the code.
Parameters
----------
strategy : Strategy()
strategy used in backtest.
topk : int (Default value: 50)
top-N stocks to buy.
margin : int or float(Default value: 0.5)
- if isinstance(margin, int):
sell_limit = margin
- else:
sell_limit = pred_in_a_day.count() * margin
buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit).
sell_limit should be no less than topk.
n_drop : int
number of stocks to be replaced in each trading date.
risk_degree: float
0-1, 0.95 for example, use 95% money to trade.
str_type: 'amount', 'weight' or 'dropout'
strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy.
Returns
-------
:class: Strategy
an initialized strategy object
"""
# There will be 3 ways to return a strategy.
if strategy is None:
# 1) create strategy with param `strategy`
str_cls_dict = {
"amount": "TopkAmountStrategy",
"weight": "TopkWeightStrategy",
"dropout": "TopkDropoutStrategy",
}
logger.info("Create new strategy ")
from .. import strategy as strategy_pool
str_cls = getattr(strategy_pool, str_cls_dict.get(str_type))
strategy = str_cls(
topk=topk,
buffer_margin=margin,
n_drop=n_drop,
risk_degree=risk_degree,
adjust_dates=adjust_dates,
)
elif isinstance(strategy, (dict, str)):
# 2) create strategy with init_instance_by_config
logger.info("Create new strategy ")
strategy = init_instance_by_config(strategy)
from ..strategy.strategy import BaseStrategy
# else: nothing happens. 3) Use the strategy directly
if not isinstance(strategy, BaseStrategy):
raise TypeError("Strategy not supported")
return strategy
def init_env_instance_by_config(env):
if isinstance(env, dict):
env_config = copy.copy(env)
if "kwargs" in env_config:
env_kwargs = copy.copy(env_config["kwargs"]):
if "sub_env" in env_kwargs:
env_kwargs["sub_env"] = init_env_instance_by_config(env_kwargs["sub_env"])
if "sub_strategy" in env_kwargs:
env_kwargs["sub_strategy"] = init_instance_by_config(env_kwargs["sub_strategy"])
env_config["kwargs"] = env_kwargs
return init_instance_by_config(env_config)
else:
return env
def get_exchange(
pred,
exchange=None,
start_time=None,
end_time=None,
codes = "all",
subscribe_fields=[],
open_cost=0.0015,
close_cost=0.0025,
@@ -104,7 +46,6 @@ def get_exchange(
trade_unit=None,
limit_threshold=None,
deal_price=None,
extract_codes=False,
shift=1,
):
"""get_exchange
@@ -128,9 +69,6 @@ def get_exchange(
dealing price type: 'close', 'open', 'vwap'.
limit_threshold : float
limit move 0.1 (10%) for example, long and short with same limit.
extract_codes: bool
will we pass the codes extracted from the pred to the exchange.
NOTE: This will be faster with offline qlib.
Returns
-------
@@ -149,176 +87,52 @@ def get_exchange(
# handle exception for deal_price
if deal_price[0] != "$":
deal_price = "$" + deal_price
if extract_codes:
codes = sorted(pred.index.get_level_values("instrument").unique())
else:
codes = "all" # TODO: We must ensure that 'all.txt' includes all the stocks
dates = sorted(pred.index.get_level_values("datetime").unique())
dates = np.append(dates, get_date_range(dates[-1], left_shift=1, right_shift=shift))
exchange = Exchange(
trade_dates=dates,
start_time=start_time,
end_time=end_time,
codes=codes,
deal_price=deal_price,
subscribe_fields=subscribe_fields,
limit_threshold=limit_threshold,
open_cost=open_cost,
close_cost=close_cost,
min_cost=min_cost,
trade_unit=trade_unit,
min_cost=min_cost,
)
return exchange
else:
return init_instance_by_config(exchange, accept_types=Exchange)
def backtest(start_time, end_time, strategy, env, account=1e9, benchmark, **kwargs):
trade_strategy = init_instance_by_config(strategy)
trade_env = init_env_instance_by_config(env)
trade_account = Account(init_cash=account)
def get_executor(
executor=None,
trade_exchange=None,
verbose=True,
):
"""get_executor
There will be 3 ways to return a executor. Please follow the code.
Parameters
----------
executor : BaseExecutor
executor used in backtest.
trade_exchange : Exchange
exchange used in executor
verbose : bool
whether to print log.
Returns
-------
:class: BaseExecutor
an initialized BaseExecutor object
"""
# There will be 3 ways to return a executor.
if executor is None:
# 1) create executor with param `executor`
logger.info("Create new executor ")
from ..online.executor import SimulatorExecutor
executor = SimulatorExecutor(trade_exchange=trade_exchange, verbose=verbose)
elif isinstance(executor, (dict, str)):
# 2) create executor with config
logger.info("Create new executor ")
executor = init_instance_by_config(executor)
from ..online.executor import BaseExecutor
# 3) Use the executor directly
if not isinstance(executor, BaseExecutor):
raise TypeError("Executor not supported")
return executor
# This is the API for compatibility for legacy code
def backtest(pred, account=1e9, shift=1, benchmark="SH000905", verbose=True, return_order=False, **kwargs):
"""This function will help you set a reasonable Exchange and provide default value for strategy
Parameters
----------
- **backtest workflow related or commmon arguments**
pred : pandas.DataFrame
predict should has <datetime, instrument> index and one `score` column.
account : float
init account value.
shift : int
whether to shift prediction by one day.
benchmark : str
benchmark code, default is SH000905 CSI 500.
verbose : bool
whether to print log.
return_order : bool
whether to return order list
- **strategy related arguments**
strategy : Strategy()
strategy used in backtest.
topk : int (Default value: 50)
top-N stocks to buy.
margin : int or float(Default value: 0.5)
- if isinstance(margin, int):
sell_limit = margin
- else:
sell_limit = pred_in_a_day.count() * margin
buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit).
sell_limit should be no less than topk.
n_drop : int
number of stocks to be replaced in each trading date.
risk_degree: float
0-1, 0.95 for example, use 95% money to trade.
str_type: 'amount', 'weight' or 'dropout'
strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy.
- **exchange related arguments**
exchange: Exchange()
pass the exchange for speeding up.
subscribe_fields: list
subscribe fields.
open_cost : float
open transaction cost. The default value is 0.002(0.2%).
close_cost : float
close transaction cost. The default value is 0.002(0.2%).
min_cost : float
min transaction cost.
trade_unit : int
100 for China A.
deal_price: str
dealing price type: 'close', 'open', 'vwap'.
limit_threshold : float
limit move 0.1 (10%) for example, long and short with same limit.
extract_codes: bool
will we pass the codes extracted from the pred to the exchange.
.. note:: This will be faster with offline qlib.
- **executor related arguments**
executor : BaseExecutor()
executor used in backtest.
verbose : bool
whether to print log.
"""
# check strategy:
spec = inspect.getfullargspec(get_strategy)
str_args = {k: v for k, v in kwargs.items() if k in spec.args}
strategy = get_strategy(**str_args)
# init exchange:
spec = inspect.getfullargspec(get_exchange)
ex_args = {k: v for k, v in kwargs.items() if k in spec.args}
trade_exchange = get_exchange(pred, **ex_args)
# init executor:
executor = get_executor(executor=kwargs.get("executor"), trade_exchange=trade_exchange, verbose=verbose)
temp_env = trade_env
while True:
if hasattr(temp_env, "trade_exchange"):
temp_env.reset(trade_exchange=trade_exchange)
if hasattr(temp_env, "sub_env"):
temp_env = temp_env.sub_env
else:
break
trade_env.reset(start_time=start_time, end_time=end_time, trade_account=trade_account)
trade_strategy.reset(start_time=start_time, end_time=end_time)
trade_state = self.sub_env.get_first_state()
while not trade_env.finished():
_order_list = self.sub_strategy.generate_order(**trade_state)
trade_state, trade_info = self.sub_env.execute(sub_order_list)
report_df = trade_account.report.generate_report_dataframe()
positions = trade_account.get_positions()
# run backtest
report_dict = backtest_func(
pred=pred,
strategy=strategy,
executor=executor,
trade_exchange=trade_exchange,
shift=shift,
verbose=verbose,
account=account,
benchmark=benchmark,
return_order=return_order,
)
# for compatibility of the old API. return the dict positions
report_dict = {"report_df": report_df, "positions": positions}
positions = report_dict.get("positions")
report_dict.update({"positions": {k: p.position for k, p in positions.items()}})
return report_dict
return

View File

@@ -129,8 +129,7 @@ class Account:
# judge whether the the trading is begin.
# and don't add init account state into report, due to we don't have excess return in those days.
self.report.update_report_record(
trade_start_time=trade_start_time,
trade_end_time=trade_end_time,
trade_time=trade_start_time,
account_value=now_account_value,
cash=self.current.position["cash"],
return_rate=(self.earning + self.ct) / last_account_value,

View File

@@ -3,6 +3,7 @@
import re
import json
import copy
import warnings
import pathlib
import pandas as pd
from loguru import Logger
@@ -22,70 +23,76 @@ class BaseEnv:
def __init__(
self,
step_bar,
trade_account,
start_time=None,
end_time=None,
track=False,
trade_account=None,
verbose=False,
**kwargs
):
self.step_bar = step_bar
self.reset(start_time=start_time, end_time=end_time, trade_account=trade_account, track=track, **kwargs)
self.verbose = verbose
self.reset(start_time=start_time, end_time=end_time, trade_account=trade_account, **kwargs)
def _reset_trade_date(self, start_time=None, end_time=None):
def _reset_trade_calendar(self, start_time, end_time):
if start_time:
self.start_time = start_time
if end_time:
self.end_time = end_time
if not self.start_time or not self.end_time:
raise ValueError("value of `start_time` or `end_time` is None")
_calendar = get_sample_freq_calendar(start_time=start_time, end_time=end_time, freq=step_bar)
self.trade_dates = np.hstack(_calendar, pd.Timestamp(self.end_time))
self.trade_len = len(self.trade_dates)
self.trade_index = 0
if self.start_time and self.end_time:
_calendar = get_sample_freq_calendar(start_time=start_time, end_time=end_time, freq=step_bar)
self.trade_calendar = np.hstack(_calendar, pd.Timestamp(self.end_time))
self.trade_len = len(self.trade_calendar)
self.trade_index = 0
else:
raise ValueError("failed to reset trade calendar, params `start_time` or `end_time` is None.")
def reset(self, start_time=None, end_time=None, **kwargs):
def _get_position(self):
return self.trade_account.current
def _get_trade_time(self):
if 0 < self.trade_index < self.trade_len - 1:
trade_start_time = self.trade_calendar[self.trade_index - 1]
trade_end_time = self.trade_calendar[self.trade_index] - pd.Timestamp(second=1)
return trade_start_time, trade_end_time
elif self.trade_index == self.trade_len - 1:
trade_start_time = self.trade_calendar[self.trade_index - 1]
trade_end_time = self.trade_calendar[self.trade_index]
return trade_start_time, trade_end_time
else:
raise RuntimeError("trade_index out of range")
def reset(self, start_time=None, end_time=None, trade_account=None, **kwargs):
if start_time or end_time:
self._reset_trade_date(start_time=start_time, end_time=end_time)
self.track = kwargs.get("track", False)
self.upper_action = kwargs.get("upper_action", None)
self.trade_account = init_instance_by_config(kwargs.get("trade_account"))
return self.trade_account
self._reset_trade_calendar(start_time=start_time, end_time=end_time)
self.trade_account = trade_account
def execute(self, **kwargs):
def get_first_state(self):
init_state = {"current": self._get_position()}
return init_state
def execute(self, order_list, **kwargs):
self.trade_index = self.trade_index + 1
return
(
self.trade_account,
{
"start_time": self.start_time,
"end_time": self.end_time,
"trade_len": self.trade_len,
"trade_index": self.trade_index - 1,
}
)
def finished(self):
return self.trade_index >= self.trade_len - 1
class SplitEnv(BaseEnv):
def __init__(
self,
step_bar,
start_time,
end_time,
trade_account,
sub_env,
sub_strategy,
track=False,
start_time=None,
end_time=None,
trade_account=None,
verbose=False,
**kwargs
):
self.sub_env = sub_env
self.sub_strategy = sub_strategy
super(SplitEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, track=track)
super(SplitEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, verbose=verbose)
def execute(self, order_list, **kwargs):
if self.finished():
@@ -93,16 +100,18 @@ class SplitEnv(BaseEnv):
#if self.track:
# yield action
#episode_reward = 0
trade_start_time = self.trade_dates[self.trade_index]
trade_end_time = self.trade_dates[self.trade_index + 1]
self.sub_strategy.reset(trade_order_list=order_list)
sub_account = self.sub_env.reset(trade_order_list=order_list, start_time=self.trade_dates[self.trade_index - 1], end_time=self.trade_dates[self.trade_index])
super(SimulatorEnv, self).execute(**kwargs)
trade_start_time, trade_end_time = self._get_trade_time()
self.sub_env.reset(start_time=trade_start_time, end_time=trade_end_time, trade_account=self.trade_account)
self.sub_strategy.reset(start_time=trade_start_time, end_time=trade_end_time, trade_order_list=order_list)
trade_state = self.sub_env.get_first_state()
while not self.sub_env.finished():
sub_order_list = self.sub_strategy.generate_order(sub_account)
sub_account, sub_info = self.sub_env.execute(sub_order_list)
_order_list = self.sub_strategy.generate_order(**trade_state)
trade_state, trade_info = self.sub_env.execute(order_list=_order_list)
#episode_reward += sub_reward
_account, _info = super(SimulatorEnv, self).execute(**kwargs)
return _account, _info
_obs = {"current": self._get_position()}
_info = {}
return _obs, _info
@@ -111,16 +120,18 @@ class SimulatorEnv(BaseEnv):
def __init__(
self,
step_bar,
start_time,
end_time,
trade_account,
trade_exchange,
track=False,
start_time=None,
end_time=None,
trade_account=None,
trade_exchange=None,
verbose=False,
**kwargs
**kwargs,
):
self.trade_exchange = trade_exchange
super(SimulatorEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, track=track, verbose=verbose)
super(SimulatorEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, trade_exchange=trade_exchange, verbose=verbose)
def reset(trade_exchange=None, **kwargs):
super(SimulatorEnv, self).reset(**kwargs)
self.trade_exchange=trade_exchange
def execute(self, order_list, **kwargs):
"""
@@ -128,9 +139,8 @@ class SimulatorEnv(BaseEnv):
"""
if self.finished():
raise StopIteration(f"this env has completed its task, please reset it if you want to call it!")
trade_start_time = self.trade_dates[self.trade_index]
trade_end_time = self.trade_dates[self.trade_index + 1]
super(SimulatorEnv, self).execute(**kwargs)
ttrade_start_time, trade_end_time = self._get_trade_time()
trade_info = []
for order in order_list:
if self.trade_exchange.check_order(order) is True:
@@ -165,5 +175,6 @@ class SimulatorEnv(BaseEnv):
# do nothing
pass
self.trade_account.update_bar_end(trade_start_time=trade_start_time, trade_end_time=trade_end_time, trade_exchange=self.trade_exchange)
_account, _info = super(SimulatorEnv, self).execute(**kwargs)
return _account, {**_info, "trade_info", trade_info}
_obs = {"current": self._get_position()}
_info = {"trade_info": trade_info}
return _obs, _info

View File

@@ -64,9 +64,9 @@ class TopkDropoutStrategy(DLStrategy):
# self.stock_count['code'] will be the days the stock has been hold
# since last buy signal. This is designed for thresh
self.stock_count = {}
self.hold_thresh = hold_thresh
self.only_tradable = only_tradable
def get_risk_degree(self, trade_index):
"""get_risk_degree
@@ -76,12 +76,10 @@ class TopkDropoutStrategy(DLStrategy):
# It will use 95% amoutn of your total value by default
return self.risk_degree
def generate_order_list(self, trade_account, trade_start_time, trade_end_time, **kwargs):
def generate_order_list(self, current, **kwargs):
super(TopkDropoutStrategy, self).generate_order_list()
if self.trade_index == 1:
pred_start_time, pred_end_time = None, trade_start_time - pd.Timedelta(seconds=1)
else:
pred_start_time, pred_end_time = self.trade_dates[self.trade_index - 2], trade_start_time - pd.Timedelta(seconds=1)
trade_start_time, trade_end_time = self._get_trade_time()
pred_start_time, pred_end_time = self._get_last_trade_time()
pred_score = sample_feature(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
if self.only_tradable:
# If The strategy only consider tradable stock when make decision
@@ -114,7 +112,7 @@ class TopkDropoutStrategy(DLStrategy):
def filter_stock(l):
return l
current_temp = copy.deepcopy(trade_account.current)
current_temp = copy.deepcopy(current)
# generate order list for this adjust date
sell_order_list = []
buy_order_list = []
@@ -229,14 +227,15 @@ class TopkDropoutStrategy(DLStrategy):
return sell_order_list + buy_order_list
class WeightStrategyBase(DLStrategy):
def __init__(self, trade_exchange, order_generator_cls_or_obj=OrderGenWInteract, **kwargs):
super().__init__(**kwargs)
def __init__(self, trade_exchange, order_generator_cls_or_obj=OrderGenWInteract, start_time=None, end_time=None, **kwargs):
super(WeightStrategyBase, self).__init__(step_bar, start_time, end_time)
self.trade_exchange = trade_exchange
if isinstance(order_generator_cls_or_obj, type):
self.order_generator = order_generator_cls_or_obj()
else:
self.order_generator = order_generator_cls_or_obj
def generate_target_weight_position(self, score, current, trade_start_time, trade_end_time):
"""
@@ -253,7 +252,7 @@ class WeightStrategyBase(DLStrategy):
"""
raise NotImplementedError()
def generate_order_list(self, trade_account, trade_start_time, trade_end_time, **kwargs):
def generate_order_list(self, current, **kwargs):
"""
Parameters
-----------
@@ -269,11 +268,8 @@ class WeightStrategyBase(DLStrategy):
# generate_order_list
# generate_target_weight_position() and generate_order_list_from_target_weight_position() to generate order_list
super(WeightStrategyBase, self).generate_order_list()
if self.trade_index == 1:
pred_start_time, pred_end_time = None, trade_start_time - pd.Timedelta(seconds=1)
else:
pred_start_time, pred_end_time = self.trade_dates[self.trade_index - 2], trade_start_time - pd.Timedelta(seconds=1)
trade_start_time, trade_end_time = self._get_trade_time()
pred_start_time, pred_end_time = self._get_pred_time()
pred_score = sample_feature(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
current_temp = copy.deepcopy(trade_account.current)
target_weight_position = self.generate_target_weight_position(

View File

@@ -9,27 +9,18 @@ from ...backtest.order import Order
class TWAPStrategy(RuleStrategy, TradingEnhancement):
def __init__(
self,
step_bar,
start_time,
end_time,
**kwargs,
):
self.step_bar = step_bar
self.reset(start_time=start_time, end_time=end_time, **kwargs)
def reset(self, trade_order_list=None, **kwargs):
super(TWAPStrategy, self).reset(**kwargs)
TradingEnhancement.reset(trade_order_list=trade_order_list)
self.trade_amount = {}
for order in self.trade_order_list:
self.trade_amount[(order.stock_id, order.direction)] = order.amount // self.trade_len
def reset(self, start_time=None, end_time=None, trade_order_list=None, **kwargs):
super(SignalStrategy, self).reset(start_time=start_time, end_time=end_time, **kwargs)
TradingEnhancement.reset(trade_order_list=trade_order_list)
def generate_order_list(self, **kwargs):
super(TopkDropoutStrategy, self).generate_order_list()
trade_start_time = self.trade_dates[self.trade_index - 1]
trade_end_time = self.trade_dates[self.trade_index]
trade_start_time, trade_end_time = self._get_trade_time()
order_list = []
for order in self.trade_order_list:
_order = Order(
@@ -43,7 +34,7 @@ class TWAPStrategy(RuleStrategy, TradingEnhancement):
order_list.append(_order)
return order_list
class SBBStrategy(RuleStrategy, TradingEnhancement):
class SBBStrategyBase(RuleStrategy, TradingEnhancement):
"""
(S)elect the (B)etter one among every two adjacent trading (B)ars to sell or buy.
"""
@@ -51,34 +42,22 @@ class SBBStrategy(RuleStrategy, TradingEnhancement):
TREND_SHORT = 1
TREND_LONG = 2
def __init__(
self,
step_bar,
start_time,
end_time,
**kwargs,
):
self.step_bar = step_bar
self.reset(start_time=start_time, end_time=end_time, **kwargs)
def reset(self, trade_order_list=None, **kwargs):
TradingEnhancement.reset(trade_order_list=trade_order_list)
self.trade_amount = {}
self.trade_delay = {}
for order in self.trade_order_list:
self.trade_amount[(order.stock_id, order.direction)] = order.amount // self.trade_len
self.trade_trend[(order.stock_id, order.direction)] = TREND_MID
def reset(self, start_time=None, end_time=None, trade_order_list=None, **kwargs):
super(SignalStrategy, self).reset(start_time=start_time, end_time=end_time, **kwargs)
TradingEnhancement.reset(trade_order_list=trade_order_list)
super(SBBStrategyBase, self).reset(**kwargs)
def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None):
raise NotImplementedError("pred_price_trend method is not implemented!")
def generate_order_list(self, trade_start_time, trade_end_time, **kwargs):
super(TopkDropoutStrategy, self).generate_order_list()
if self.trade_index == 1:
pred_start_time, pred_end_time = None, trade_start_time - pd.Timedelta(seconds=1)
else:
pred_start_time, pred_end_time = self.trade_dates[self.trade_index - 2], trade_start_time - pd.Timedelta(seconds=1)
def generate_order_list(self, **kwargs):
super(SBBStrategyBase, self).generate_order_list()
trade_start_time, trade_end_time = self._get_trade_time()
pred_start_time, pred_end_time = self._get_last_trade_time()
order_list = []
for order in self.trade_order_list:
if self.trade_index % 2 == 1:
@@ -124,7 +103,7 @@ class SBBStrategy(RuleStrategy, TradingEnhancement):
return order_list
class SBBEMAStrategy(SBBStrategy):
class SBBStrategyEMA(SBBStrategyBase):
"""
(S)elect the (B)etter one among every two adjacent trading (B)ars to sell or buy with (EMA).
"""
@@ -137,17 +116,17 @@ class SBBEMAStrategy(SBBStrategy):
freq="day",
**kwargs,
):
self.step_bar = step_bar
super(SBBStrategyEMA, self).__init__(step_bar, start_time, end_time, **kwargs)
if instruments is None:
warnings.warn("`instruments` is not set, will load all stocks")
self.instruments = "all"
if isinstance(instruments, str):
self.instruments = D.instruments(instruments, filter_pipe=self.filter_pipe)
self.freq = freq
self.reset(start_time=start_time, end_time=end_time)
def _reset_trade_date(self, start_time=None, end_time=None):
super(SignalStrategy, self)._reset_trade_date(start_time=start_time, end_time=end_time)
def _reset_trade_calendar(self, start_time=None, end_time=None, _calendar=None):
super(SBBStrategyEMA, self)._reset_trade_calendar(start_time=start_time, end_time=end_time, _calendar=_calendar)
fields = [("EMA...", "signal")]
self.signal = D.features(instruments, fields, start_time=self.start_time, end_time=self.end_time, freq=self.freq)

View File

@@ -20,26 +20,49 @@ from ..data.data import D
- label和freq和strategy的bar分离这个如何决策呢
"""
class BaseStrategy:
def __init__(self, step_bar, start_time, end_time, **kwargs):
def __init__(self, step_bar, start_time=None, end_time=None, **kwargs):
self.step_bar = step_bar
self.reset(start_time=start_time, end_time=end_time, **kwargs)
def _reset_trade_date(self, start_time=None, end_time=None):
def _reset_trade_calendar(self, start_time, end_time, _calendar=None):
if start_time:
self.start_time = start_time
if end_time:
self.end_time = end_time
if not self.start_time or not self.end_time:
raise ValueError("value of `start_time` or `end_time` is None")
_calendar = get_sample_freq_calendar(start_time=start_time, end_time=end_time, freq=step_bar)
self.trade_dates = np.hstack(_calendar, pd.Timestamp(self.end_time))
self.trade_len = len(self.trade_dates)
self.trade_index = 0
def reset(self, start_time=None, end_time=None, **kwargs):
if start_time or end_time:
self._reset_trade_date(start_time=start_time, end_time=end_time)
if self.start_time and self.end_time:
if not _calendar:
_calendar = get_sample_freq_calendar(start_time=start_time, end_time=end_time, freq=step_bar)
self.trade_calendar = np.hstack(_calendar, pd.Timestamp(self.end_time))
else:
self.trade_calendar = _calendar
self.trade_len = len(self.trade_calendar)
self.trade_index = 0
else:
raise ValueError("failed to reset trade calendar, params `start_time` or `end_time` is None.")
def reset(self, start_time=None, end_time=None, _calendar=None):
if start_time or end_time :
self._reset_trade_calendar(start_time=start_time, end_time=end_time, calendar=calendar)
def _get_trade_time(self):
if 0 < self.trade_index < self.trade_len - 1:
trade_start_time = self.trade_calendar[self.trade_index - 1]
trade_end_time = self.trade_calendar[self.trade_index] - pd.Timestamp(second=1)
return trade_start_time, trade_end_time
elif self.trade_index == self.trade_len - 1:
trade_start_time = self.trade_calendar[self.trade_index - 1]
trade_end_time = self.trade_calendar[self.trade_index]
return trade_start_time, trade_end_time
else:
raise RuntimeError("trade_index out of range")
def _get_last_trade_time(self, shift=1):
if self.trade_index - shift < 0:
return None, None
elif self.trade_index - shift == 0:
return None, self.trade_index[self.trade_index - shift]
else:
return self.trade_index[self.trade_index - shift - 1], self.trade_index[self.trade_index - shift]
def generate_order_list(self, **kwargs):
self.trade_index = self.trade_index + 1
@@ -48,7 +71,7 @@ class RuleStrategy(BaseStrategy):
pass
class DLStrategy(BaseStrategy):
def __init__(self, step_bar, start_time, end_time, model, dataset:DatasetH):
def __init__(self, step_bar, model, dataset:DatasetH, start_time=None, end_time=None):
self.model = model
self.dataset = dataset
self.pred_scores = self.model.predict(dataset)
@@ -62,6 +85,5 @@ class DLStrategy(BaseStrategy):
class TradingEnhancement:
def reset(self, trade_order_list):
if trade_order_list:
self.trade_order_list = trade_order_list
self.trade_order_list = trade_order_list