From 39deb7d27fd49b5b7074282d50c1c7de0ee0e0bf Mon Sep 17 00:00:00 2001 From: bxdd Date: Thu, 22 Apr 2021 22:28:01 +0800 Subject: [PATCH] update env & strategy, add workflow --- examples/highfreq/backtest/workflow.py | 135 ++++++++ examples/highfreq/{ => data}/README.md | 0 .../highfreq/{ => data}/highfreq_handler.py | 0 examples/highfreq/{ => data}/highfreq_ops.py | 0 .../highfreq/{ => data}/highfreq_processor.py | 0 examples/highfreq/{ => data}/workflow.py | 0 qlib/backtest/__init__.py | 288 ++++-------------- qlib/backtest/account.py | 3 +- qlib/backtest/env.py | 119 ++++---- qlib/contrib/strategy/dl_strategy.py | 26 +- qlib/contrib/strategy/rule_strategy.py | 59 ++-- qlib/strategy/base.py | 52 +++- 12 files changed, 319 insertions(+), 363 deletions(-) create mode 100644 examples/highfreq/backtest/workflow.py rename examples/highfreq/{ => data}/README.md (100%) rename examples/highfreq/{ => data}/highfreq_handler.py (100%) rename examples/highfreq/{ => data}/highfreq_ops.py (100%) rename examples/highfreq/{ => data}/highfreq_processor.py (100%) rename examples/highfreq/{ => data}/workflow.py (100%) diff --git a/examples/highfreq/backtest/workflow.py b/examples/highfreq/backtest/workflow.py new file mode 100644 index 000000000..cddc78b92 --- /dev/null +++ b/examples/highfreq/backtest/workflow.py @@ -0,0 +1,135 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import sys +from pathlib import Path + +import qlib +import pandas as pd +from qlib.config import REG_CN +from qlib.contrib.model.gbdt import LGBModel +from qlib.contrib.data.handler import Alpha158 +from qlib.contrib.strategy.strategy import TopkDropoutStrategy +from qlib.contrib.evaluate import ( + backtest as normal_backtest, + risk_analysis, +) +from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict +from qlib.workflow import R +from qlib.workflow.record_temp import SignalRecord, PortAnaRecord +from qlib.tests.data import GetData + +if __name__ == "__main__": + + # use default data + provider_uri = "~/.qlib/qlib_data/cn_data" # target_dir + if not exists_qlib_data(provider_uri): + print(f"Qlib data is not found in {provider_uri}") + GetData().qlib_data(target_dir=provider_uri, region=REG_CN) + + qlib.init(provider_uri=provider_uri, region=REG_CN) + + market = "csi300" + benchmark = "SH000300" + + ################################### + # train model + ################################### + + data_handler_config = { + "start_time": "2008-01-01", + "end_time": "2020-08-01", + "fit_start_time": "2008-01-01", + "fit_end_time": "2014-12-31", + "instruments": market, + } + + task = { + "model": { + "class": "LGBModel", + "module_path": "qlib.contrib.model.gbdt", + "kwargs": { + "loss": "mse", + "colsample_bytree": 0.8879, + "learning_rate": 0.0421, + "subsample": 0.8789, + "lambda_l1": 205.6999, + "lambda_l2": 580.9768, + "max_depth": 8, + "num_leaves": 210, + "num_threads": 20, + }, + }, + "dataset": { + "class": "DatasetH", + "module_path": "qlib.data.dataset", + "kwargs": { + "handler": { + "class": "Alpha158", + "module_path": "qlib.contrib.data.handler", + "kwargs": data_handler_config, + }, + "segments": { + "train": ("2008-01-01", "2014-12-31"), + "valid": ("2015-01-01", "2016-12-31"), + "test": ("2017-01-01", "2020-08-01"), + }, + }, + }, + } + # model initialization + model = init_instance_by_config(task["model"]) + dataset = init_instance_by_config(task["dataset"]) + model.fit(dataset) + + trade_start_time = "2017-01-01" + trade_end_time = "2020-08-01" + trade_exchange = get_exchange(start_time=trade_start_time, end_time=trade_end_time) + + backtest_config={ + "strategy": { + "class": "TopkDropoutStrategy", + "module_path": "qlib.contrib.strategy.dl_strategy", + "kwargs": { + "step_bar": "day", + "model": model, + "dataset": dataset, + "trade_exchange": trade_exchange, + "topk": 50, + "n_drop": 5, + }, + }, + "env":{ + "class": "SplitEnv", + "module_path": "qlib.backtest.env", + "kwargs": { + "step_bar": "day", + "sub_env": { + "class": "SimulatorEnv", + "module_path": "qlib.backtest.env", + "kwargs": { + "step_bar": "1min", + "trade_exchange": trade_exchange, + } + }, + "sub_strategy": { + "class": "SBBStrategyEMA", + "module_path": "qlib.contrib.strategy.rule_strategy", + "kwargs": { + "step_bar": "1min", + } + } + } + } + } + + + # prediction + recorder = R.get_recorder() + sr = SignalRecord(model, dataset, recorder) + sr.generate() + + # backtest. If users want to use backtest based on their own prediction, + # please refer to https://qlib.readthedocs.io/en/latest/component/recorder.html#record-template. + par = PortAnaRecord(recorder, port_analysis_config) + par.generate() diff --git a/examples/highfreq/README.md b/examples/highfreq/data/README.md similarity index 100% rename from examples/highfreq/README.md rename to examples/highfreq/data/README.md diff --git a/examples/highfreq/highfreq_handler.py b/examples/highfreq/data/highfreq_handler.py similarity index 100% rename from examples/highfreq/highfreq_handler.py rename to examples/highfreq/data/highfreq_handler.py diff --git a/examples/highfreq/highfreq_ops.py b/examples/highfreq/data/highfreq_ops.py similarity index 100% rename from examples/highfreq/highfreq_ops.py rename to examples/highfreq/data/highfreq_ops.py diff --git a/examples/highfreq/highfreq_processor.py b/examples/highfreq/data/highfreq_processor.py similarity index 100% rename from examples/highfreq/highfreq_processor.py rename to examples/highfreq/data/highfreq_processor.py diff --git a/examples/highfreq/workflow.py b/examples/highfreq/data/workflow.py similarity index 100% rename from examples/highfreq/workflow.py rename to examples/highfreq/data/workflow.py diff --git a/qlib/backtest/__init__.py b/qlib/backtest/__init__.py index aa24ffb0c..0afe03ea4 100644 --- a/qlib/backtest/__init__.py +++ b/qlib/backtest/__init__.py @@ -8,95 +8,37 @@ from .exchange import Exchange from .report import Report from .backtest import backtest as backtest_func, get_date_range +import copy import numpy as np import inspect -from ...utils import init_instance_by_config -from ...log import get_module_logger -from ...config import C +from ..utils import init_instance_by_config +from ..log import get_module_logger +from ..config import C logger = get_module_logger("backtest caller") -def get_strategy( - strategy=None, - topk=50, - margin=0.5, - n_drop=5, - risk_degree=0.95, - str_type="dropout", - adjust_dates=None, -): - """get_strategy - - There will be 3 ways to return a stratgy. Please follow the code. - - - Parameters - ---------- - - strategy : Strategy() - strategy used in backtest. - topk : int (Default value: 50) - top-N stocks to buy. - margin : int or float(Default value: 0.5) - - if isinstance(margin, int): - - sell_limit = margin - - - else: - - sell_limit = pred_in_a_day.count() * margin - - buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit). - sell_limit should be no less than topk. - n_drop : int - number of stocks to be replaced in each trading date. - risk_degree: float - 0-1, 0.95 for example, use 95% money to trade. - str_type: 'amount', 'weight' or 'dropout' - strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy. - - Returns - ------- - :class: Strategy - an initialized strategy object - """ - - # There will be 3 ways to return a strategy. - if strategy is None: - # 1) create strategy with param `strategy` - str_cls_dict = { - "amount": "TopkAmountStrategy", - "weight": "TopkWeightStrategy", - "dropout": "TopkDropoutStrategy", - } - logger.info("Create new strategy ") - from .. import strategy as strategy_pool - - str_cls = getattr(strategy_pool, str_cls_dict.get(str_type)) - strategy = str_cls( - topk=topk, - buffer_margin=margin, - n_drop=n_drop, - risk_degree=risk_degree, - adjust_dates=adjust_dates, - ) - elif isinstance(strategy, (dict, str)): - # 2) create strategy with init_instance_by_config - logger.info("Create new strategy ") - strategy = init_instance_by_config(strategy) - - from ..strategy.strategy import BaseStrategy - - # else: nothing happens. 3) Use the strategy directly - if not isinstance(strategy, BaseStrategy): - raise TypeError("Strategy not supported") - return strategy +def init_env_instance_by_config(env): + if isinstance(env, dict): + env_config = copy.copy(env) + if "kwargs" in env_config: + env_kwargs = copy.copy(env_config["kwargs"]): + if "sub_env" in env_kwargs: + env_kwargs["sub_env"] = init_env_instance_by_config(env_kwargs["sub_env"]) + if "sub_strategy" in env_kwargs: + env_kwargs["sub_strategy"] = init_instance_by_config(env_kwargs["sub_strategy"]) + env_config["kwargs"] = env_kwargs + return init_instance_by_config(env_config) + else: + return env def get_exchange( pred, exchange=None, + start_time=None, + end_time=None, + codes = "all", subscribe_fields=[], open_cost=0.0015, close_cost=0.0025, @@ -104,7 +46,6 @@ def get_exchange( trade_unit=None, limit_threshold=None, deal_price=None, - extract_codes=False, shift=1, ): """get_exchange @@ -128,9 +69,6 @@ def get_exchange( dealing price type: 'close', 'open', 'vwap'. limit_threshold : float limit move 0.1 (10%) for example, long and short with same limit. - extract_codes: bool - will we pass the codes extracted from the pred to the exchange. - NOTE: This will be faster with offline qlib. Returns ------- @@ -149,176 +87,52 @@ def get_exchange( # handle exception for deal_price if deal_price[0] != "$": deal_price = "$" + deal_price - if extract_codes: - codes = sorted(pred.index.get_level_values("instrument").unique()) - else: - codes = "all" # TODO: We must ensure that 'all.txt' includes all the stocks - - dates = sorted(pred.index.get_level_values("datetime").unique()) - dates = np.append(dates, get_date_range(dates[-1], left_shift=1, right_shift=shift)) exchange = Exchange( - trade_dates=dates, + start_time=start_time, + end_time=end_time, codes=codes, deal_price=deal_price, subscribe_fields=subscribe_fields, limit_threshold=limit_threshold, open_cost=open_cost, close_cost=close_cost, - min_cost=min_cost, trade_unit=trade_unit, + min_cost=min_cost, ) - return exchange + else: + return init_instance_by_config(exchange, accept_types=Exchange) +def backtest(start_time, end_time, strategy, env, account=1e9, benchmark, **kwargs): + trade_strategy = init_instance_by_config(strategy) + trade_env = init_env_instance_by_config(env) + trade_account = Account(init_cash=account) -def get_executor( - executor=None, - trade_exchange=None, - verbose=True, -): - """get_executor - - There will be 3 ways to return a executor. Please follow the code. - - Parameters - ---------- - - executor : BaseExecutor - executor used in backtest. - trade_exchange : Exchange - exchange used in executor - verbose : bool - whether to print log. - - Returns - ------- - :class: BaseExecutor - an initialized BaseExecutor object - """ - - # There will be 3 ways to return a executor. - if executor is None: - # 1) create executor with param `executor` - logger.info("Create new executor ") - from ..online.executor import SimulatorExecutor - - executor = SimulatorExecutor(trade_exchange=trade_exchange, verbose=verbose) - elif isinstance(executor, (dict, str)): - # 2) create executor with config - logger.info("Create new executor ") - executor = init_instance_by_config(executor) - - from ..online.executor import BaseExecutor - - # 3) Use the executor directly - if not isinstance(executor, BaseExecutor): - raise TypeError("Executor not supported") - return executor - - -# This is the API for compatibility for legacy code -def backtest(pred, account=1e9, shift=1, benchmark="SH000905", verbose=True, return_order=False, **kwargs): - """This function will help you set a reasonable Exchange and provide default value for strategy - Parameters - ---------- - - - **backtest workflow related or commmon arguments** - - pred : pandas.DataFrame - predict should has index and one `score` column. - account : float - init account value. - shift : int - whether to shift prediction by one day. - benchmark : str - benchmark code, default is SH000905 CSI 500. - verbose : bool - whether to print log. - return_order : bool - whether to return order list - - - **strategy related arguments** - - strategy : Strategy() - strategy used in backtest. - topk : int (Default value: 50) - top-N stocks to buy. - margin : int or float(Default value: 0.5) - - if isinstance(margin, int): - - sell_limit = margin - - - else: - - sell_limit = pred_in_a_day.count() * margin - - buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit). - sell_limit should be no less than topk. - n_drop : int - number of stocks to be replaced in each trading date. - risk_degree: float - 0-1, 0.95 for example, use 95% money to trade. - str_type: 'amount', 'weight' or 'dropout' - strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy. - - - **exchange related arguments** - - exchange: Exchange() - pass the exchange for speeding up. - subscribe_fields: list - subscribe fields. - open_cost : float - open transaction cost. The default value is 0.002(0.2%). - close_cost : float - close transaction cost. The default value is 0.002(0.2%). - min_cost : float - min transaction cost. - trade_unit : int - 100 for China A. - deal_price: str - dealing price type: 'close', 'open', 'vwap'. - limit_threshold : float - limit move 0.1 (10%) for example, long and short with same limit. - extract_codes: bool - will we pass the codes extracted from the pred to the exchange. - - .. note:: This will be faster with offline qlib. - - - **executor related arguments** - - executor : BaseExecutor() - executor used in backtest. - verbose : bool - whether to print log. - - """ - # check strategy: - spec = inspect.getfullargspec(get_strategy) - str_args = {k: v for k, v in kwargs.items() if k in spec.args} - strategy = get_strategy(**str_args) - - # init exchange: spec = inspect.getfullargspec(get_exchange) ex_args = {k: v for k, v in kwargs.items() if k in spec.args} trade_exchange = get_exchange(pred, **ex_args) - # init executor: - executor = get_executor(executor=kwargs.get("executor"), trade_exchange=trade_exchange, verbose=verbose) + temp_env = trade_env + while True: + if hasattr(temp_env, "trade_exchange"): + temp_env.reset(trade_exchange=trade_exchange) + if hasattr(temp_env, "sub_env"): + temp_env = temp_env.sub_env + else: + break + + trade_env.reset(start_time=start_time, end_time=end_time, trade_account=trade_account) + trade_strategy.reset(start_time=start_time, end_time=end_time) + trade_state = self.sub_env.get_first_state() + + + while not trade_env.finished(): + _order_list = self.sub_strategy.generate_order(**trade_state) + trade_state, trade_info = self.sub_env.execute(sub_order_list) + + report_df = trade_account.report.generate_report_dataframe() + positions = trade_account.get_positions() - # run backtest - report_dict = backtest_func( - pred=pred, - strategy=strategy, - executor=executor, - trade_exchange=trade_exchange, - shift=shift, - verbose=verbose, - account=account, - benchmark=benchmark, - return_order=return_order, - ) - # for compatibility of the old API. return the dict positions + report_dict = {"report_df": report_df, "positions": positions} - positions = report_dict.get("positions") - report_dict.update({"positions": {k: p.position for k, p in positions.items()}}) - return report_dict + return diff --git a/qlib/backtest/account.py b/qlib/backtest/account.py index 038bbcf60..c44d26d7b 100644 --- a/qlib/backtest/account.py +++ b/qlib/backtest/account.py @@ -129,8 +129,7 @@ class Account: # judge whether the the trading is begin. # and don't add init account state into report, due to we don't have excess return in those days. self.report.update_report_record( - trade_start_time=trade_start_time, - trade_end_time=trade_end_time, + trade_time=trade_start_time, account_value=now_account_value, cash=self.current.position["cash"], return_rate=(self.earning + self.ct) / last_account_value, diff --git a/qlib/backtest/env.py b/qlib/backtest/env.py index 32ed91ef0..a4f1eb95e 100644 --- a/qlib/backtest/env.py +++ b/qlib/backtest/env.py @@ -3,6 +3,7 @@ import re import json import copy +import warnings import pathlib import pandas as pd from loguru import Logger @@ -22,70 +23,76 @@ class BaseEnv: def __init__( self, step_bar, - trade_account, start_time=None, end_time=None, - track=False, + trade_account=None, verbose=False, **kwargs ): self.step_bar = step_bar - self.reset(start_time=start_time, end_time=end_time, trade_account=trade_account, track=track, **kwargs) + self.verbose = verbose + self.reset(start_time=start_time, end_time=end_time, trade_account=trade_account, **kwargs) - def _reset_trade_date(self, start_time=None, end_time=None): + def _reset_trade_calendar(self, start_time, end_time): if start_time: self.start_time = start_time if end_time: self.end_time = end_time - if not self.start_time or not self.end_time: - raise ValueError("value of `start_time` or `end_time` is None") - _calendar = get_sample_freq_calendar(start_time=start_time, end_time=end_time, freq=step_bar) - self.trade_dates = np.hstack(_calendar, pd.Timestamp(self.end_time)) - self.trade_len = len(self.trade_dates) - self.trade_index = 0 + if self.start_time and self.end_time: + _calendar = get_sample_freq_calendar(start_time=start_time, end_time=end_time, freq=step_bar) + self.trade_calendar = np.hstack(_calendar, pd.Timestamp(self.end_time)) + self.trade_len = len(self.trade_calendar) + self.trade_index = 0 + else: + raise ValueError("failed to reset trade calendar, params `start_time` or `end_time` is None.") - def reset(self, start_time=None, end_time=None, **kwargs): + def _get_position(self): + return self.trade_account.current + + def _get_trade_time(self): + if 0 < self.trade_index < self.trade_len - 1: + trade_start_time = self.trade_calendar[self.trade_index - 1] + trade_end_time = self.trade_calendar[self.trade_index] - pd.Timestamp(second=1) + return trade_start_time, trade_end_time + elif self.trade_index == self.trade_len - 1: + trade_start_time = self.trade_calendar[self.trade_index - 1] + trade_end_time = self.trade_calendar[self.trade_index] + return trade_start_time, trade_end_time + else: + raise RuntimeError("trade_index out of range") + + def reset(self, start_time=None, end_time=None, trade_account=None, **kwargs): if start_time or end_time: - self._reset_trade_date(start_time=start_time, end_time=end_time) - self.track = kwargs.get("track", False) - self.upper_action = kwargs.get("upper_action", None) - self.trade_account = init_instance_by_config(kwargs.get("trade_account")) - return self.trade_account + self._reset_trade_calendar(start_time=start_time, end_time=end_time) + self.trade_account = trade_account - def execute(self, **kwargs): + def get_first_state(self): + init_state = {"current": self._get_position()} + return init_state + + + def execute(self, order_list, **kwargs): self.trade_index = self.trade_index + 1 - return - ( - self.trade_account, - { - "start_time": self.start_time, - "end_time": self.end_time, - "trade_len": self.trade_len, - "trade_index": self.trade_index - 1, - } - ) def finished(self): return self.trade_index >= self.trade_len - 1 - class SplitEnv(BaseEnv): def __init__( self, step_bar, - start_time, - end_time, - trade_account, sub_env, sub_strategy, - track=False, + start_time=None, + end_time=None, + trade_account=None, verbose=False, **kwargs ): self.sub_env = sub_env self.sub_strategy = sub_strategy - super(SplitEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, track=track) + super(SplitEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, verbose=verbose) def execute(self, order_list, **kwargs): if self.finished(): @@ -93,16 +100,18 @@ class SplitEnv(BaseEnv): #if self.track: # yield action #episode_reward = 0 - trade_start_time = self.trade_dates[self.trade_index] - trade_end_time = self.trade_dates[self.trade_index + 1] - self.sub_strategy.reset(trade_order_list=order_list) - sub_account = self.sub_env.reset(trade_order_list=order_list, start_time=self.trade_dates[self.trade_index - 1], end_time=self.trade_dates[self.trade_index]) + super(SimulatorEnv, self).execute(**kwargs) + trade_start_time, trade_end_time = self._get_trade_time() + self.sub_env.reset(start_time=trade_start_time, end_time=trade_end_time, trade_account=self.trade_account) + self.sub_strategy.reset(start_time=trade_start_time, end_time=trade_end_time, trade_order_list=order_list) + trade_state = self.sub_env.get_first_state() while not self.sub_env.finished(): - sub_order_list = self.sub_strategy.generate_order(sub_account) - sub_account, sub_info = self.sub_env.execute(sub_order_list) + _order_list = self.sub_strategy.generate_order(**trade_state) + trade_state, trade_info = self.sub_env.execute(order_list=_order_list) #episode_reward += sub_reward - _account, _info = super(SimulatorEnv, self).execute(**kwargs) - return _account, _info + _obs = {"current": self._get_position()} + _info = {} + return _obs, _info @@ -111,16 +120,18 @@ class SimulatorEnv(BaseEnv): def __init__( self, step_bar, - start_time, - end_time, - trade_account, - trade_exchange, - track=False, + start_time=None, + end_time=None, + trade_account=None, + trade_exchange=None, verbose=False, - **kwargs + **kwargs, ): - self.trade_exchange = trade_exchange - super(SimulatorEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, track=track, verbose=verbose) + super(SimulatorEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, trade_exchange=trade_exchange, verbose=verbose) + + def reset(trade_exchange=None, **kwargs): + super(SimulatorEnv, self).reset(**kwargs) + self.trade_exchange=trade_exchange def execute(self, order_list, **kwargs): """ @@ -128,9 +139,8 @@ class SimulatorEnv(BaseEnv): """ if self.finished(): raise StopIteration(f"this env has completed its task, please reset it if you want to call it!") - - trade_start_time = self.trade_dates[self.trade_index] - trade_end_time = self.trade_dates[self.trade_index + 1] + super(SimulatorEnv, self).execute(**kwargs) + ttrade_start_time, trade_end_time = self._get_trade_time() trade_info = [] for order in order_list: if self.trade_exchange.check_order(order) is True: @@ -165,5 +175,6 @@ class SimulatorEnv(BaseEnv): # do nothing pass self.trade_account.update_bar_end(trade_start_time=trade_start_time, trade_end_time=trade_end_time, trade_exchange=self.trade_exchange) - _account, _info = super(SimulatorEnv, self).execute(**kwargs) - return _account, {**_info, "trade_info", trade_info} \ No newline at end of file + _obs = {"current": self._get_position()} + _info = {"trade_info": trade_info} + return _obs, _info \ No newline at end of file diff --git a/qlib/contrib/strategy/dl_strategy.py b/qlib/contrib/strategy/dl_strategy.py index f3a227c85..737fd7a58 100644 --- a/qlib/contrib/strategy/dl_strategy.py +++ b/qlib/contrib/strategy/dl_strategy.py @@ -64,9 +64,9 @@ class TopkDropoutStrategy(DLStrategy): # self.stock_count['code'] will be the days the stock has been hold # since last buy signal. This is designed for thresh self.stock_count = {} - self.hold_thresh = hold_thresh self.only_tradable = only_tradable + def get_risk_degree(self, trade_index): """get_risk_degree @@ -76,12 +76,10 @@ class TopkDropoutStrategy(DLStrategy): # It will use 95% amoutn of your total value by default return self.risk_degree - def generate_order_list(self, trade_account, trade_start_time, trade_end_time, **kwargs): + def generate_order_list(self, current, **kwargs): super(TopkDropoutStrategy, self).generate_order_list() - if self.trade_index == 1: - pred_start_time, pred_end_time = None, trade_start_time - pd.Timedelta(seconds=1) - else: - pred_start_time, pred_end_time = self.trade_dates[self.trade_index - 2], trade_start_time - pd.Timedelta(seconds=1) + trade_start_time, trade_end_time = self._get_trade_time() + pred_start_time, pred_end_time = self._get_last_trade_time() pred_score = sample_feature(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last") if self.only_tradable: # If The strategy only consider tradable stock when make decision @@ -114,7 +112,7 @@ class TopkDropoutStrategy(DLStrategy): def filter_stock(l): return l - current_temp = copy.deepcopy(trade_account.current) + current_temp = copy.deepcopy(current) # generate order list for this adjust date sell_order_list = [] buy_order_list = [] @@ -229,14 +227,15 @@ class TopkDropoutStrategy(DLStrategy): return sell_order_list + buy_order_list class WeightStrategyBase(DLStrategy): - def __init__(self, trade_exchange, order_generator_cls_or_obj=OrderGenWInteract, **kwargs): - super().__init__(**kwargs) + def __init__(self, trade_exchange, order_generator_cls_or_obj=OrderGenWInteract, start_time=None, end_time=None, **kwargs): + super(WeightStrategyBase, self).__init__(step_bar, start_time, end_time) self.trade_exchange = trade_exchange if isinstance(order_generator_cls_or_obj, type): self.order_generator = order_generator_cls_or_obj() else: self.order_generator = order_generator_cls_or_obj + def generate_target_weight_position(self, score, current, trade_start_time, trade_end_time): """ @@ -253,7 +252,7 @@ class WeightStrategyBase(DLStrategy): """ raise NotImplementedError() - def generate_order_list(self, trade_account, trade_start_time, trade_end_time, **kwargs): + def generate_order_list(self, current, **kwargs): """ Parameters ----------- @@ -269,11 +268,8 @@ class WeightStrategyBase(DLStrategy): # generate_order_list # generate_target_weight_position() and generate_order_list_from_target_weight_position() to generate order_list super(WeightStrategyBase, self).generate_order_list() - if self.trade_index == 1: - pred_start_time, pred_end_time = None, trade_start_time - pd.Timedelta(seconds=1) - else: - pred_start_time, pred_end_time = self.trade_dates[self.trade_index - 2], trade_start_time - pd.Timedelta(seconds=1) - + trade_start_time, trade_end_time = self._get_trade_time() + pred_start_time, pred_end_time = self._get_pred_time() pred_score = sample_feature(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last") current_temp = copy.deepcopy(trade_account.current) target_weight_position = self.generate_target_weight_position( diff --git a/qlib/contrib/strategy/rule_strategy.py b/qlib/contrib/strategy/rule_strategy.py index af43be246..31968dafa 100644 --- a/qlib/contrib/strategy/rule_strategy.py +++ b/qlib/contrib/strategy/rule_strategy.py @@ -9,27 +9,18 @@ from ...backtest.order import Order class TWAPStrategy(RuleStrategy, TradingEnhancement): - def __init__( - self, - step_bar, - start_time, - end_time, - **kwargs, - ): - self.step_bar = step_bar - self.reset(start_time=start_time, end_time=end_time, **kwargs) + + def reset(self, trade_order_list=None, **kwargs): + super(TWAPStrategy, self).reset(**kwargs) + TradingEnhancement.reset(trade_order_list=trade_order_list) self.trade_amount = {} for order in self.trade_order_list: self.trade_amount[(order.stock_id, order.direction)] = order.amount // self.trade_len - - def reset(self, start_time=None, end_time=None, trade_order_list=None, **kwargs): - super(SignalStrategy, self).reset(start_time=start_time, end_time=end_time, **kwargs) - TradingEnhancement.reset(trade_order_list=trade_order_list) + def generate_order_list(self, **kwargs): super(TopkDropoutStrategy, self).generate_order_list() - trade_start_time = self.trade_dates[self.trade_index - 1] - trade_end_time = self.trade_dates[self.trade_index] + trade_start_time, trade_end_time = self._get_trade_time() order_list = [] for order in self.trade_order_list: _order = Order( @@ -43,7 +34,7 @@ class TWAPStrategy(RuleStrategy, TradingEnhancement): order_list.append(_order) return order_list -class SBBStrategy(RuleStrategy, TradingEnhancement): +class SBBStrategyBase(RuleStrategy, TradingEnhancement): """ (S)elect the (B)etter one among every two adjacent trading (B)ars to sell or buy. """ @@ -51,34 +42,22 @@ class SBBStrategy(RuleStrategy, TradingEnhancement): TREND_SHORT = 1 TREND_LONG = 2 - def __init__( - self, - step_bar, - start_time, - end_time, - **kwargs, - ): - self.step_bar = step_bar - self.reset(start_time=start_time, end_time=end_time, **kwargs) + def reset(self, trade_order_list=None, **kwargs): + TradingEnhancement.reset(trade_order_list=trade_order_list) self.trade_amount = {} self.trade_delay = {} for order in self.trade_order_list: self.trade_amount[(order.stock_id, order.direction)] = order.amount // self.trade_len self.trade_trend[(order.stock_id, order.direction)] = TREND_MID - - def reset(self, start_time=None, end_time=None, trade_order_list=None, **kwargs): - super(SignalStrategy, self).reset(start_time=start_time, end_time=end_time, **kwargs) - TradingEnhancement.reset(trade_order_list=trade_order_list) + super(SBBStrategyBase, self).reset(**kwargs) def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None): raise NotImplementedError("pred_price_trend method is not implemented!") - def generate_order_list(self, trade_start_time, trade_end_time, **kwargs): - super(TopkDropoutStrategy, self).generate_order_list() - if self.trade_index == 1: - pred_start_time, pred_end_time = None, trade_start_time - pd.Timedelta(seconds=1) - else: - pred_start_time, pred_end_time = self.trade_dates[self.trade_index - 2], trade_start_time - pd.Timedelta(seconds=1) + def generate_order_list(self, **kwargs): + super(SBBStrategyBase, self).generate_order_list() + trade_start_time, trade_end_time = self._get_trade_time() + pred_start_time, pred_end_time = self._get_last_trade_time() order_list = [] for order in self.trade_order_list: if self.trade_index % 2 == 1: @@ -124,7 +103,7 @@ class SBBStrategy(RuleStrategy, TradingEnhancement): return order_list -class SBBEMAStrategy(SBBStrategy): +class SBBStrategyEMA(SBBStrategyBase): """ (S)elect the (B)etter one among every two adjacent trading (B)ars to sell or buy with (EMA). """ @@ -137,17 +116,17 @@ class SBBEMAStrategy(SBBStrategy): freq="day", **kwargs, ): - self.step_bar = step_bar + super(SBBStrategyEMA, self).__init__(step_bar, start_time, end_time, **kwargs) if instruments is None: warnings.warn("`instruments` is not set, will load all stocks") self.instruments = "all" if isinstance(instruments, str): self.instruments = D.instruments(instruments, filter_pipe=self.filter_pipe) self.freq = freq - self.reset(start_time=start_time, end_time=end_time) + - def _reset_trade_date(self, start_time=None, end_time=None): - super(SignalStrategy, self)._reset_trade_date(start_time=start_time, end_time=end_time) + def _reset_trade_calendar(self, start_time=None, end_time=None, _calendar=None): + super(SBBStrategyEMA, self)._reset_trade_calendar(start_time=start_time, end_time=end_time, _calendar=_calendar) fields = [("EMA...", "signal")] self.signal = D.features(instruments, fields, start_time=self.start_time, end_time=self.end_time, freq=self.freq) diff --git a/qlib/strategy/base.py b/qlib/strategy/base.py index 03b9d88c0..9f9be45cb 100644 --- a/qlib/strategy/base.py +++ b/qlib/strategy/base.py @@ -20,26 +20,49 @@ from ..data.data import D - label和freq和strategy的bar分离,这个如何决策呢 """ class BaseStrategy: - def __init__(self, step_bar, start_time, end_time, **kwargs): + def __init__(self, step_bar, start_time=None, end_time=None, **kwargs): self.step_bar = step_bar self.reset(start_time=start_time, end_time=end_time, **kwargs) - def _reset_trade_date(self, start_time=None, end_time=None): + def _reset_trade_calendar(self, start_time, end_time, _calendar=None): if start_time: self.start_time = start_time if end_time: self.end_time = end_time - if not self.start_time or not self.end_time: - raise ValueError("value of `start_time` or `end_time` is None") - _calendar = get_sample_freq_calendar(start_time=start_time, end_time=end_time, freq=step_bar) - self.trade_dates = np.hstack(_calendar, pd.Timestamp(self.end_time)) - self.trade_len = len(self.trade_dates) - self.trade_index = 0 - - def reset(self, start_time=None, end_time=None, **kwargs): - if start_time or end_time: - self._reset_trade_date(start_time=start_time, end_time=end_time) + if self.start_time and self.end_time: + if not _calendar: + _calendar = get_sample_freq_calendar(start_time=start_time, end_time=end_time, freq=step_bar) + self.trade_calendar = np.hstack(_calendar, pd.Timestamp(self.end_time)) + else: + self.trade_calendar = _calendar + self.trade_len = len(self.trade_calendar) + self.trade_index = 0 + else: + raise ValueError("failed to reset trade calendar, params `start_time` or `end_time` is None.") + def reset(self, start_time=None, end_time=None, _calendar=None): + if start_time or end_time : + self._reset_trade_calendar(start_time=start_time, end_time=end_time, calendar=calendar) + + def _get_trade_time(self): + if 0 < self.trade_index < self.trade_len - 1: + trade_start_time = self.trade_calendar[self.trade_index - 1] + trade_end_time = self.trade_calendar[self.trade_index] - pd.Timestamp(second=1) + return trade_start_time, trade_end_time + elif self.trade_index == self.trade_len - 1: + trade_start_time = self.trade_calendar[self.trade_index - 1] + trade_end_time = self.trade_calendar[self.trade_index] + return trade_start_time, trade_end_time + else: + raise RuntimeError("trade_index out of range") + + def _get_last_trade_time(self, shift=1): + if self.trade_index - shift < 0: + return None, None + elif self.trade_index - shift == 0: + return None, self.trade_index[self.trade_index - shift] + else: + return self.trade_index[self.trade_index - shift - 1], self.trade_index[self.trade_index - shift] def generate_order_list(self, **kwargs): self.trade_index = self.trade_index + 1 @@ -48,7 +71,7 @@ class RuleStrategy(BaseStrategy): pass class DLStrategy(BaseStrategy): - def __init__(self, step_bar, start_time, end_time, model, dataset:DatasetH): + def __init__(self, step_bar, model, dataset:DatasetH, start_time=None, end_time=None): self.model = model self.dataset = dataset self.pred_scores = self.model.predict(dataset) @@ -62,6 +85,5 @@ class DLStrategy(BaseStrategy): class TradingEnhancement: def reset(self, trade_order_list): - if trade_order_list: - self.trade_order_list = trade_order_list + self.trade_order_list = trade_order_list