From a401f1eafe68398b77b0142445663bbcdf0e080f Mon Sep 17 00:00:00 2001 From: Young Date: Wed, 30 Jun 2021 08:50:03 +0000 Subject: [PATCH 1/6] improve the docstring --- qlib/contrib/strategy/rule_strategy.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/qlib/contrib/strategy/rule_strategy.py b/qlib/contrib/strategy/rule_strategy.py index 22483a79c..e6779f124 100644 --- a/qlib/contrib/strategy/rule_strategy.py +++ b/qlib/contrib/strategy/rule_strategy.py @@ -707,12 +707,33 @@ class RandomOrderStrategy(BaseStrategy): class FileOrderStrategy(BaseStrategy): """ - Motivtaion: + Motivation: - This class provides an interface for user to read orders from csv files. - - It is supposed to be used in """ def __init__(self, file: Union[IO, str, Path], index_range: Tuple[int, int] = None, *args, **kwargs): + """ + + Parameters + ---------- + file : Union[IO, str, Path] + this parameters will specify the info of expected orders + Here is an example of the content + + datetime,instrument,amount,direction + 20200102, SH600519, 1000, sell + 20200103, SH600519, 1000, buy + 20200106, SH600519, 1000, sell + + index_range : Tuple[int, int] + the intra day time index range of the orders + the left and right is closed. + + If you want to get the index_range in intra-day + - `qlib/utils/time.py:def get_day_min_idx_range` can help you create the index range easier + # TODO: this is a index_range level limitation. We'll implement a more detailed limitation later. + + """ super().__init__(*args, **kwargs) with get_io_object(file) as f: self.order_df = pd.read_csv(f, dtype={"datetime": np.str}) From 8b85b9eee79b930c0cb3de44456935e5562a281b Mon Sep 17 00:00:00 2001 From: bxdd Date: Thu, 1 Jul 2021 14:35:49 +0000 Subject: [PATCH 2/6] optimize performance of resam data in rule_strategy & exchange --- qlib/backtest/exchange.py | 25 +++++++-------- qlib/contrib/strategy/rule_strategy.py | 21 +++++++------ qlib/utils/resam.py | 42 ++++++++++++++++++++++++++ 3 files changed, 65 insertions(+), 23 deletions(-) diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py index a759dbd86..f5a366510 100644 --- a/qlib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -12,7 +12,7 @@ import pandas as pd from ..data.data import D from ..data.dataset.utils import get_level_index from ..config import C, REG_CN -from ..utils.resam import resam_ts_data +from ..utils.resam import resam_ts_data, ts_data_last from ..log import get_module_logger from .order import Order @@ -166,7 +166,7 @@ class Exchange: quote_dict = {} for stock_id, stock_val in quote_df.groupby(level="instrument"): - quote_dict[stock_id] = stock_val + quote_dict[stock_id] = stock_val.droplevel(level="instrument") self.quote = quote_dict @@ -186,13 +186,13 @@ class Exchange: """ if direction is None: - buy_limit = resam_ts_data(self.quote[stock_id]["limit_buy"], start_time, end_time, method="all").iloc[0] - sell_limit = resam_ts_data(self.quote[stock_id]["limit_sell"], start_time, end_time, method="all").iloc[0] + buy_limit = resam_ts_data(self.quote[stock_id]["limit_buy"], start_time, end_time, method="all") + sell_limit = resam_ts_data(self.quote[stock_id]["limit_sell"], start_time, end_time, method="all") return buy_limit or sell_limit elif direction == Order.BUY: - return resam_ts_data(self.quote[stock_id]["limit_buy"], start_time, end_time, method="all").iloc[0] + return resam_ts_data(self.quote[stock_id]["limit_buy"], start_time, end_time, method="all") elif direction == Order.SELL: - return resam_ts_data(self.quote[stock_id]["limit_sell"], start_time, end_time, method="all").iloc[0] + return resam_ts_data(self.quote[stock_id]["limit_sell"], start_time, end_time, method="all") else: raise ValueError(f"direction {direction} is not supported!") @@ -267,16 +267,16 @@ class Exchange: ) def get_quote_info(self, stock_id, start_time, end_time): - return resam_ts_data(self.quote[stock_id], start_time, end_time, method="last").iloc[0] + return resam_ts_data(self.quote[stock_id], start_time, end_time, method=ts_data_last) def get_close(self, stock_id, start_time, end_time): - return resam_ts_data(self.quote[stock_id]["$close"], start_time, end_time, method="last").iloc[0] + return resam_ts_data(self.quote[stock_id]["$close"], start_time, end_time, method=ts_data_last) def get_volume(self, stock_id, start_time, end_time): - return resam_ts_data(self.quote[stock_id]["$volume"], start_time, end_time, method="sum").iloc[0] + return resam_ts_data(self.quote[stock_id]["$volume"], start_time, end_time, method="sum") def get_deal_price(self, stock_id, start_time, end_time): - deal_price = resam_ts_data(self.quote[stock_id][self.deal_price], start_time, end_time, method="last").iloc[0] + deal_price = resam_ts_data(self.quote[stock_id][self.deal_price], start_time, end_time, method=ts_data_last) if np.isclose(deal_price, 0.0) or np.isnan(deal_price): self.logger.warning( f"(stock_id:{stock_id}, trade_time:{(start_time, end_time)}, {self.deal_price}): {deal_price}!!!" @@ -295,10 +295,7 @@ class Exchange: """ if stock_id not in self.quote: return None - res = resam_ts_data(self.quote[stock_id]["$factor"], start_time, end_time, method="last") - if res is not None: - res = res.iloc[0] - return res + return resam_ts_data(self.quote[stock_id]["$factor"], start_time, end_time, method=ts_data_last) def generate_amount_position_from_weight_position(self, weight_position, cash, start_time, end_time): """ diff --git a/qlib/contrib/strategy/rule_strategy.py b/qlib/contrib/strategy/rule_strategy.py index 20099d4d3..5d26f0e30 100644 --- a/qlib/contrib/strategy/rule_strategy.py +++ b/qlib/contrib/strategy/rule_strategy.py @@ -3,7 +3,7 @@ import numpy as np import pandas as pd from typing import List, Tuple, Union -from ...utils.resam import resam_ts_data +from ...utils.resam import resam_ts_data, ts_data_last from ...data.data import D from ...strategy.base import BaseStrategy from ...backtest.order import BaseTradeDecision, Order, TradeDecisionWO @@ -427,7 +427,7 @@ class SBBStrategyEMA(SBBStrategyBase): if not signal_df.empty: for stock_id, stock_val in signal_df.groupby(level="instrument"): - self.signal[stock_id] = stock_val + self.signal[stock_id] = stock_val["signal"].droplevel(level="instrument") def reset_level_infra(self, level_infra): """ @@ -449,13 +449,16 @@ class SBBStrategyEMA(SBBStrategyBase): return self.TREND_MID else: _sample_signal = resam_ts_data( - self.signal[stock_id]["signal"], pred_start_time, pred_end_time, method="last" + self.signal[stock_id], + pred_start_time, + pred_end_time, + method=ts_data_last, ) # if EMA signal == 0 or None, return mid trend - if _sample_signal is None or _sample_signal.iloc[0] == 0: + if _sample_signal is None or np.isnan(_sample_signal) or _sample_signal == 0: return self.TREND_MID # if EMA signal > 0, return long trend - elif _sample_signal.iloc[0] > 0: + elif _sample_signal > 0: return self.TREND_LONG # if EMA signal < 0, return short trend else: @@ -518,7 +521,7 @@ class ACStrategy(BaseStrategy): if not signal_df.empty: for stock_id, stock_val in signal_df.groupby(level="instrument"): - self.signal[stock_id] = stock_val + self.signal[stock_id] = stock_val["volatility"].droplevel(level="instrument") def reset_common_infra(self, common_infra): """ @@ -585,12 +588,12 @@ class ACStrategy(BaseStrategy): # considering trade unit sig_sam = ( - resam_ts_data(self.signal[order.stock_id]["volatility"], pred_start_time, pred_end_time, method="last") + resam_ts_data(self.signal[order.stock_id], pred_start_time, pred_end_time, method=ts_data_last) if order.stock_id in self.signal else None ) - if sig_sam is None or sig_sam.iloc[0] is None: + if sig_sam is None or np.isnan(sig_sam): # no signal, TWAP _amount_trade_unit = self.trade_exchange.get_amount_of_trade_unit(order.factor) if _amount_trade_unit is None: @@ -607,7 +610,7 @@ class ACStrategy(BaseStrategy): ) else: # VA strategy - kappa_tild = self.lamb / self.eta * sig_sam.iloc[0] * sig_sam.iloc[0] + kappa_tild = self.lamb / self.eta * sig_sam * sig_sam kappa = np.arccosh(kappa_tild / 2 + 1) amount_ratio = ( np.sinh(kappa * (trade_len - trade_step)) - np.sinh(kappa * (trade_len - trade_step - 1)) diff --git a/qlib/utils/resam.py b/qlib/utils/resam.py index 4df155946..7782b8486 100644 --- a/qlib/utils/resam.py +++ b/qlib/utils/resam.py @@ -263,3 +263,45 @@ def resam_ts_data( elif isinstance(method, str): return getattr(feature, method)(**method_kwargs) return feature + + +def get_valid_value(series, last=True): + """get the first/last not nan value of pd.Series with single level index + Parameters + ---------- + series : pd.Seires + last : bool, optional + wether to get the last valid value, by default True + - if last is True, get the last valid value + - else, get the first valid value + + Returns + ------- + Nan | float + the first/last valid value + """ + x = series.dropna() + if x.empty: + return np.nan + else: + return x.iloc[-1] if last else x.iloc[0] + + +def ts_data_last(ts_feature): + """get the last not nan value of pd.Series|DataFrame with single level index""" + if isinstance(ts_feature, pd.DataFrame): + return ts_feature.apply(lambda column: get_valid_value(column, last=True)) + elif isinstance(ts_feature, pd.Series): + return get_valid_value(ts_feature, last=True) + else: + raise TypeError(f"ts_feature should be pd.DataFrame/Series, not {type(ts_feature)}") + + +def ts_data_first(ts_feature): + """get the first not nan value of pd.Series|DataFrame with single level index""" + if isinstance(ts_feature, pd.DataFrame): + return ts_feature.apply(lambda column: get_valid_value(column, last=False)) + elif isinstance(ts_feature, pd.Series): + return get_valid_value(ts_feature, last=False) + else: + raise TypeError(f"ts_feature should be pd.DataFrame/Series, not {type(ts_feature)}") From 8dd5788bacb313ad023e80eef5fc263186e045f3 Mon Sep 17 00:00:00 2001 From: bxdd Date: Thu, 1 Jul 2021 16:31:58 +0000 Subject: [PATCH 3/6] fix comments & update resam ts_last method --- .../nested_decision_execution/workflow.py | 16 ++++----- qlib/backtest/report.py | 2 +- qlib/data/dataset/handler.py | 34 +++++++++++++------ qlib/data/dataset/storage.py | 11 +++++- qlib/utils/resam.py | 7 ++-- 5 files changed, 45 insertions(+), 25 deletions(-) diff --git a/examples/nested_decision_execution/workflow.py b/examples/nested_decision_execution/workflow.py index b6c1362fd..3108960c8 100644 --- a/examples/nested_decision_execution/workflow.py +++ b/examples/nested_decision_execution/workflow.py @@ -124,14 +124,14 @@ class NestedDecisionExecutionWorkflow: def _init_qlib(self): """initialize qlib""" - # provider_uri_day = "/data/stock_data/huaxia/qlib" - # provider_uri_1min = "/data2/stock_data/huaxia_1min_qlib" - provider_uri_day = "~/.qlib/qlib_data/cn_data" # target_dir - GetData().qlib_data(target_dir=provider_uri_day, region=REG_CN, version="v2", exists_skip=True) - provider_uri_1min = HIGH_FREQ_CONFIG.get("provider_uri") - GetData().qlib_data( - target_dir=provider_uri_1min, interval="1min", region=REG_CN, version="v2", exists_skip=True - ) + provider_uri_day = "/data/stock_data/huaxia/qlib" + provider_uri_1min = "/data2/stock_data/huaxia_1min_qlib" + # provider_uri_day = "~/.qlib/qlib_data/cn_data" # target_dir + # GetData().qlib_data(target_dir=provider_uri_day, region=REG_CN, version="v2", exists_skip=True) + # provider_uri_1min = HIGH_FREQ_CONFIG.get("provider_uri") + # GetData().qlib_data( + # target_dir=provider_uri_1min, interval="1min", region=REG_CN, version="v2", exists_skip=True + # ) provider_uri_map = {"1min": provider_uri_1min, "day": provider_uri_day} client_config = { "calendar_provider": { diff --git a/qlib/backtest/report.py b/qlib/backtest/report.py index f217ea169..7623af551 100644 --- a/qlib/backtest/report.py +++ b/qlib/backtest/report.py @@ -91,7 +91,7 @@ class Report: if freq is None: raise ValueError("benchmark freq can't be None!") - _codes = benchmark if isinstance(benchmark, list) else [benchmark] + _codes = benchmark if isinstance(benchmark, (list, dict)) else [benchmark] fields = ["$close/Ref($close,1)-1"] _temp_result, _ = get_higher_eq_freq_feature(_codes, fields, start_time, end_time, freq=freq) if len(_temp_result) == 0: diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index edcc1ede2..2d5159292 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -197,7 +197,7 @@ class DataHandler(Serializable): ------- pd.DataFrame. """ - from .storage import HasingStockStorage + from .storage import BaseHandlerStorage data_storage = self._data if isinstance(data_storage, pd.DataFrame): @@ -211,10 +211,17 @@ class DataHandler(Serializable): # Fetch column first will be more friendly to SepDataFrame data_df = fetch_df_by_col(data_df, col_set) data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig) - elif isinstance(data_storage, HasingStockStorage): - if proc_func is not None: - raise ValueError("proc_func is not supported by the HasingStockStorage") - data_df = data_storage.fetch(selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig) + elif isinstance(data_storage, BaseHandlerStorage): + if not data_storage.is_proc_func_supported(): + if proc_func is not None: + raise ValueError(f"proc_func is not supported by the storage {type(data_storage)}") + data_df = data_storage.fetch( + selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig + ) + else: + data_df = data_storage.fetch( + selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig, proc_func=proc_func + ) else: raise TypeError(f"data_storage should be pd.DataFrame|HasingStockStorage, not {type(data_storage)}") @@ -522,7 +529,7 @@ class DataHandlerLP(DataHandler): ------- pd.DataFrame: """ - from .storage import HasingStockStorage + from .storage import BaseHandlerStorage data_storage = self._get_df_by_key(data_key) if isinstance(data_storage, pd.DataFrame): @@ -537,10 +544,17 @@ class DataHandlerLP(DataHandler): data_df = fetch_df_by_col(data_df, col_set) data_df = fetch_df_by_index(data_df, selector, level, fetch_orig=self.fetch_orig) - elif isinstance(data_storage, HasingStockStorage): - if proc_func is not None: - raise ValueError("proc_func is not supported by the HasingStockStorage") - data_df = data_storage.fetch(selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig) + elif isinstance(data_storage, BaseHandlerStorage): + if not data_storage.is_proc_func_supported(): + if proc_func is not None: + raise ValueError(f"proc_func is not supported by the storage {type(data_storage)}") + data_df = data_storage.fetch( + selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig + ) + else: + data_df = data_storage.fetch( + selector=selector, level=level, col_set=col_set, fetch_orig=self.fetch_orig, proc_func=proc_func + ) else: raise TypeError(f"data_storage should be pd.DataFrame|HasingStockStorage, not {type(data_storage)}") diff --git a/qlib/data/dataset/storage.py b/qlib/data/dataset/storage.py index 247970481..cd38bbefa 100644 --- a/qlib/data/dataset/storage.py +++ b/qlib/data/dataset/storage.py @@ -14,6 +14,7 @@ class BaseHandlerStorage: level: Union[str, int] = "datetime", col_set: Union[str, List[str]] = DataHandler.CS_ALL, fetch_orig: bool = True, + proc_func: Callable = None, **kwargs, ) -> pd.DataFrame: """fetch data from the data storage @@ -24,6 +25,7 @@ class BaseHandlerStorage: describe how to select data by index level : Union[str, int] which index level to select the data + - if level is None, apply selector to df directly col_set : Union[str, List[str]] - if isinstance(col_set, str): select a set of meaningful columns.(e.g. features, columns) @@ -33,7 +35,8 @@ class BaseHandlerStorage: select several sets of meaningful columns, the returned data has multiple level fetch_orig : bool Return the original data instead of copy if possible. - + proc_func: Callable + please refer to the doc of DataHandler.fetch """ raise NotImplementedError("fetch is method not implemented!") @@ -42,6 +45,9 @@ class BaseHandlerStorage: def from_df(df: pd.DataFrame): raise NotImplementedError("from_df method is not implemented!") + def is_proc_func_supported(self): + raise NotImplementedError("is_proc_func_supported method is not implemented!") + class HasingStockStorage(BaseHandlerStorage): def __init__(self, df): @@ -105,3 +111,6 @@ class HasingStockStorage(BaseHandlerStorage): return fetch_stock_df_list[0] else: return pd.concat(fetch_stock_df_list, sort=False, copy=~fetch_orig) + + def is_proc_func_supported(self): + return False diff --git a/qlib/utils/resam.py b/qlib/utils/resam.py index 7782b8486..7e0dc141c 100644 --- a/qlib/utils/resam.py +++ b/qlib/utils/resam.py @@ -270,6 +270,7 @@ def get_valid_value(series, last=True): Parameters ---------- series : pd.Seires + series should not be empty last : bool, optional wether to get the last valid value, by default True - if last is True, get the last valid value @@ -280,11 +281,7 @@ def get_valid_value(series, last=True): Nan | float the first/last valid value """ - x = series.dropna() - if x.empty: - return np.nan - else: - return x.iloc[-1] if last else x.iloc[0] + return series.fillna(method="ffill").iloc[-1] if last else series.fillna(method="bfill").iloc[0] def ts_data_last(ts_feature): From ef7fe8aa75c7e0fb48518721b5109ff10a651f55 Mon Sep 17 00:00:00 2001 From: Young Date: Sat, 3 Jul 2021 08:46:09 +0000 Subject: [PATCH 4/6] support parallel HF trading --- qlib/backtest/exchange.py | 13 ++----- qlib/backtest/executor.py | 52 ++++++++++++++++++++++++-- qlib/contrib/strategy/rule_strategy.py | 3 ++ 3 files changed, 55 insertions(+), 13 deletions(-) diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py index 8177d53ee..34c0ef744 100644 --- a/qlib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -242,6 +242,7 @@ class Exchange: raise ValueError("trade_account and position can only choose one") trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time) + # NOTE: order will be changed in this function trade_val, trade_cost = self._calc_trade_info_by_order( order, trade_account.current if trade_account else position ) @@ -256,16 +257,6 @@ class Exchange: return trade_val, trade_cost, trade_price - def create_order(self, code, amount, start_time, end_time, direction) -> Order: - return Order( - stock_id=code, - amount=amount, - start_time=start_time, - end_time=end_time, - direction=direction, - factor=self.get_factor(code, start_time, end_time), - ) - def get_quote_info(self, stock_id, start_time, end_time): return resam_ts_data(self.quote[stock_id], start_time, end_time, method="last").iloc[0] @@ -471,6 +462,8 @@ class Exchange: """ Calculation of trade info + **NOTE**: Order will be changed in this function + :param order: :param position: Position :return: trade_val, trade_cost diff --git a/qlib/backtest/executor.py b/qlib/backtest/executor.py index 3f7b2f4ed..ea2a0567d 100644 --- a/qlib/backtest/executor.py +++ b/qlib/backtest/executor.py @@ -1,7 +1,7 @@ import copy import warnings import pandas as pd -from typing import Union +from typing import List, Union from qlib.backtest.report import Indicator @@ -317,6 +317,15 @@ class NestedExecutor(BaseExecutor): class SimulatorExecutor(BaseExecutor): """Executor that simulate the true market""" + # available trade_types + TT_SERIAL = "serial" + ## The orders will be executed serially in a sequence + # In each trading step, it is possible that users sell instruments first and use the money to buy new instruments + TT_PARAL = "parallel" + ## The orders will be executed parallelly + # In each trading step, if users try to sell instruments first and buy new instruments with money, failure will + # occur + def __init__( self, time_per_step: str, @@ -328,6 +337,7 @@ class SimulatorExecutor(BaseExecutor): track_data: bool = False, trade_exchange: Exchange = None, common_infra: CommonInfrastructure = None, + trade_type: str = TT_PARAL, **kwargs, ): """ @@ -336,6 +346,8 @@ class SimulatorExecutor(BaseExecutor): trade_exchange : Exchange exchange that provides market info, used to deal order and generate report - If `trade_exchange` is None, self.trade_exchange will be set with common_infra + trade_type: str + please refer to the doc of `TT_SERIAL` & `TT_PARAL` """ super(SimulatorExecutor, self).__init__( time_per_step=time_per_step, @@ -351,6 +363,8 @@ class SimulatorExecutor(BaseExecutor): if trade_exchange is not None: self.trade_exchange = trade_exchange + self.trade_type = trade_type + def reset_common_infra(self, common_infra): """ reset infrastructure for trading @@ -360,14 +374,45 @@ class SimulatorExecutor(BaseExecutor): if common_infra.has("trade_exchange"): self.trade_exchange = common_infra.get("trade_exchange") + def _get_order_iterator(self, trade_decision: BaseTradeDecision) -> List[Order]: + """ + + Parameters + ---------- + trade_decision : BaseTradeDecision + the trade decision given by the strategy + + Returns + ------- + List[Order]: + get a list orders according to `self.trade_type` + """ + orders = trade_decision.get_decision() + + if self.trade_type == self.TT_SERIAL: + # Orders will be traded in a parallel way + order_it = orders + elif self.trade_type == self.TT_PARAL: + # NOTE: !!!!!!! + # Assumption: there will not be orders in different trading direction in a single step of a strategy !!!! + # The parallel trading failure will be caused only by the confliction of money + # Therefore, make the buying go first will make sure the confliction happen. + # It equals to parallel trading after sorting the order by direction + order_it = sorted(orders, key=lambda order: -order.direction) + else: + raise NotImplementedError(f"This type of input is not supported") + return order_it + def execute(self, trade_decision: BaseTradeDecision): trade_step = self.trade_calendar.get_trade_step() trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step) execute_result = [] - for order in trade_decision.get_decision(): + + for order in self._get_order_iterator(trade_decision): if self.trade_exchange.check_order(order) is True: - # execute the order + # execute the order. + # NOTE: The trade_account will be changed in this function trade_val, trade_cost, trade_price = self.trade_exchange.deal_order( order, trade_account=self.trade_account ) @@ -404,6 +449,7 @@ class SimulatorExecutor(BaseExecutor): # do nothing pass + # Account will not be changed in this function self.trade_account.update_bar_end( trade_start_time, trade_end_time, diff --git a/qlib/contrib/strategy/rule_strategy.py b/qlib/contrib/strategy/rule_strategy.py index e6779f124..2bc01045d 100644 --- a/qlib/contrib/strategy/rule_strategy.py +++ b/qlib/contrib/strategy/rule_strategy.py @@ -718,8 +718,11 @@ class FileOrderStrategy(BaseStrategy): ---------- file : Union[IO, str, Path] this parameters will specify the info of expected orders + Here is an example of the content + 1) Amount (**adjusted**) based strategy + datetime,instrument,amount,direction 20200102, SH600519, 1000, sell 20200103, SH600519, 1000, buy From ecf2f24d598f19022b5dbf3e2a3819801bd5c7a9 Mon Sep 17 00:00:00 2001 From: bxdd Date: Sat, 3 Jul 2021 18:42:40 +0000 Subject: [PATCH 5/6] fix comments --- .../nested_decision_execution/workflow.py | 16 +++++++-------- qlib/data/dataset/storage.py | 8 +++++++- qlib/utils/resam.py | 20 ++++++++----------- 3 files changed, 23 insertions(+), 21 deletions(-) diff --git a/examples/nested_decision_execution/workflow.py b/examples/nested_decision_execution/workflow.py index 3108960c8..b6c1362fd 100644 --- a/examples/nested_decision_execution/workflow.py +++ b/examples/nested_decision_execution/workflow.py @@ -124,14 +124,14 @@ class NestedDecisionExecutionWorkflow: def _init_qlib(self): """initialize qlib""" - provider_uri_day = "/data/stock_data/huaxia/qlib" - provider_uri_1min = "/data2/stock_data/huaxia_1min_qlib" - # provider_uri_day = "~/.qlib/qlib_data/cn_data" # target_dir - # GetData().qlib_data(target_dir=provider_uri_day, region=REG_CN, version="v2", exists_skip=True) - # provider_uri_1min = HIGH_FREQ_CONFIG.get("provider_uri") - # GetData().qlib_data( - # target_dir=provider_uri_1min, interval="1min", region=REG_CN, version="v2", exists_skip=True - # ) + # provider_uri_day = "/data/stock_data/huaxia/qlib" + # provider_uri_1min = "/data2/stock_data/huaxia_1min_qlib" + provider_uri_day = "~/.qlib/qlib_data/cn_data" # target_dir + GetData().qlib_data(target_dir=provider_uri_day, region=REG_CN, version="v2", exists_skip=True) + provider_uri_1min = HIGH_FREQ_CONFIG.get("provider_uri") + GetData().qlib_data( + target_dir=provider_uri_1min, interval="1min", region=REG_CN, version="v2", exists_skip=True + ) provider_uri_map = {"1min": provider_uri_1min, "day": provider_uri_day} client_config = { "calendar_provider": { diff --git a/qlib/data/dataset/storage.py b/qlib/data/dataset/storage.py index cd38bbefa..9325807f9 100644 --- a/qlib/data/dataset/storage.py +++ b/qlib/data/dataset/storage.py @@ -37,8 +37,12 @@ class BaseHandlerStorage: Return the original data instead of copy if possible. proc_func: Callable please refer to the doc of DataHandler.fetch - """ + Returns + ------- + pd.DataFrame + the dataframe fetched + """ raise NotImplementedError("fetch is method not implemented!") @staticmethod @@ -46,6 +50,7 @@ class BaseHandlerStorage: raise NotImplementedError("from_df method is not implemented!") def is_proc_func_supported(self): + """whether the arg `proc_func` in `fetch` method is supported.""" raise NotImplementedError("is_proc_func_supported method is not implemented!") @@ -113,4 +118,5 @@ class HasingStockStorage(BaseHandlerStorage): return pd.concat(fetch_stock_df_list, sort=False, copy=~fetch_orig) def is_proc_func_supported(self): + """the arg `proc_func` in `fetch` method is not supported in HasingStockStorage""" return False diff --git a/qlib/utils/resam.py b/qlib/utils/resam.py index 7e0dc141c..9e9590e30 100644 --- a/qlib/utils/resam.py +++ b/qlib/utils/resam.py @@ -3,6 +3,8 @@ import datetime import numpy as np import pandas as pd + +from functools import partial from typing import Tuple, List, Union, Optional, Callable from . import lazy_sort_index @@ -284,21 +286,15 @@ def get_valid_value(series, last=True): return series.fillna(method="ffill").iloc[-1] if last else series.fillna(method="bfill").iloc[0] -def ts_data_last(ts_feature): - """get the last not nan value of pd.Series|DataFrame with single level index""" +def _ts_data_valid(ts_feature, last=False): + """get the first/last not nan value of pd.Series|DataFrame with single level index""" if isinstance(ts_feature, pd.DataFrame): - return ts_feature.apply(lambda column: get_valid_value(column, last=True)) + return ts_feature.apply(lambda column: get_valid_value(column, last=last)) elif isinstance(ts_feature, pd.Series): - return get_valid_value(ts_feature, last=True) + return get_valid_value(ts_feature, last=last) else: raise TypeError(f"ts_feature should be pd.DataFrame/Series, not {type(ts_feature)}") -def ts_data_first(ts_feature): - """get the first not nan value of pd.Series|DataFrame with single level index""" - if isinstance(ts_feature, pd.DataFrame): - return ts_feature.apply(lambda column: get_valid_value(column, last=False)) - elif isinstance(ts_feature, pd.Series): - return get_valid_value(ts_feature, last=False) - else: - raise TypeError(f"ts_feature should be pd.DataFrame/Series, not {type(ts_feature)}") +ts_data_last = partial(_ts_data_valid, last=False) +ts_data_first = partial(_ts_data_valid, last=True) From 50c0e99f9895c6176266d52ff83849eb50e7b32e Mon Sep 17 00:00:00 2001 From: Young Date: Sun, 4 Jul 2021 06:41:34 +0000 Subject: [PATCH 6/6] fix ffr and order amount --- qlib/backtest/account.py | 7 +++++-- qlib/backtest/executor.py | 2 ++ qlib/backtest/order.py | 31 +++++++++++++++++++++++++++++-- qlib/backtest/report.py | 29 +++++++++++++++++++++-------- 4 files changed, 57 insertions(+), 12 deletions(-) diff --git a/qlib/backtest/account.py b/qlib/backtest/account.py index 6167ee407..0d89dde87 100644 --- a/qlib/backtest/account.py +++ b/qlib/backtest/account.py @@ -9,7 +9,7 @@ import pandas as pd from .position import BasePosition, InfPosition, Position from .report import Report, Indicator -from .order import Order +from .order import BaseTradeDecision, Order from .exchange import Exchange """ @@ -226,6 +226,7 @@ class Account: trade_end_time: pd.Timestamp, trade_exchange: Exchange, atomic: bool, + outer_trade_decision: BaseTradeDecision, generate_report: bool = False, trade_info: list = None, inner_order_indicators: Indicator = None, @@ -276,7 +277,9 @@ class Account: if atomic: self.indicator.update_order_indicators(trade_start_time, trade_end_time, trade_info, trade_exchange) else: - self.indicator.agg_order_indicators(inner_order_indicators, indicator_config) + self.indicator.agg_order_indicators( + inner_order_indicators, indicator_config=indicator_config, outer_trade_decision=outer_trade_decision + ) self.indicator.cal_trade_indicators(trade_start_time, self.freq, indicator_config) self.indicator.record(trade_start_time) diff --git a/qlib/backtest/executor.py b/qlib/backtest/executor.py index ea2a0567d..14d97e825 100644 --- a/qlib/backtest/executor.py +++ b/qlib/backtest/executor.py @@ -299,6 +299,7 @@ class NestedExecutor(BaseExecutor): trade_end_time, self.trade_exchange, atomic=False, + outer_trade_decision=trade_decision, generate_report=self.generate_report, inner_order_indicators=inner_order_indicators, indicator_config=self.indicator_config, @@ -455,6 +456,7 @@ class SimulatorExecutor(BaseExecutor): trade_end_time, self.trade_exchange, atomic=True, + outer_trade_decision=trade_decision, generate_report=self.generate_report, trade_info=execute_result, indicator_config=self.indicator_config, diff --git a/qlib/backtest/order.py b/qlib/backtest/order.py index 9df162263..1767deb62 100644 --- a/qlib/backtest/order.py +++ b/qlib/backtest/order.py @@ -39,7 +39,7 @@ class Order: """ stock_id: str - amount: float + amount: float # `amount` is a non-negative value # The interval of the order which belongs to (NOTE: this is not the expected order dealing range time) start_time: pd.Timestamp @@ -47,7 +47,7 @@ class Order: direction: int factor: float - deal_amount: float = field(init=False) + deal_amount: float = field(init=False) # `deal_amount` is a non-negative value # FIXME: # for compatible now. @@ -60,6 +60,33 @@ class Order: raise NotImplementedError("direction not supported, `Order.SELL` for sell, `Order.BUY` for buy") self.deal_amount = 0 + @property + def amount_delta(self) -> float: + """ + return the delta of amount. + - Positive value indicates buying `amount` of share + - Negative value indicates selling `amount` of share + """ + return self.amount * self.sign + + @property + def deal_amount_delta(self) -> float: + """ + return the delta of deal_amount. + - Positive value indicates buying `deal_amount` of share + - Negative value indicates selling `deal_amount` of share + """ + return self.deal_amount * self.sign + + @property + def sign(self) -> float: + """ + return the sign of trading + - `+1` indicates buying + - `-1` value indicates selling + """ + return self.direction * 2 - 1 + @staticmethod def parse_dir(direction: Union[str, int, OrderDir]) -> OrderDir: if isinstance(direction, OrderDir): diff --git a/qlib/backtest/report.py b/qlib/backtest/report.py index 7623af551..ce2812bd0 100644 --- a/qlib/backtest/report.py +++ b/qlib/backtest/report.py @@ -4,6 +4,8 @@ from collections import OrderedDict from logging import warning +from typing import List +from qlib.backtest.order import BaseTradeDecision, Order import pandas as pd import pathlib import warnings @@ -241,13 +243,13 @@ class Indicator: trade_cost = dict() for order, _trade_val, _trade_cost, _trade_price in trade_info: - amount[order.stock_id] = order.amount * (order.direction * 2 - 1) - deal_amount[order.stock_id] = order.deal_amount * (order.direction * 2 - 1) + amount[order.stock_id] = order.amount_delta + deal_amount[order.stock_id] = order.deal_amount_delta trade_price[order.stock_id] = _trade_price - trade_value[order.stock_id] = _trade_val * (order.direction * 2 - 1) + trade_value[order.stock_id] = _trade_val * order.sign trade_cost[order.stock_id] = _trade_cost - self.order_indicator["amount"] = pd.Series(amount) + self.order_indicator["amount"] = self.order_indicator["inner_amount"] = pd.Series(amount) self.order_indicator["deal_amount"] = pd.Series(deal_amount) self.order_indicator["trade_price"] = pd.Series(trade_price) self.order_indicator["trade_value"] = pd.Series(trade_value) @@ -271,13 +273,13 @@ class Indicator: ) / self.order_indicator["base_price"] def _agg_order_trade_info(self, inner_order_indicators): - amount = pd.Series() + inner_amount = pd.Series() deal_amount = pd.Series() trade_price = pd.Series() trade_value = pd.Series() trade_cost = pd.Series() for _order_indicator in inner_order_indicators: - amount = amount.add(_order_indicator["amount"], fill_value=0) + inner_amount = inner_amount.add(_order_indicator["inner_amount"], fill_value=0) deal_amount = deal_amount.add(_order_indicator["deal_amount"], fill_value=0) trade_price = trade_price.add( _order_indicator["trade_price"] * _order_indicator["deal_amount"], fill_value=0 @@ -285,13 +287,21 @@ class Indicator: trade_value = trade_value.add(_order_indicator["trade_value"], fill_value=0) trade_cost = trade_cost.add(_order_indicator["trade_cost"], fill_value=0) - self.order_indicator["amount"] = amount + self.order_indicator["inner_amount"] = inner_amount self.order_indicator["deal_amount"] = deal_amount trade_price /= self.order_indicator["deal_amount"] self.order_indicator["trade_price"] = trade_price self.order_indicator["trade_value"] = trade_value self.order_indicator["trade_cost"] = trade_cost + def _update_trade_amount(self, outer_trade_decision: BaseTradeDecision): + # NOTE: these indicator is designed for order execution, so the + decision: List[Order] = outer_trade_decision.get_decision() + if decision is None: + self.order_indicator["amount"] = pd.Series() + else: + self.order_indicator["amount"] = pd.Series({order.stock_id: order.amount_delta for order in decision}) + def _agg_order_fulfill_rate(self): self.order_indicator["ffr"] = self.order_indicator["deal_amount"] / self.order_indicator["amount"] @@ -367,8 +377,11 @@ class Indicator: self._update_order_fulfill_rate() self._update_order_price_advantage(trade_exchange, trade_start_time, trade_end_time) - def agg_order_indicators(self, inner_order_indicators, indicator_config={}): + def agg_order_indicators( + self, inner_order_indicators, outer_trade_decision: BaseTradeDecision, indicator_config={} + ): self._agg_order_trade_info(inner_order_indicators) + self._update_trade_amount(outer_trade_decision) self._agg_order_fulfill_rate() pa_config = indicator_config.get("pa_config", {}) self._agg_order_price_advantage(inner_order_indicators, base_price=pa_config.get("base_price", "twap"))