From 4085b447aab98e23f3c6106038705fd1a423471d Mon Sep 17 00:00:00 2001 From: bxdd Date: Thu, 27 May 2021 21:14:39 +0800 Subject: [PATCH] move backtest to core, fix calendar bugs, add some docstring --- examples/multi_level_trading/workflow.py | 10 +- examples/workflow_by_code.ipynb | 1 - examples/workflow_by_code.py | 11 -- qlib/{contrib => }/backtest/__init__.py | 8 +- qlib/{contrib => }/backtest/account.py | 2 + qlib/{contrib => }/backtest/backtest.py | 0 qlib/{contrib => }/backtest/exchange.py | 12 +- qlib/{contrib => }/backtest/executor.py | 54 ++++---- qlib/{contrib => }/backtest/order.py | 0 qlib/{contrib => }/backtest/position.py | 0 .../backtest/profit_attribution.py | 4 +- qlib/{contrib => }/backtest/report.py | 8 +- qlib/backtest/utils.py | 98 ++++++++++++++ qlib/config.py | 6 +- qlib/contrib/backtest/utils.py | 67 ---------- qlib/contrib/evaluate.py | 6 +- .../analysis_position/parse_position.py | 2 +- .../report/analysis_position/rank_label.py | 2 +- qlib/contrib/strategy/cost_control.py | 5 +- qlib/contrib/strategy/model_strategy.py | 34 +++-- qlib/contrib/strategy/order_generator.py | 4 +- qlib/contrib/strategy/rule_strategy.py | 126 ++++++++++++------ qlib/data/data.py | 4 +- qlib/rl/env.py | 3 +- qlib/strategy/base.py | 16 +-- qlib/utils/resam.py | 29 ++-- qlib/workflow/record_temp.py | 2 +- 27 files changed, 298 insertions(+), 216 deletions(-) rename qlib/{contrib => }/backtest/__init__.py (96%) rename qlib/{contrib => }/backtest/account.py (99%) rename qlib/{contrib => }/backtest/backtest.py (100%) rename qlib/{contrib => }/backtest/exchange.py (98%) rename qlib/{contrib => }/backtest/executor.py (89%) rename qlib/{contrib => }/backtest/order.py (100%) rename qlib/{contrib => }/backtest/position.py (100%) rename qlib/{contrib => }/backtest/profit_attribution.py (99%) rename qlib/{contrib => }/backtest/report.py (97%) create mode 100644 qlib/backtest/utils.py delete mode 100644 qlib/contrib/backtest/utils.py diff --git a/examples/multi_level_trading/workflow.py b/examples/multi_level_trading/workflow.py index ea11d4e7f..8096fc76f 100644 --- a/examples/multi_level_trading/workflow.py +++ b/examples/multi_level_trading/workflow.py @@ -10,7 +10,7 @@ from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict from qlib.workflow import R from qlib.workflow.record_temp import SignalRecord, PortAnaRecord from qlib.tests.data import GetData -from qlib.contrib.backtest import collect_data +from qlib.backtest import collect_data class MultiLevelTradingWorkflow: @@ -61,17 +61,17 @@ class MultiLevelTradingWorkflow: } trade_start_time = "2017-01-01" - trade_end_time = "2017-02-01" + trade_end_time = "2020-08-01" port_analysis_config = { "executor": { - "class": "SplitExecutor", - "module_path": "qlib.contrib.backtest.executor", + "class": "NestedExecutor", + "module_path": "qlib.backtest.executor", "kwargs": { "time_per_step": "week", "inner_executor": { "class": "SimulatorExecutor", - "module_path": "qlib.contrib.backtest.executor", + "module_path": "qlib.backtest.executor", "kwargs": { "time_per_step": "day", "verbose": True, diff --git a/examples/workflow_by_code.ipynb b/examples/workflow_by_code.ipynb index 1dda1c621..b4da1bfe4 100644 --- a/examples/workflow_by_code.ipynb +++ b/examples/workflow_by_code.ipynb @@ -66,7 +66,6 @@ "from qlib.config import REG_CN\n", "from qlib.contrib.model.gbdt import LGBModel\n", "from qlib.contrib.data.handler import Alpha158\n", - "from qlib.contrib.strategy.strategy import TopkDropoutStrategy\n", "from qlib.contrib.evaluate import (\n", " backtest as normal_backtest,\n", " risk_analysis,\n", diff --git a/examples/workflow_by_code.py b/examples/workflow_by_code.py index d5dab8917..92ce6aa34 100644 --- a/examples/workflow_by_code.py +++ b/examples/workflow_by_code.py @@ -1,19 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import sys -from pathlib import Path - import qlib -import pandas as pd from qlib.config import REG_CN -from qlib.contrib.model.gbdt import LGBModel -from qlib.contrib.data.handler import Alpha158 -from qlib.contrib.strategy.strategy import TopkDropoutStrategy -from qlib.contrib.evaluate import ( - backtest as normal_backtest, - risk_analysis, -) from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict from qlib.workflow import R from qlib.workflow.record_temp import SignalRecord, PortAnaRecord diff --git a/qlib/contrib/backtest/__init__.py b/qlib/backtest/__init__.py similarity index 96% rename from qlib/contrib/backtest/__init__.py rename to qlib/backtest/__init__.py index effab026b..12db0a314 100644 --- a/qlib/contrib/backtest/__init__.py +++ b/qlib/backtest/__init__.py @@ -7,10 +7,10 @@ from .executor import BaseExecutor from .backtest import backtest as backtest_func from .backtest import collect_data as data_generator -from ...strategy.base import BaseStrategy -from ...utils import init_instance_by_config -from ...log import get_module_logger -from ...config import C +from ..strategy.base import BaseStrategy +from ..utils import init_instance_by_config +from ..log import get_module_logger +from ..config import C logger = get_module_logger("backtest caller") diff --git a/qlib/contrib/backtest/account.py b/qlib/backtest/account.py similarity index 99% rename from qlib/contrib/backtest/account.py rename to qlib/backtest/account.py index c7571bc98..dfe248c68 100644 --- a/qlib/contrib/backtest/account.py +++ b/qlib/backtest/account.py @@ -24,6 +24,8 @@ rtn & earning in the Account **is consider cost** while earning is the difference of two position value, so it considers cost, it is the true return rate in the specific accomplishment for rtn, it does not consider cost, in other words, rtn - cost = earning + +Now rtn has been removed in the hierarchical backtest implemention. """ diff --git a/qlib/contrib/backtest/backtest.py b/qlib/backtest/backtest.py similarity index 100% rename from qlib/contrib/backtest/backtest.py rename to qlib/backtest/backtest.py diff --git a/qlib/contrib/backtest/exchange.py b/qlib/backtest/exchange.py similarity index 98% rename from qlib/contrib/backtest/exchange.py rename to qlib/backtest/exchange.py index 09b7f2a63..de2df98be 100644 --- a/qlib/contrib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -8,11 +8,11 @@ import logging import numpy as np import pandas as pd -from ...data.data import D -from ...data.dataset.utils import get_level_index -from ...config import C, REG_CN -from ...utils.resam import resam_ts_data -from ...log import get_module_logger +from ..data.data import D +from ..data.dataset.utils import get_level_index +from ..config import C, REG_CN +from ..utils.resam import resam_ts_data +from ..log import get_module_logger from .order import Order @@ -35,7 +35,7 @@ class Exchange: """__init__ :param freq: frequency of data - :param start_time: closed start time for backtest + :param start_time: closed start time for backtest :param end_time: closed end time for backtest :param codes: list stock_id list or a string of instruments(i.e. all, csi500, sse50) :param deal_price: str, 'close', 'open', 'vwap' diff --git a/qlib/contrib/backtest/executor.py b/qlib/backtest/executor.py similarity index 89% rename from qlib/contrib/backtest/executor.py rename to qlib/backtest/executor.py index c896f802d..88a219f41 100644 --- a/qlib/contrib/backtest/executor.py +++ b/qlib/backtest/executor.py @@ -3,8 +3,8 @@ import warnings import pandas as pd from typing import Union -from ...utils import init_instance_by_config -from ...utils.resam import parse_freq +from ..utils import init_instance_by_config +from ..utils.resam import parse_freq from .order import Order @@ -30,7 +30,7 @@ class BaseExecutor: Parameters ---------- time_per_step : str - trade time per trading step, used for genreate trade calendar + trade time per trading step, used for genreate the trade calendar generate_report : bool, optional whether to generate report, by default False verbose : bool, optional @@ -80,16 +80,18 @@ class BaseExecutor: if "start_time" in kwargs or "end_time" in kwargs: start_time = kwargs.get("start_time") end_time = kwargs.get("end_time") - self.calendar = TradeCalendarManager(freq=self.time_per_step, start_time=start_time, end_time=end_time) + self.trade_calendar = TradeCalendarManager( + freq=self.time_per_step, start_time=start_time, end_time=end_time + ) if common_infra is not None: self.reset_common_infra(common_infra) def get_level_infra(self): - return {"calendar": self.calendar} + return {"trade_calendar": self.trade_calendar} def finished(self): - return self.calendar.finished() + return self.trade_calendar.finished() def execute(self, trade_decision): """execute the trade decision and return the executed result @@ -117,8 +119,13 @@ class BaseExecutor: raise NotImplementedError("get_report is not implemented!") -class SplitExecutor(BaseExecutor): - from ...strategy.base import BaseStrategy +class NestedExecutor(BaseExecutor): + """ + Nested Executor with inner strategy and executor + - At each time `execute` is called, it will call the inner strategy and executor to execute the `trade_decision` in a higher frequency env. + """ + + from ..strategy.base import BaseStrategy def __init__( self, @@ -127,10 +134,10 @@ class SplitExecutor(BaseExecutor): inner_strategy: Union[BaseStrategy, dict], start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None, - trade_exchange: Exchange = None, generate_report: bool = False, verbose: bool = False, track_data: bool = False, + trade_exchange: Exchange = None, common_infra: dict = {}, **kwargs, ): @@ -153,7 +160,7 @@ class SplitExecutor(BaseExecutor): inner_strategy, common_infra=common_infra, accept_types=self.BaseStrategy ) - super(SplitExecutor, self).__init__( + super(NestedExecutor, self).__init__( time_per_step=time_per_step, start_time=start_time, end_time=end_time, @@ -173,7 +180,7 @@ class SplitExecutor(BaseExecutor): - reset trade_exchange - reset inner_strategyand inner_executor common infra """ - super(SplitExecutor, self).reset_common_infra(common_infra) + super(NestedExecutor, self).reset_common_infra(common_infra) if self.generate_report and "trade_exchange" in common_infra: self.trade_exchange = common_infra.get("trade_exchange") @@ -182,15 +189,15 @@ class SplitExecutor(BaseExecutor): self.inner_strategy.reset_common_infra(common_infra) def _init_sub_trading(self, trade_decision): - trade_index = self.calendar.get_trade_index() - trade_start_time, trade_end_time = self.calendar.get_calendar_time(trade_index) + trade_step = self.trade_calendar.get_trade_step() + trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step) self.inner_executor.reset(start_time=trade_start_time, end_time=trade_end_time) sub_level_infra = self.inner_executor.get_level_infra() self.inner_strategy.reset(level_infra=sub_level_infra, outer_trade_decision=trade_decision) def _update_trade_account(self): - trade_index = self.calendar.get_trade_index() - trade_start_time, trade_end_time = self.calendar.get_calendar_time(trade_index) + trade_step = self.trade_calendar.get_trade_step() + trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step) self.trade_account.update_bar_count() if self.generate_report: self.trade_account.update_bar_report( @@ -200,7 +207,6 @@ class SplitExecutor(BaseExecutor): ) def execute(self, trade_decision): - self.calendar.step() self._init_sub_trading(trade_decision) execute_result = [] _inner_execute_result = None @@ -210,13 +216,13 @@ class SplitExecutor(BaseExecutor): execute_result.extend(_inner_execute_result) if hasattr(self, "trade_account"): self._update_trade_account() - + self.trade_calendar.step() return execute_result def collect_data(self, trade_decision): if self.track_data: yield trade_decision - self.calendar.step() + self.trade_calendar.step() self._init_sub_trading(trade_decision) execute_result = [] _inner_execute_result = None @@ -240,15 +246,17 @@ class SplitExecutor(BaseExecutor): class SimulatorExecutor(BaseExecutor): + """Executor that simulate the true market""" + def __init__( self, time_per_step: str, start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None, - trade_exchange: Exchange = None, generate_report: bool = False, verbose: bool = False, track_data: bool = False, + trade_exchange: Exchange = None, common_infra: dict = {}, **kwargs, ): @@ -282,9 +290,9 @@ class SimulatorExecutor(BaseExecutor): self.trade_exchange = common_infra.get("trade_exchange") def execute(self, trade_decision): - self.calendar.step() - trade_index = self.calendar.get_trade_index() - trade_start_time, trade_end_time = self.calendar.get_calendar_time(trade_index) + + trade_step = self.trade_calendar.get_trade_step() + trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step) execute_result = [] for order in trade_decision: if self.trade_exchange.check_order(order) is True: @@ -333,7 +341,7 @@ class SimulatorExecutor(BaseExecutor): trade_end_time=trade_end_time, trade_exchange=self.trade_exchange, ) - + self.trade_calendar.step() return execute_result def get_report(self): diff --git a/qlib/contrib/backtest/order.py b/qlib/backtest/order.py similarity index 100% rename from qlib/contrib/backtest/order.py rename to qlib/backtest/order.py diff --git a/qlib/contrib/backtest/position.py b/qlib/backtest/position.py similarity index 100% rename from qlib/contrib/backtest/position.py rename to qlib/backtest/position.py diff --git a/qlib/contrib/backtest/profit_attribution.py b/qlib/backtest/profit_attribution.py similarity index 99% rename from qlib/contrib/backtest/profit_attribution.py rename to qlib/backtest/profit_attribution.py index 20c6f638f..7e1844a6f 100644 --- a/qlib/contrib/backtest/profit_attribution.py +++ b/qlib/backtest/profit_attribution.py @@ -5,8 +5,8 @@ import numpy as np import pandas as pd from .position import Position -from ...data import D -from ...config import C +from ..data import D +from ..config import C import datetime from pathlib import Path diff --git a/qlib/contrib/backtest/report.py b/qlib/backtest/report.py similarity index 97% rename from qlib/contrib/backtest/report.py rename to qlib/backtest/report.py index 3763f5214..c26c46f9d 100644 --- a/qlib/contrib/backtest/report.py +++ b/qlib/backtest/report.py @@ -10,8 +10,8 @@ import warnings from pandas.core.frame import DataFrame -from ...utils.resam import parse_freq, resam_ts_data -from ...data import D +from ..utils.resam import parse_freq, resam_ts_data +from ..data import D class Report: @@ -86,9 +86,9 @@ class Report: try: _temp_result = D.features(_codes, fields, start_time, end_time, freq="day", disk_cache=1) except ValueError: - _temp_result = D.features(_codes, fields, start_time, end_time, freq="minute", disk_cache=1) + _temp_result = D.features(_codes, fields, start_time, end_time, freq="1min", disk_cache=1) elif norm_freq == "minute": - _temp_result = D.features(_codes, fields, start_time, end_time, freq="minute", disk_cache=1) + _temp_result = D.features(_codes, fields, start_time, end_time, freq="1min", disk_cache=1) else: raise ValueError(f"benchmark freq {freq} is not supported") if len(_temp_result) == 0: diff --git a/qlib/backtest/utils.py b/qlib/backtest/utils.py new file mode 100644 index 000000000..fe51c99f3 --- /dev/null +++ b/qlib/backtest/utils.py @@ -0,0 +1,98 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import pandas as pd +from typing import Union + +from ..utils.resam import get_resam_calendar +from ..data.data import Cal + + +class TradeCalendarManager: + """ + Manager for trading calendar + - BaseStrategy and BaseExecutor will use it + """ + + def __init__( + self, freq: str, start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None + ): + """ + Parameters + ---------- + freq : str + frequency of trading calendar, also trade time per trading step + start_time : Union[str, pd.Timestamp], optional + closed start of the trading calendar, by default None + If `start_time` is None, it must be reset before trading. + end_time : Union[str, pd.Timestamp], optional + closed end of the trade time range, by default None + If `end_time` is None, it must be reset before trading. + """ + self.freq = freq + self.start_time = pd.Timestamp(start_time) if start_time else None + self.end_time = pd.Timestamp(end_time) if end_time else None + self._init_trade_calendar(freq=freq, start_time=start_time, end_time=end_time) + + def _init_trade_calendar(self, freq, start_time, end_time): + """ + Reset the trade calendar + - self.trade_len : The total count for trading step + - self.trade_step : The number of trading step finished, self.trade_step can be [0, 1, 2, ..., self.trade_len - 1] + """ + _calendar, freq, freq_sam = get_resam_calendar(freq=freq) + self.trade_calendar = _calendar + _, _, _start_index, _end_index = Cal.locate_index(start_time, end_time, freq=freq, freq_sam=freq_sam) + self.start_index = _start_index + self.end_index = _end_index + self.trade_len = _end_index - _start_index + 1 + self.trade_step = 0 + + def finished(self): + """ + Check if the trading finished + - Should check before calling strategy.generate_decisions and executor.execute + - If self.trade_step >= self.self.trade_len, it means the trading is finished + - If self.trade_step < self.self.trade_len, it means the number of trading step finished is self.trade_step + """ + return self.trade_step >= self.trade_len + + def step(self): + if self.finished(): + raise RuntimeError(f"The calendar is finished, please reset it if you want to call it!") + self.trade_step = self.trade_step + 1 + + def get_freq(self): + return self.freq + + def get_trade_len(self): + return self.trade_len + + def get_trade_step(self): + return self.trade_step + + def get_step_time(self, trade_step=0, shift=0): + """ + Get the time range of trading step + + Parameters + ---------- + trade_step : int, optional + the number of trading step finished, by default 0 + shift : int, optional + shift bars , by default 0 + + Returns + ------- + Tuple[pd.Timestamp, pd.Timestap] + - If shift == 0, return the trading time range + - If shift > 0, return the trading time range of the earlier shift bars + - If shift < 0, return the trading time range of the later shift bar + """ + trade_step = trade_step - shift + calendar_index = self.start_index + trade_step + return self.trade_calendar[calendar_index], self.trade_calendar[calendar_index + 1] - pd.Timedelta(seconds=1) + + def get_all_time(self): + """Get the start_time and end_time for trading""" + return self.start_time, self.end_time diff --git a/qlib/config.py b/qlib/config.py index df28ac939..c3085ae68 100644 --- a/qlib/config.py +++ b/qlib/config.py @@ -149,9 +149,9 @@ _default_config = { "task_db_name": "default_task_db", }, # Shift minute for highfreq minite data, used in backtest - # if min_data_shift == 0, use default market time [9:30, 11:29, 1:30, 2:59] - # if min_data_shift != 0, use shifted market time [9:30, 11:29, 1:30, 2:59] - shift*minute - "min_data_shift": {0}, + # if min_data_shift == 0, use default market time [9:30, 11:29, 1:00, 2:59] + # if min_data_shift != 0, use shifted market time [9:30, 11:29, 1:00, 2:59] - shift*minute + "min_data_shift": 0, } MODE_CONF = { diff --git a/qlib/contrib/backtest/utils.py b/qlib/contrib/backtest/utils.py deleted file mode 100644 index 622816753..000000000 --- a/qlib/contrib/backtest/utils.py +++ /dev/null @@ -1,67 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -import pandas as pd -from typing import Union - -from ...utils.resam import get_resam_calendar -from ...data.data import Cal - - -class TradeCalendarManager: - """ - Manager for trading calendar - - BaseStrategy and BaseExecutor will use it - """ - - def __init__( - self, freq: str, start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None - ): - """ - Parameters - ---------- - freq : str - frequency of trading calendar, also trade time per trading step - start_time : Union[str, pd.Timestamp], optional - closed start of the trading calendar, by default None - If `start_time` is None, it must be reset before trading. - end_time : Union[str, pd.Timestamp], optional - closed end of the trade time range, by default None - If `end_time` is None, it must be reset before trading. - """ - self.freq = freq - self.start_time = pd.Timestamp(start_time) if start_time else None - self.end_time = pd.Timestamp(start_time) if start_time else None - self._init_trade_calendar(freq=freq, start_time=start_time, end_time=end_time) - - def _init_trade_calendar(self, freq, start_time, end_time): - """reset trade calendar""" - _calendar, freq, freq_sam = get_resam_calendar(freq=freq) - self.calendar = _calendar - _, _, _start_index, _end_index = Cal.locate_index(start_time, end_time, freq=freq, freq_sam=freq_sam) - self.start_index = _start_index - self.end_index = _end_index - self.trade_len = _end_index - _start_index + 1 - self.trade_index = 0 - - def finished(self): - return self.trade_index >= self.trade_len - - def step(self): - if self.finished(): - raise RuntimeError(f"The calendar is finished, please reset it if you want to call it!") - self.trade_index = self.trade_index + 1 - - def get_freq(self): - return self.freq - - def get_trade_len(self): - return self.trade_len - - def get_trade_index(self): - return self.trade_index - - def get_calendar_time(self, trade_index=1, shift=0): - trade_index = trade_index - shift - calendar_index = self.start_index + trade_index - return self.calendar[calendar_index - 1], self.calendar[calendar_index] - pd.Timedelta(seconds=1) diff --git a/qlib/contrib/evaluate.py b/qlib/contrib/evaluate.py index 59a831f3e..8d4052cdb 100644 --- a/qlib/contrib/evaluate.py +++ b/qlib/contrib/evaluate.py @@ -9,7 +9,7 @@ import numpy as np import pandas as pd import warnings from ..log import get_module_logger -from .backtest import get_exchange, backtest as backtest_func +from ..backtest import get_exchange, backtest as backtest_func from ..utils import get_date_range from ..utils.resam import parse_freq @@ -141,9 +141,7 @@ def backtest(pred, account=1e9, shift=1, benchmark="SH000905", verbose=True, **k whether to print log. """ - warnings.warn( - "this function is deprecated, please use backtest function in qlib.contrib.backtest", DeprecationWarning - ) + warnings.warn("this function is deprecated, please use backtest function in qlib.backtest", DeprecationWarning) report_dict = backtest_func( pred=pred, account=account, shift=shift, benchmark=benchmark, verbose=verbose, return_order=False, **kwargs ) diff --git a/qlib/contrib/report/analysis_position/parse_position.py b/qlib/contrib/report/analysis_position/parse_position.py index c5d48ff8e..1373d902f 100644 --- a/qlib/contrib/report/analysis_position/parse_position.py +++ b/qlib/contrib/report/analysis_position/parse_position.py @@ -4,7 +4,7 @@ import pandas as pd -from ...backtest.profit_attribution import get_stock_weight_df +from ....backtest.profit_attribution import get_stock_weight_df def parse_position(position: dict = None) -> pd.DataFrame: diff --git a/qlib/contrib/report/analysis_position/rank_label.py b/qlib/contrib/report/analysis_position/rank_label.py index 77743b10c..2927f12a2 100644 --- a/qlib/contrib/report/analysis_position/rank_label.py +++ b/qlib/contrib/report/analysis_position/rank_label.py @@ -97,7 +97,7 @@ def rank_label_graph( qcr.analysis_position.rank_label_graph(positions, features_df, pred_df_dates.min(), pred_df_dates.max()) - :param position: position data; **qlib.contrib.backtest.backtest.backtest** result. + :param position: position data; **qlib.backtest.backtest** result. :param label_data: **D.features** result; index is **pd.MultiIndex**, index name is **[instrument, datetime]**; columns names is **[label]**. **The label T is the change from T to T+1**, it is recommended to use ``close``, example: `D.features(D.instruments('csi500'), ['Ref($close, -1)/$close-1'])`. diff --git a/qlib/contrib/strategy/cost_control.py b/qlib/contrib/strategy/cost_control.py index 58e3fccc4..e7f6cce04 100644 --- a/qlib/contrib/strategy/cost_control.py +++ b/qlib/contrib/strategy/cost_control.py @@ -17,6 +17,7 @@ class SoftTopkStrategy(WeightStrategyBase): max_sold_weight=1.0, risk_degree=0.95, buy_method="first_fill", + trade_exchange=None, level_infra={}, common_infra={}, **kwargs, @@ -31,14 +32,14 @@ class SoftTopkStrategy(WeightStrategyBase): average_fill: assign the weight to the stocks rank high averagely. """ super(SoftTopkStrategy, self).__init__( - model, dataset, order_generator_cls_or_obj, level_infra, common_infra, **kwargs + model, dataset, order_generator_cls_or_obj, trade_exchange, level_infra, common_infra, **kwargs ) self.topk = topk self.max_sold_weight = max_sold_weight self.risk_degree = risk_degree self.buy_method = buy_method - def get_risk_degree(self, trade_index=None): + def get_risk_degree(self, trade_step=None): """get_risk_degree Return the proportion of your total value you will used in investment. Dynamically risk_degree will result in Market timing diff --git a/qlib/contrib/strategy/model_strategy.py b/qlib/contrib/strategy/model_strategy.py index d797729be..d563bccea 100644 --- a/qlib/contrib/strategy/model_strategy.py +++ b/qlib/contrib/strategy/model_strategy.py @@ -5,7 +5,7 @@ import pandas as pd from ...utils.resam import resam_ts_data from ...strategy.base import ModelStrategy -from ..backtest.order import Order +from ...backtest.order import Order from .order_generator import OrderGenWInteract @@ -21,6 +21,7 @@ class TopkDropoutStrategy(ModelStrategy): risk_degree=0.95, hold_thresh=1, only_tradable=False, + trade_exchange=None, level_infra={}, common_infra={}, **kwargs, @@ -47,6 +48,9 @@ class TopkDropoutStrategy(ModelStrategy): strategy will make buy sell decision without checking the tradable state of the stock. else: strategy will make decision with the tradable state of the stock info and avoid buy and sell them. + trade_exchange : Exchange + exchange that provides market info, used to deal order and generate report + - If `trade_exchange` is None, self.trade_exchange will be set with common_infra """ super(TopkDropoutStrategy, self).__init__( model, dataset, level_infra=level_infra, common_infra=common_infra, **kwargs @@ -58,6 +62,8 @@ class TopkDropoutStrategy(ModelStrategy): self.risk_degree = risk_degree self.hold_thresh = hold_thresh self.only_tradable = only_tradable + if trade_exchange is not None: + self.trade_exchange = trade_exchange def reset_common_infra(self, common_infra): """ @@ -73,7 +79,7 @@ class TopkDropoutStrategy(ModelStrategy): if "trade_exchange" in common_infra: self.trade_exchange = common_infra.get("trade_exchange") - def get_risk_degree(self, trade_index=None): + def get_risk_degree(self, trade_step=None): """get_risk_degree Return the proportion of your total value you will used in investment. Dynamically risk_degree will result in Market timing. @@ -82,9 +88,10 @@ class TopkDropoutStrategy(ModelStrategy): return self.risk_degree def generate_trade_decision(self, execute_result=None): - trade_index = self.calendar.get_trade_index() - trade_start_time, trade_end_time = self.calendar.get_calendar_time(trade_index) - pred_start_time, pred_end_time = self.calendar.get_calendar_time(trade_index, shift=1) + # get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1] + trade_step = self.trade_calendar.get_trade_step() + trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step) + pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1) pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last") if pred_score is None: return [] @@ -179,7 +186,7 @@ class TopkDropoutStrategy(ModelStrategy): continue if code in sell: # check hold limit - time_per_step = self.calendar.get_freq() + time_per_step = self.trade_calendar.get_freq() if current_temp.get_stock_count(code, bar=time_per_step) < self.hold_thresh: continue # sell order @@ -243,6 +250,7 @@ class WeightStrategyBase(ModelStrategy): model, dataset, order_generator_cls_or_obj=OrderGenWInteract, + trade_exchange=None, level_infra={}, common_infra={}, **kwargs, @@ -254,6 +262,8 @@ class WeightStrategyBase(ModelStrategy): self.order_generator = order_generator_cls_or_obj() else: self.order_generator = order_generator_cls_or_obj + if trade_exchange is not None: + self.trade_exchange = trade_exchange def reset_common_infra(self, common_infra): """ @@ -269,7 +279,7 @@ class WeightStrategyBase(ModelStrategy): if "trade_exchange" in common_infra: self.trade_exchange = common_infra.get("trade_exchange") - def get_risk_degree(self, trade_index=None): + def get_risk_degree(self, trade_step=None): """get_risk_degree Return the proportion of your total value you will used in investment. Dynamically risk_degree will result in Market timing. @@ -307,9 +317,11 @@ class WeightStrategyBase(ModelStrategy): """ # generate_trade_decision # generate_target_weight_position() and generate_order_list_from_target_weight_position() to generate order_list - trade_index = self.calendar.get_trade_index() - trade_start_time, trade_end_time = self.calendar.get_calendar_time(trade_index) - pred_start_time, pred_end_time = self.calendar.get_calendar_time(trade_index, shift=1) + + # get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1] + trade_step = self.trade_calendar.get_trade_step() + trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step) + pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1) pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last") if pred_score is None: return [] @@ -320,7 +332,7 @@ class WeightStrategyBase(ModelStrategy): order_list = self.order_generator.generate_order_list_from_target_weight_position( current=current_temp, trade_exchange=self.trade_exchange, - risk_degree=self.get_risk_degree(trade_index), + risk_degree=self.get_risk_degree(trade_step), target_weight_position=target_weight_position, pred_start_time=pred_start_time, pred_end_time=pred_end_time, diff --git a/qlib/contrib/strategy/order_generator.py b/qlib/contrib/strategy/order_generator.py index db2c1de0d..d3e94551a 100644 --- a/qlib/contrib/strategy/order_generator.py +++ b/qlib/contrib/strategy/order_generator.py @@ -4,8 +4,8 @@ """ This order generator is for strategies based on WeightStrategyBase """ -from ..backtest.position import Position -from ..backtest.exchange import Exchange +from ...backtest.position import Position +from ...backtest.exchange import Exchange import pandas as pd import copy diff --git a/qlib/contrib/strategy/rule_strategy.py b/qlib/contrib/strategy/rule_strategy.py index 1f42c451c..24873caae 100644 --- a/qlib/contrib/strategy/rule_strategy.py +++ b/qlib/contrib/strategy/rule_strategy.py @@ -3,13 +3,35 @@ import warnings from ...utils.resam import resam_ts_data from ...data.data import D from ...data.dataset.utils import convert_index_format -from ...strategy.base import RuleStrategy -from ..backtest.order import Order +from ...strategy.base import BaseStrategy +from ...backtest.order import Order +from ...backtest.exchange import Exchange -class TWAPStrategy(RuleStrategy): +class TWAPStrategy(BaseStrategy): """TWAP Strategy for trading""" + def __init__( + self, + outer_trade_decision: object = None, + trade_exchange: Exchange = None, + level_infra: dict = {}, + common_infra: dict = {}, + ): + """ + Parameters + ---------- + trade_exchange : Exchange + exchange that provides market info, used to deal order and generate report + - If `trade_exchange` is None, self.trade_exchange will be set with common_infra + """ + super(TWAPStrategy, self).__init__( + outer_trade_decision=outer_trade_decision, level_infra=level_infra, common_infra=common_infra + ) + + if trade_exchange is not None: + self.trade_exchange = trade_exchange + def reset_common_infra(self, common_infra): """ Parameters @@ -44,9 +66,11 @@ class TWAPStrategy(RuleStrategy): for order, _, _, _ in execute_result: self.trade_amount[(order.stock_id, order.direction)] -= order.deal_amount - trade_index = self.calendar.get_trade_index() - trade_len = self.calendar.get_trade_len() - trade_start_time, trade_end_time = self.calendar.get_calendar_time(trade_index) + # get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1] + trade_step = self.trade_calendar.get_trade_step() + # get the total count of trading step + trade_len = self.trade_calendar.get_trade_len() + trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step) order_list = [] for order in self.outer_trade_decision: if not self.trade_exchange.is_stock_tradable( @@ -57,21 +81,21 @@ class TWAPStrategy(RuleStrategy): _order_amount = None # consider trade unit if _amount_trade_unit is None: - # split the order equally - _order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_index + 1) + # divide the order equally + _order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_step + 1) # without considering trade unit elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit: - # split the order equally - # floor((trade_unit_cnt + trade_len - trade_index) / (trade_len - trade_index + 1)) == ceil(trade_unit_cnt / (trade_len - trade_index + 1)) + # divide the order equally + # floor((trade_unit_cnt + trade_len - trade_step) / (trade_len - trade_step + 1)) == ceil(trade_unit_cnt / (trade_len - trade_step + 1)) trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit) _order_amount = ( - (trade_unit_cnt + trade_len - trade_index) // (trade_len - trade_index + 1) * _amount_trade_unit + (trade_unit_cnt + trade_len - trade_step) // (trade_len - trade_step + 1) * _amount_trade_unit ) if order.direction == order.SELL: # sell all amount at last if self.trade_amount[(order.stock_id, order.direction)] > 1e-5 and ( - _order_amount is None or trade_index == trade_len + _order_amount is None or trade_step == trade_len ): _order_amount = self.trade_amount[(order.stock_id, order.direction)] @@ -89,7 +113,7 @@ class TWAPStrategy(RuleStrategy): return order_list -class SBBStrategyBase(RuleStrategy): +class SBBStrategyBase(BaseStrategy): """ (S)elect the (B)etter one among every two adjacent trading (B)ars to sell or buy. """ @@ -98,6 +122,27 @@ class SBBStrategyBase(RuleStrategy): TREND_SHORT = 1 TREND_LONG = 2 + def __init__( + self, + outer_trade_decision: object = None, + trade_exchange: Exchange = None, + level_infra: dict = {}, + common_infra: dict = {}, + ): + """ + Parameters + ---------- + trade_exchange : Exchange + exchange that provides market info, used to deal order and generate report + - If `trade_exchange` is None, self.trade_exchange will be set with common_infra + """ + super(SBBStrategyBase, self).__init__( + outer_trade_decision=outer_trade_decision, level_infra=level_infra, common_infra=common_infra + ) + + if trade_exchange is not None: + self.trade_exchange = trade_exchange + def reset_common_infra(self, common_infra): super(SBBStrategyBase, self).reset_common_infra(common_infra) if common_infra is not None: @@ -132,15 +177,17 @@ class SBBStrategyBase(RuleStrategy): if execute_result is not None: for order, _, _, _ in execute_result: self.trade_amount[(order.stock_id, order.direction)] -= order.deal_amount - trade_index = self.calendar.get_trade_index() - trade_len = self.calendar.get_trade_len() - trade_start_time, trade_end_time = self.calendar.get_calendar_time(trade_index) - pred_start_time, pred_end_time = self.calendar.get_calendar_time(trade_index, shift=1) + # get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1] + trade_step = self.trade_calendar.get_trade_step() + # get the total count of trading step + trade_len = self.trade_calendar.get_trade_len() + trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step) + pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1) order_list = [] # for each order in in self.outer_trade_decision for order in self.outer_trade_decision: # predict the price trend - if trade_index % 2 == 1: + if trade_step % 2 == 0: _pred_trend = self._pred_price_trend(order.stock_id, pred_start_time, pred_end_time) else: _pred_trend = self.trade_trend[(order.stock_id, order.direction)] @@ -148,7 +195,7 @@ class SBBStrategyBase(RuleStrategy): if not self.trade_exchange.is_stock_tradable( stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time ): - if trade_index % 2 == 1: + if trade_step % 2 == 0: self.trade_trend[(order.stock_id, order.direction)] = _pred_trend continue # get amount of one trade unit @@ -157,21 +204,21 @@ class SBBStrategyBase(RuleStrategy): _order_amount = None # considering trade unit if _amount_trade_unit is None: - # split the order equally - _order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_index + 1) + # divide the order equally + _order_amount = self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_step) # without considering trade unit elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit: # cal how many trade unit trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit) - # split the order equally - # floor((trade_unit_cnt + trade_len - trade_index) / (trade_len - trade_index + 1)) == ceil(trade_unit_cnt / (trade_len - trade_index + 1)) + # divide the order equally + # floor((trade_unit_cnt + trade_len - trade_step - 1) / (trade_len - trade_step)) == ceil(trade_unit_cnt / (trade_len - trade_step)) _order_amount = ( - (trade_unit_cnt + trade_len - trade_index) // (trade_len - trade_index + 1) * _amount_trade_unit + (trade_unit_cnt + trade_len - trade_step - 1) // (trade_len - trade_step) * _amount_trade_unit ) if order.direction == order.SELL: # sell all amount at last if self.trade_amount[(order.stock_id, order.direction)] > 1e-5 and ( - _order_amount is None or trade_index == trade_len + _order_amount is None or trade_step == trade_len - 1 ): _order_amount = self.trade_amount[(order.stock_id, order.direction)] @@ -190,31 +237,31 @@ class SBBStrategyBase(RuleStrategy): _order_amount = None # considering trade unit if _amount_trade_unit is None: - # N trade day last, split the order into N + 1 parts, and trade 2 parts + # N trade day left, divide the order into N + 1 parts, and trade 2 parts _order_amount = ( - 2 * self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_index + 2) + 2 * self.trade_amount[(order.stock_id, order.direction)] / (trade_len - trade_step + 1) ) # without considering trade unit elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit: # cal how many trade unit trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit) - # N trade day last, split the order into N + 1 parts, and trade 2 parts + # N trade day left, divide the order into N + 1 parts, and trade 2 parts _order_amount = ( - (trade_unit_cnt + trade_len - trade_index + 1) - // (trade_len - trade_index + 2) + (trade_unit_cnt + trade_len - trade_step) + // (trade_len - trade_step + 1) * 2 * _amount_trade_unit ) if order.direction == order.SELL: # sell all amount at last if self.trade_amount[(order.stock_id, order.direction)] >= 1e-5 and ( - _order_amount is None or trade_index == trade_len + _order_amount is None or trade_step == trade_len - 1 ): _order_amount = self.trade_amount[(order.stock_id, order.direction)] if _order_amount: _order_amount = min(_order_amount, self.trade_amount[(order.stock_id, order.direction)]) - if trade_index % 2 == 1: + if trade_step % 2 == 0: # in the first of two adjacent bar # if look short on the price, sell the stock more # if look long on the price, sell the stock more @@ -253,7 +300,7 @@ class SBBStrategyBase(RuleStrategy): ) order_list.append(_order) - if trade_index % 2 == 1: + if trade_step % 2 == 0: self.trade_trend[(order.stock_id, order.direction)] = _pred_trend return order_list @@ -269,6 +316,7 @@ class SBBStrategyEMA(SBBStrategyBase): outer_trade_decision=[], instruments="csi300", freq="day", + trade_exchange: Exchange = None, level_infra={}, common_infra={}, **kwargs, @@ -288,13 +336,13 @@ class SBBStrategyEMA(SBBStrategyBase): if isinstance(instruments, str): self.instruments = D.instruments(instruments) self.freq = freq - super(SBBStrategyEMA, self).__init__(outer_trade_decision, level_infra, common_infra, **kwargs) + super(SBBStrategyEMA, self).__init__(outer_trade_decision, trade_exchange, level_infra, common_infra, **kwargs) def _reset_signal(self): - trade_len = self.calendar.get_trade_len() + trade_len = self.trade_calendar.get_trade_len() fields = ["EMA($close, 10)-EMA($close, 20)"] - signal_start_time, _ = self.calendar.get_calendar_time(trade_index=1, shift=1) - _, signal_end_time = self.calendar.get_calendar_time(trade_index=trade_len, shift=1) + signal_start_time, _ = self.trade_calendar.get_step_time(trade_step=0, shift=1) + _, signal_end_time = self.trade_calendar.get_step_time(trade_step=trade_len - 1, shift=1) signal_df = D.features( self.instruments, fields, start_time=signal_start_time, end_time=signal_end_time, freq=self.freq ) @@ -314,8 +362,8 @@ class SBBStrategyEMA(SBBStrategyBase): else: self.level_infra.update(level_infra) - if "calendar" in level_infra: - self.calendar = level_infra.get("calendar") + if "trade_calendar" in level_infra: + self.trade_calendar = level_infra.get("trade_calendar") self._reset_signal() def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None): diff --git a/qlib/data/data.py b/qlib/data/data.py index 394c3271e..9c61c225a 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -775,7 +775,7 @@ class ClientCalendarProvider(CalendarProvider): def calendar(self, start_time=None, end_time=None, freq="day", freq_sam=None, future=False): self.conn.send_request( - request_type="calendar", + request_type="trade_calendar", request_content={ "start_time": str(start_time), "end_time": str(end_time), @@ -990,7 +990,7 @@ class LocalProvider(BaseProvider): :param type: The type of resource for the uri :param **kwargs: """ - if type == "calendar": + if type == "trade_calendar": return Cal._uri(**kwargs) elif type == "instrument": return Inst._uri(**kwargs) diff --git a/qlib/rl/env.py b/qlib/rl/env.py index faf9c026e..3a77d2295 100644 --- a/qlib/rl/env.py +++ b/qlib/rl/env.py @@ -3,8 +3,9 @@ from typing import Union + +from ..backtest.executor import BaseExecutor from .interpreter import StateInterpreter, ActionInterpreter -from ..contrib.backtest.executor import BaseExecutor from ..utils import init_instance_by_config from .interpreter import BaseInterpreter diff --git a/qlib/strategy/base.py b/qlib/strategy/base.py index 59d9d72e3..7828db609 100644 --- a/qlib/strategy/base.py +++ b/qlib/strategy/base.py @@ -1,15 +1,9 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import copy -import pandas as pd -from typing import List, Union - - from ..model.base import BaseModel from ..data.dataset import DatasetH from ..data.dataset.utils import convert_index_format -from ..contrib.backtest.order import Order from ..rl.interpreter import ActionInterpreter, StateInterpreter from ..utils import init_instance_by_config @@ -44,8 +38,8 @@ class BaseStrategy: else: self.level_infra.update(level_infra) - if "calendar" in level_infra: - self.calendar = level_infra.get("calendar") + if "trade_calendar" in level_infra: + self.trade_calendar = level_infra.get("trade_calendar") def reset_common_infra(self, common_infra): if not hasattr(self, "common_infra"): @@ -83,12 +77,6 @@ class BaseStrategy: raise NotImplementedError("generate_trade_decision is not implemented!") -class RuleStrategy(BaseStrategy): - """Rule-based Trading strategy""" - - pass - - class ModelStrategy(BaseStrategy): """Model-based trading strategy, use model to make predictions for trading""" diff --git a/qlib/utils/resam.py b/qlib/utils/resam.py index b121b6130..cdac48533 100644 --- a/qlib/utils/resam.py +++ b/qlib/utils/resam.py @@ -40,7 +40,7 @@ def parse_freq(freq: str) -> Tuple[int, str]: raise ValueError( "freq format is not supported, the freq should be like (n)month/mon, (n)week/w, (n)day/d, (n)minute/min" ) - _count = int(match_obj.group(1)) if match_obj.group(1) is None else 1 + _count = int(match_obj.group(1)) if match_obj.group(1) else 1 _freq = match_obj.group(2) _freq_format_dict = { "month": "month", @@ -58,7 +58,8 @@ def parse_freq(freq: str) -> Tuple[int, str]: def resam_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> np.ndarray: """ Resample the calendar with frequency freq_raw into the calendar with frequency freq_sam - Assumption: The fix length (240) of the calendar in each day. + Assumption: + - Fix length (240) of the calendar in each day. Parameters ---------- @@ -83,16 +84,19 @@ def resam_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> np if freq_sam == "minute": def cal_sam_minute(x, sam_minutes): + """ + Sample raw calendar into calendar with sam_minutes freq, shift represents the shift minute the market time + - open time of stock market is [9:30 - shift*pd.Timedelta(minutes=1)] + - mid close time of stock market is [11:29 - shift*pd.Timedelta(minutes=1)] + - mid open time of stock market is [13:00 - shift*pd.Timedelta(minutes=1)] + - close time of stock market is [14:59 - shift*pd.Timedelta(minutes=1)] + """ day_time = pd.Timestamp(x.date()) shift = C.min_data_shift - # shift represents the shift minute the market time - # - open time of stock market is [9:30 - shift*pd.Timedelta(minutes=1)] - # - mid close time of stock market is [11:29 - shift*pd.Timedelta(minutes=1)] - # - mid open time of stock market is [13:30 - shift*pd.Timedelta(minutes=1)] - # - close time of stock market is [14:59 - shift*pd.Timedelta(minutes=1)] + open_time = day_time + pd.Timedelta(hours=9, minutes=30) - shift * pd.Timedelta(minutes=1) mid_close_time = day_time + pd.Timedelta(hours=11, minutes=29) - shift * pd.Timedelta(minutes=1) - mid_open_time = day_time + pd.Timedelta(hours=13, minutes=30) - shift * pd.Timedelta(minutes=1) + mid_open_time = day_time + pd.Timedelta(hours=13, minutes=00) - shift * pd.Timedelta(minutes=1) close_time = day_time + pd.Timedelta(hours=14, minutes=59) - shift * pd.Timedelta(minutes=1) if open_time <= x <= mid_close_time: @@ -101,7 +105,6 @@ def resam_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> np minute_index = (x - mid_open_time).seconds // 60 + 120 else: raise ValueError("datetime of calendar is out of range") - minute_index = minute_index // sam_minutes * sam_minutes if 0 <= minute_index < 120: @@ -109,7 +112,7 @@ def resam_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> np elif 120 <= minute_index < 240: return mid_open_time + (minute_index - 120) * pd.Timedelta(minutes=1) else: - raise ValueError("calendar minute_index error") + raise ValueError("calendar minute_index error, check `min_data_shift` in qlib.config.C") if freq_raw != "minute": raise ValueError("when sampling minute calendar, freq of raw calendar must be minute or min") @@ -189,11 +192,13 @@ def get_resam_calendar( freq = "day" except ValueError: _calendar = Cal.calendar( - start_time=start_time, end_time=end_time, freq="min", freq_sam=freq, future=future + start_time=start_time, end_time=end_time, freq="1min", freq_sam=freq, future=future ) freq = "min" elif norm_freq == "minute": - _calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq="min", freq_sam=freq, future=future) + _calendar = Cal.calendar( + start_time=start_time, end_time=end_time, freq="1min", freq_sam=freq, future=future + ) freq = "min" else: raise ValueError(f"freq {freq} is not supported") diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index a32ef9729..8abcd6c14 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -8,10 +8,10 @@ import pandas as pd from pathlib import Path from pprint import pprint from ..contrib.evaluate import risk_analysis -from ..contrib.backtest import backtest as normal_backtest from ..data.dataset import DatasetH from ..data.dataset.handler import DataHandlerLP +from ..backtest import backtest as normal_backtest from ..utils import init_instance_by_config, get_module_by_module_path from ..log import get_module_logger from ..utils import flatten_dict