From b41267fa593a83047713c4b099541dc640ecfb4b Mon Sep 17 00:00:00 2001 From: Young Date: Fri, 25 Jun 2021 20:12:39 +0000 Subject: [PATCH] successful run random order gen in day script --- .../nested_decision_execution/workflow.py | 4 +- qlib/backtest/account.py | 17 +- qlib/backtest/backtest.py | 23 +- qlib/backtest/exchange.py | 27 ++- qlib/backtest/executor.py | 17 +- qlib/backtest/order.py | 199 +++++++++++++++++- qlib/backtest/position.py | 21 ++ qlib/backtest/report.py | 9 +- qlib/backtest/utils.py | 188 ----------------- qlib/contrib/evaluate.py | 12 +- qlib/contrib/strategy/model_strategy.py | 7 +- qlib/contrib/strategy/order_generator.py | 6 +- qlib/contrib/strategy/rule_strategy.py | 96 +++++---- qlib/strategy/base.py | 31 +-- qlib/utils/resam.py | 79 ++----- qlib/utils/time.py | 115 ++++++++++ qlib/workflow/record_temp.py | 8 +- 17 files changed, 505 insertions(+), 354 deletions(-) create mode 100644 qlib/utils/time.py diff --git a/examples/nested_decision_execution/workflow.py b/examples/nested_decision_execution/workflow.py index a44aee4ca..b6c1362fd 100644 --- a/examples/nested_decision_execution/workflow.py +++ b/examples/nested_decision_execution/workflow.py @@ -13,7 +13,7 @@ from qlib.tests.data import GetData from qlib.backtest import collect_data -class NestedDecisonExecutionWorkflow: +class NestedDecisionExecutionWorkflow: market = "csi300" benchmark = "SH000300" @@ -229,4 +229,4 @@ class NestedDecisonExecutionWorkflow: if __name__ == "__main__": - fire.Fire(NestedDecisonExecutionWorkflow) + fire.Fire(NestedDecisionExecutionWorkflow) diff --git a/qlib/backtest/account.py b/qlib/backtest/account.py index be1c25f95..64a814dba 100644 --- a/qlib/backtest/account.py +++ b/qlib/backtest/account.py @@ -76,7 +76,7 @@ class Account: 'kwargs': { "cash": init_cash }, - 'model_path': "qlib.backtest.position", + 'module_path': "qlib.backtest.position", }) self.accum_info = AccumulatedInfo() self.reset(freq=freq, benchmark_config=benchmark_config, init_report=True) @@ -164,13 +164,14 @@ class Account: def update_current(self, trade_start_time, trade_end_time, trade_exchange): """update current to make rtn consistent with earning at the end of bar""" # update price for stock in the position and the profit from changed_price - stock_list = self.current.get_stock_list() - for code in stock_list: - # if suspend, no new price to be updated, profit is 0 - if trade_exchange.check_stock_suspended(code, trade_start_time, trade_end_time): - continue - bar_close = trade_exchange.get_close(code, trade_start_time, trade_end_time) - self.current.update_stock_price(stock_id=code, price=bar_close) + if not self.current.skip_update(): + stock_list = self.current.get_stock_list() + for code in stock_list: + # if suspend, no new price to be updated, profit is 0 + if trade_exchange.check_stock_suspended(code, trade_start_time, trade_end_time): + continue + bar_close = trade_exchange.get_close(code, trade_start_time, trade_end_time) + self.current.update_stock_price(stock_id=code, price=bar_close) def update_report(self, trade_start_time, trade_end_time): """update position history, report""" diff --git a/qlib/backtest/backtest.py b/qlib/backtest/backtest.py index 18573115b..6ab17c5c5 100644 --- a/qlib/backtest/backtest.py +++ b/qlib/backtest/backtest.py @@ -1,13 +1,14 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from qlib.backtest.utils import TradeDecison +from qlib.backtest.order import BaseTradeDecision from qlib.strategy.base import BaseStrategy from qlib.backtest.executor import BaseExecutor -from ..utils.resam import parse_freq +from ..utils.time import Freq +from tqdm.auto import tqdm def backtest_loop(start_time, end_time, trade_strategy: BaseStrategy, trade_executor: BaseExecutor): - """backtest funciton for the interaction of the outermost strategy and executor in the nested decison execution + """backtest funciton for the interaction of the outermost strategy and executor in the nested decision execution Returns ------- @@ -15,7 +16,7 @@ def backtest_loop(start_time, end_time, trade_strategy: BaseStrategy, trade_exec it records the trading report information """ return_value = {} - for _decison in collect_data_loop(start_time, end_time, trade_strategy, trade_executor, return_value): + for _decision in collect_data_loop(start_time, end_time, trade_strategy, trade_executor, return_value): pass return return_value.get("report"), return_value.get("indicator") @@ -45,22 +46,24 @@ def collect_data_loop(start_time, end_time, trade_strategy: BaseStrategy, trade_ level_infra = trade_executor.get_level_infra() trade_strategy.reset(level_infra=level_infra) - _execute_result = None - while not trade_executor.finished(): - _trade_decision: TradeDecison = trade_strategy.generate_trade_decision(_execute_result) - _execute_result = yield from trade_executor.collect_data(_trade_decision) + with tqdm(total=trade_executor.trade_calendar.get_trade_len(), desc="backtest loop") as bar: + _execute_result = None + while not trade_executor.finished(): + _trade_decision: BaseTradeDecision = trade_strategy.generate_trade_decision(_execute_result) + _execute_result = yield from trade_executor.collect_data(_trade_decision) + bar.update(trade_executor.trade_calendar.get_trade_step()) if return_value is not None: all_executors = trade_executor.get_all_executors() all_reports = { - "{}{}".format(*parse_freq(_executor.time_per_step)): _executor.get_report() + "{}{}".format(*Freq.parse(_executor.time_per_step)): _executor.get_report() for _executor in all_executors if _executor.generate_report } all_indicators = { "{}{}".format( - *parse_freq(_executor.time_per_step) + *Freq.parse(_executor.time_per_step) ): _executor.get_trade_indicator().generate_trade_indicators_dataframe() for _executor in all_executors } diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py index 06ecbaa5b..cffa98ba6 100644 --- a/qlib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -4,6 +4,7 @@ import random import logging +from typing import Union import numpy as np import pandas as pd @@ -259,6 +260,16 @@ class Exchange: return trade_val, trade_cost, trade_price + def create_order(self, code, amount, start_time, end_time, direction) -> Order: + return Order( + stock_id=code, + amount=amount, + start_time=start_time, + end_time=end_time, + direction=direction, + factor=self.get_factor(code, start_time, end_time), + ) + def get_quote_info(self, stock_id, start_time, end_time): return resam_ts_data(self.quote[stock_id], start_time, end_time, method="last").iloc[0] @@ -278,8 +289,20 @@ class Exchange: deal_price = self.get_close(stock_id, start_time, end_time) return deal_price - def get_factor(self, stock_id, start_time, end_time): - return resam_ts_data(self.quote[stock_id]["$factor"], start_time, end_time, method="last").iloc[0] + def get_factor(self, stock_id, start_time, end_time) -> Union[float, None]: + """ + Returns + ------- + Union[float, None]: + `None`: if the stock is suspended `None` may be returned + `float`: return factor if the factor exists + """ + if stock_id not in self.quote: + return None + res = resam_ts_data(self.quote[stock_id]["$factor"], start_time, end_time, method="last") + if res is not None: + res = res.iloc[0] + return res def generate_amount_position_from_weight_position(self, weight_position, cash, start_time, end_time): """ diff --git a/qlib/backtest/executor.py b/qlib/backtest/executor.py index bc4831f32..b6d16d58f 100644 --- a/qlib/backtest/executor.py +++ b/qlib/backtest/executor.py @@ -3,12 +3,12 @@ import warnings import pandas as pd from typing import Union -from .order import Order +from .order import Order, BaseTradeDecision from .exchange import Exchange -from .utils import BaseTradeDecision, TradeCalendarManager, CommonInfrastructure, LevelInfrastructure, TradeDecison +from .utils import TradeCalendarManager, CommonInfrastructure, LevelInfrastructure from ..utils import init_instance_by_config -from ..utils.resam import parse_freq +from ..utils.time import Freq from ..strategy.base import BaseStrategy @@ -135,7 +135,7 @@ class BaseExecutor: Parameters ---------- - trade_decision : TradeDecison + trade_decision : BaseTradeDecision Returns ---------- @@ -149,7 +149,7 @@ class BaseExecutor: Parameters ---------- - trade_decision : TradeDecison + trade_decision : BaseTradeDecision Returns ---------- @@ -261,7 +261,7 @@ class NestedExecutor(BaseExecutor): def execute(self, trade_decision): return_value = {} - for _decison in self.collect_data(trade_decision, return_value): + for _decision in self.collect_data(trade_decision, return_value): pass return return_value.get("execute_result") @@ -358,13 +358,12 @@ class SimulatorExecutor(BaseExecutor): if common_infra.has("trade_exchange"): self.trade_exchange = common_infra.get("trade_exchange") - def execute(self, trade_decision): + def execute(self, trade_decision: BaseTradeDecision): trade_step = self.trade_calendar.get_trade_step() trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step) execute_result = [] - order_generator = trade_decision.generator() - for order in order_generator: + for order in trade_decision.get_decision(): if self.trade_exchange.check_order(order) is True: # execute the order trade_val, trade_cost, trade_price = self.trade_exchange.deal_order( diff --git a/qlib/backtest/order.py b/qlib/backtest/order.py index e4bf41f1e..d1b5f6d08 100644 --- a/qlib/backtest/order.py +++ b/qlib/backtest/order.py @@ -1,8 +1,16 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +# TODO: rename it with decision.py +from __future__ import annotations +# try to fix circular imports when enabling type hints +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from qlib.strategy.base import BaseStrategy +from qlib.backtest.utils import TradeCalendarManager +import warnings import pandas as pd from dataclasses import dataclass, field -from typing import ClassVar +from typing import ClassVar, Union, List, Set, Tuple @dataclass @@ -34,3 +42,192 @@ class Order: if self.direction not in {Order.SELL, Order.BUY}: raise NotImplementedError("direction not supported, `Order.SELL` for sell, `Order.BUY` for buy") self.deal_amount = 0 + + +class BaseTradeDecision: + """ + Trade decisions ara made by strategy and executed by exeuter + + Motivation: + Here are several typical scenarios for `BaseTradeDecision` + + Case 1: + 1. Outer strategy makes a decision. The decision is not available at the start of current interval + 2. After a period of time, the decision are updated and become available + 3. The inner strategy try to get the decision and start to execute the decision according to `get_range_limit` + Case 2: + 1. The strategy is available at the start of the interval + 2. Same as `case 1.3` + """ + def __init__(self, strategy: BaseStrategy): + """ + Parameters + ---------- + strategy : BaseStrategy + The strategy who make the decision + """ + self.strategy = strategy + + def get_decision(self) -> List[object]: + """ + get the **concrete decision** (e.g. execution orders) + This will be called by the inner strategy + + Returns + ------- + List[object]: + The decision result. Typically it is some orders + Example: + []: + Decision not available + concrete_decision: + available + """ + raise NotImplementedError(f"This type of input is not supported") + + def update(self, trade_calendar: TradeCalendarManager) -> Union["BaseTradeDecision", None]: + """ + Be called at the **start** of each step + + Parameters + ---------- + trade_calendar : TradeCalendarManager + The calendar of the **inner strategy**!!!!! + + Returns + ------- + None: + No update, use previous decision(or unavailable) + BaseTradeDecision: + New update, use new decision + """ + return self.strategy.update_trade_decision(self, trade_calendar) + + def get_range_limit(self) -> Tuple[int, int]: + """ + return the expected step range for limiting the decision execution time + Both left and right are **closed** + + Returns + ------- + Tuple[int, int]: + + Raises + ------ + NotImplementedError: + If the decision can't provide a unified start and end + """ + raise NotImplementedError(f"Please implement the `func` method") + + +class TradeDecisionWO(BaseTradeDecision): + """ + Trade Decision (W)ith (O)rder. + Besides, the time_range is also included. + """ + def __init__(self, order_list: List[Order], strategy: BaseStrategy, idx_range: Tuple=None): + super().__init__(strategy) + self.order_list = order_list + self.idx_range = idx_range + + def get_range_limit(self) -> Tuple[int, int]: + if self.idx_range is None: + # Default to get full index + return 0, self.strategy.trade_calendar.get_trade_len() - 1 + return self.idx_range + + def get_decision(self) -> List[object]: + return self.order_list + + +# TODO: the orders below need to be discussed ------------------------------------ +class TradeDecisionWithOrderPool: + """trade decision that made by strategy""" + + def __init__(self, strategy, order_pool): + """ + Parameters + ---------- + strategy : BaseStrategy + the original strategy that make the decision + order_pool : list, optional + the candinate order pool for generate trade decision + """ + super(TradeDecisionWithOrderPool, self).__init__(strategy) + self.order_pool = order_pool + self.order_list = [] + + def pop_order_pool(self, pop_len): + if pop_len > len(self.order_pool): + warnings.warn( + f"pop len {pop_len} is too much length than order pool, cut it as pool length {len(self.order_pool)}" + ) + pop_len = len(self.order_pool) + res = self.order_pool[:pop_len] + del self.order_pool[:pop_len] + return res + + def push_order_list(self, order_list): + self.order_list.extend(order_list) + + def get_decision(self): + """get the order list + + Parameters + ---------- + only_enable : bool, optional + wether to ignore disabled order, by default False + only_disable : bool, optional + wether to ignore enabled order, by default False + Returns + ------- + List[Order] + the order list + """ + return self.order_list + + def update(self, trade_calendar): + """make the original strategy update the enabled status of orders.""" + self.ori_strategy.update_trade_decision(self, trade_calendar) + + +class BaseDecisionUpdater: + def update_decision(self, decision, trade_calendar) -> BaseTradeDecision: + """[summary] + + Parameters + ---------- + decision : BaseTradeDecision + the trade decision to be updated + trade_calendar : BaseTradeCalendar + the trade calendar of inner execution + + Returns + ------- + BaseTradeDecision + the updated decision + """ + raise NotImplementedError(f"This method is not implemented") + + +class DecisionUpdaterWithOrderPool: + def __init__(self, plan_config=None): + """ + Parameters + ---------- + plan_config : Dict[Tuple(int, float)], optional + the plan config, by default None + """ + if plan_config is None: + self.plan_config = [(0, 1)] + else: + self.plan_config = plan_config + + def update_decision(self, decision, trade_calendar) -> BaseTradeDecision: + # get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1] + trade_step = self.trade_calendar.get_trade_step() + for _index, _ratio in self.plan_config: + if trade_step == _index: + pop_len = len(decision.order_pool) * _ratio + pop_order_list = decision.pop_order_pool(pop_len) + decision.push_order_list(pop_order_list) diff --git a/qlib/backtest/position.py b/qlib/backtest/position.py index 6b021c913..70272f688 100644 --- a/qlib/backtest/position.py +++ b/qlib/backtest/position.py @@ -30,6 +30,23 @@ class BasePosition: """ return False + def check_stock(self, stock_id: str) -> bool: + """ + check if is the stock in the position + + Parameters + ---------- + stock_id : str + the id of the stock + + Returns + ------- + bool: + if is the stock in the position + """ + raise NotImplementedError(f"Please implement the `check_stock` method") + + def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float): """ Parameters @@ -393,6 +410,10 @@ class InfPosition(BasePosition): """ Updating state is meaningless for InfPosition """ return True + def check_stock(self, stock_id: str) -> bool: + # InfPosition always have any stocks + return True + def update_order(self, order: Order, trade_val: float, cost: float, trade_price: float): pass diff --git a/qlib/backtest/report.py b/qlib/backtest/report.py index 75b743694..70ebd724e 100644 --- a/qlib/backtest/report.py +++ b/qlib/backtest/report.py @@ -11,7 +11,8 @@ from pandas.core import groupby from pandas.core.frame import DataFrame -from ..utils.resam import parse_freq, resam_ts_data, get_higher_eq_freq_feature +from ..utils.time import Freq +from ..utils.resam import resam_ts_data, get_higher_eq_freq_feature from ..data import D from ..tests.config import CSI300_BENCH @@ -78,6 +79,9 @@ class Report: def _cal_benchmark(self, benchmark_config, freq): benchmark = benchmark_config.get("benchmark", CSI300_BENCH) + if benchmark is None: + return None + if isinstance(benchmark, pd.Series): return benchmark else: @@ -94,6 +98,9 @@ class Report: return _temp_result.groupby(level="datetime")[_temp_result.columns.tolist()[0]].mean().fillna(0) def _sample_benchmark(self, bench, trade_start_time, trade_end_time): + if self.bench is None: + return None + def cal_change(x): return (x + 1).prod() diff --git a/qlib/backtest/utils.py b/qlib/backtest/utils.py index 85d88068a..d2441dd3a 100644 --- a/qlib/backtest/utils.py +++ b/qlib/backtest/utils.py @@ -1,10 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from qlib.backtest.order import Order -from qlib.strategy.base import BaseStrategy -from qlib.backtest.exchange import Exchange -from qlib.backtest.account import Account import pandas as pd import warnings from typing import Tuple, Union, List, Set @@ -150,187 +146,3 @@ class CommonInfrastructure(BaseInfrastructure): class LevelInfrastructure(BaseInfrastructure): def get_support_infra(self): return ["trade_calendar"] - - -class BaseTradeDecision: - # TODO: put it into order.py; and replace it with decision.py - def __init__(self, strategy: BaseStrategy): - self.strategy = strategy - - def get_decision(self) -> List[object]: - """ - get the **concrete decision** (e.g. concrete decision) - This will be called by the inner strategy - - Returns - ------- - List[object]: - The decision result. Typically it is some orders - Example: - []: - Decision not available - concrete_decision: - available - """ - raise NotImplementedError(f"This type of input is not supported") - - def update(self, trade_calendar: TradeCalendarManager) -> "BaseTradeDecison": - """ - Be called at the **start** of each step - - Parameters - ---------- - trade_calendar : TradeCalendarManager - The calendar of the **inner strategy**!!!!! - - Returns - ------- - None: - No update, use previous decision(or unavailable) - BaseTradeDecison: - New update, use new decision - """ - return self.strategy.update_trade_decision(self, trade_calendar) - - def get_range_limit(self) -> Tuple[int, int]: - """ - return the expected step range for limiting the decision execution time - - Returns - ------- - Tuple[int, int]: - - Raises - ------ - NotImplementedError: - If the decision can't provide a unified start and end - """ - raise NotImplementedError(f"Please implement the `func` method") - - -class TradeDecisonWO(BaseTradeDecision): - def __init__(self, order_list: List[Order], strategy: BaseStrategy): - super().__init__(strategy) - self.order_list = order_list - - -class TradeDecison(BaseTradeDecision): - """trade decision that made by strategy""" - - def __init__(self, order_list, ori_strategy, init_enable=False): - """ - Parameters - ---------- - order_list : list - the order list - ori_strategy : BaseStrategy - the original strategy that make the decison - init_enable : bool, optional - wether to enable order initially, default by False - """ - self.order_list = order_list - self.ori_strategy = ori_strategy - if init_enable: - self.enable_dict = {_order.stock_id: _order for _order in self.order_list} - self.disable_dict = dict() - else: - self.enable_dict = dict() - self.disable_dict = {_order.stock_id: _order for _order in self.order_list} - - def enable(self, enable_set: Union[List[str], Set[str]] = None, all_enable=False): - """enable order set - Parameters - ---------- - enable_set : Union[List[str], Set[str]], optional - the order set that will be enabled, by default None - - if all_enable is True, enable_set will be ignored - - else, enable the order whose stock_id in enable_set - all_enable : bool, optional - wether to enable all order, by default False - """ - if all_enable is True: - self.enable_dict.update(self.disable_dict) - self.disable_dict.clear() - if enable_set is not None: - warnings.warn(f"`enable_set` is ignored because `all_enable` is set True") - else: - enable_set = set(enable_set) - for _stock_id in enable_set: - enable_order = self.disable_dict.get(_stock_id) - if enable_order is None: - raise ValueError(f"_stock_id {_stock_id} is not found in disable set") - self.enable_order.update({_stock_id: enable_order}) - self.disable_dict.pop(_stock_id) - - def disable(self, disable_set: Union[List[str], Set[str]] = None, all_disable=False): - """disable order set - Parameters - ---------- - disable_set : Union[List[str], Set[str]], optional - the order set that will be disabled, by default None - - if all_disable is True, disable_set will be ignored - - else, disable the order whose stock_id in disable_set - all_disable : bool, optional - wether to disable all order, by default False - """ - if all_disable is True: - self.disable_dict.update(self.enable_dict) - self.enable_dict.clear() - if disable_set is not None: - warnings.warn(f"`disable_set` is ignored because `all_disable` is set True") - else: - disable_set = set(disable_set) - for _stock_id in disable_set: - disable_order = self.enable_dict.get(_stock_id) - if disable_order is None: - raise ValueError(f"_stock_id {_stock_id} is not found in enable set") - self.disable_dict.update({_stock_id: disable_order}) - self.enable_dict.pop(_stock_id) - - def generator(self, only_enable=False, only_disable=False): - """get order generator used for iteration - Parameters - ---------- - only_enable : bool, optional - wether to ignore disabled order, by default False - only_disable : bool, optional - wether to ignore enabled order, by default False - """ - if not only_disable and not only_enable: - yield from self.order_list - elif not only_disable: - yield from self.enable_dict.values() - elif not only_enable: - yield from self.disable_dict.values() - - def get_order_list(self, only_enable=False, only_disable=False): - """get the order list - - Parameters - ---------- - only_enable : bool, optional - wether to ignore disabled order, by default False - only_disable : bool, optional - wether to ignore enabled order, by default False - Returns - ------- - List[Order] - the order list - """ - if not only_disable and not only_enable: - return self.order_list - elif not only_disable: - return list(self.enable_dict.values()) - elif not only_enable: - return list(self.disable_dict.values()) - - def update(self, trade_calendar: TradeCalendarManager): - """ - make the original strategy update the enabled status of orders. - - Parameters - ---------- - trade_calendar : TradeCalendarManager - the trade calendar for sub strategy - """ - self.ori_strategy.update_trade_decision(self, trade_calendar) diff --git a/qlib/contrib/evaluate.py b/qlib/contrib/evaluate.py index a50be144a..f7728f911 100644 --- a/qlib/contrib/evaluate.py +++ b/qlib/contrib/evaluate.py @@ -11,7 +11,7 @@ import warnings from ..log import get_module_logger from ..backtest import get_exchange, backtest as backtest_func from ..utils import get_date_range -from ..utils.resam import parse_freq, NORM_FREQ_MONTH, NORM_FREQ_WEEK, NORM_FREQ_DAY, NORM_FREQ_MINUTE +from ..utils.resam import Freq from ..data import D from ..config import C @@ -35,12 +35,12 @@ def risk_analysis(r, N: int = None, freq: str = "day"): """ def cal_risk_analysis_scaler(freq): - _count, _freq = parse_freq(freq) + _count, _freq = Freq.parse(freq) _freq_scaler = { - NORM_FREQ_MINUTE: 240 * 252, - NORM_FREQ_DAY: 252, - NORM_FREQ_WEEK: 50, - NORM_FREQ_MONTH: 12, + Freq.NORM_FREQ_MINUTE: 240 * 252, + Freq.NORM_FREQ_DAY: 252, + Freq.NORM_FREQ_WEEK: 50, + Freq.NORM_FREQ_MONTH: 12, } return _freq_scaler[_freq] / _count diff --git a/qlib/contrib/strategy/model_strategy.py b/qlib/contrib/strategy/model_strategy.py index 71f9ee509..14e6f0810 100644 --- a/qlib/contrib/strategy/model_strategy.py +++ b/qlib/contrib/strategy/model_strategy.py @@ -6,8 +6,7 @@ import pandas as pd from ...utils.resam import resam_ts_data from ...strategy.base import ModelStrategy -from ...backtest.order import Order -from ...backtest.utils import TradeDecison +from ...backtest.order import Order, BaseTradeDecision from .order_generator import OrderGenWInteract @@ -247,7 +246,7 @@ class TopkDropoutStrategy(ModelStrategy): factor=factor, ) buy_order_list.append(buy_order) - return TradeDecison(order_list=sell_order_list + buy_order_list, ori_strategy=self) + return TradeDecision(order_list=sell_order_list + buy_order_list, ori_strategy=self) class WeightStrategyBase(ModelStrategy): @@ -344,4 +343,4 @@ class WeightStrategyBase(ModelStrategy): trade_start_time=trade_start_time, trade_end_time=trade_end_time, ) - return TradeDecison(order_list=order_list, ori_strategy=self) + return TradeDecision(order_list=order_list, ori_strategy=self) diff --git a/qlib/contrib/strategy/order_generator.py b/qlib/contrib/strategy/order_generator.py index 7e4ee1a07..f822609c8 100644 --- a/qlib/contrib/strategy/order_generator.py +++ b/qlib/contrib/strategy/order_generator.py @@ -6,7 +6,7 @@ This order generator is for strategies based on WeightStrategyBase """ from ...backtest.position import Position from ...backtest.exchange import Exchange -from ...backtest.utils import TradeDecison +from ...backtest.order import BaseTradeDecision import pandas as pd import copy @@ -127,7 +127,7 @@ class OrderGenWInteract(OrderGenerator): trade_start_time=trade_start_time, trade_end_time=trade_end_time, ) - return TradeDecison(order_list=order_list, ori_strategy=self) + return TradeDecision(order_list=order_list, ori_strategy=self) class OrderGenWOInteract(OrderGenerator): @@ -191,4 +191,4 @@ class OrderGenWOInteract(OrderGenerator): trade_start_time=trade_start_time, trade_end_time=trade_end_time, ) - return TradeDecison(order_list=order_list, ori_strategy=self) + return TradeDecision(order_list=order_list, ori_strategy=self) diff --git a/qlib/contrib/strategy/rule_strategy.py b/qlib/contrib/strategy/rule_strategy.py index c0993f44e..0d44e02a5 100644 --- a/qlib/contrib/strategy/rule_strategy.py +++ b/qlib/contrib/strategy/rule_strategy.py @@ -7,9 +7,9 @@ from ...utils.resam import resam_ts_data from ...data.data import D from ...data.dataset.utils import convert_index_format from ...strategy.base import BaseStrategy -from ...backtest.order import Order +from ...backtest.order import BaseTradeDecision, Order, TradeDecisionWO from ...backtest.exchange import Exchange -from ...backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeDecison +from ...backtest.utils import CommonInfrastructure, LevelInfrastructure class TWAPStrategy(BaseStrategy): @@ -17,7 +17,7 @@ class TWAPStrategy(BaseStrategy): def __init__( self, - outer_trade_decision: TradeDecison = None, + outer_trade_decision: BaseTradeDecision = None, trade_exchange: Exchange = None, level_infra: LevelInfrastructure = None, common_infra: CommonInfrastructure = None, @@ -25,8 +25,8 @@ class TWAPStrategy(BaseStrategy): """ Parameters ---------- - outer_trade_decision : TradeDecison - the trade decison of outer strategy which this startegy relies + outer_trade_decision : BaseTradeDecision + the trade decision of outer strategy which this startegy relies trade_exchange : Exchange exchange that provides market info, used to deal order and generate report - If `trade_exchange` is None, self.trade_exchange will be set with common_infra @@ -57,25 +57,35 @@ class TWAPStrategy(BaseStrategy): if common_infra.has("trade_exchange"): self.trade_exchange = common_infra.get("trade_exchange") - def reset(self, outer_trade_decision: TradeDecison = None, **kwargs): + def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs): """ Parameters ---------- - outer_trade_decision : TradeDecison, optional + outer_trade_decision : BaseTradeDecision, optional """ super(TWAPStrategy, self).reset(outer_trade_decision=outer_trade_decision, **kwargs) if outer_trade_decision is not None: self.trade_amount = {} - outer_order_generator = outer_trade_decision.generator() - for order in outer_order_generator: + for order in outer_trade_decision.get_decision(): self.trade_amount[order.stock_id] = order.amount def generate_trade_decision(self, execute_result=None): + # strategy is not available. Give an empty decision + if len(self.outer_trade_decision.get_decision()) == 0: + return TradeDecisionWO(order_list=[], strategy=self) + # get the number of trading step finished, trade_step can be [0, 1, 2, ..., trade_len - 1] trade_step = self.trade_calendar.get_trade_step() # get the total count of trading step - trade_len = self.trade_calendar.get_trade_len() + start_idx, end_idx = self.outer_trade_decision.get_range_limit() + trade_len = end_idx - start_idx + 1 + + if trade_step < start_idx: + # It is not time to start trading + return TradeDecisionWO(order_list=[], strategy=self) + + rel_trade_step = trade_step - start_idx # trade_step relative to start_idx # update the order amount if execute_result is not None: @@ -84,8 +94,7 @@ class TWAPStrategy(BaseStrategy): trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step) order_list = [] - outer_order_generator = self.outer_trade_decision.generator(only_enable=True) - for order in outer_order_generator: + for order in self.outer_trade_decision.get_decision(): # if not tradable, continue if not self.trade_exchange.is_stock_tradable( stock_id=order.stock_id, start_time=trade_start_time, end_time=trade_end_time @@ -96,21 +105,21 @@ class TWAPStrategy(BaseStrategy): # considering trade unit if _amount_trade_unit is None: # divide the order into equal parts, and trade one part - _order_amount = self.trade_amount[order.stock_id] / (trade_len - trade_step) + _order_amount = self.trade_amount[order.stock_id] / (trade_len - rel_trade_step) # without considering trade unit else: # divide the order into equal parts, and trade one part # calculate the total count of trade units to trade trade_unit_cnt = int(self.trade_amount[order.stock_id] // _amount_trade_unit) # calculate the amount of one part, ceil the amount - # floor((trade_unit_cnt + trade_len - trade_step) / (trade_len - trade_step + 1)) == ceil(trade_unit_cnt / (trade_len - trade_step + 1)) + # floor((trade_unit_cnt + trade_len - rel_trade_step) / (trade_len - rel_trade_step + 1)) == ceil(trade_unit_cnt / (trade_len - rel_trade_step + 1)) _order_amount = ( - (trade_unit_cnt + trade_len - trade_step - 1) // (trade_len - trade_step) * _amount_trade_unit + (trade_unit_cnt + trade_len - rel_trade_step - 1) // (trade_len - rel_trade_step) * _amount_trade_unit ) if order.direction == order.SELL: # sell all amount at last - if self.trade_amount[order.stock_id] > 1e-5 and (_order_amount < 1e-5 or trade_step == trade_len - 1): + if self.trade_amount[order.stock_id] > 1e-5 and (_order_amount < 1e-5 or rel_trade_step == trade_len - 1): _order_amount = self.trade_amount[order.stock_id] _order_amount = min(_order_amount, self.trade_amount[order.stock_id]) @@ -126,7 +135,7 @@ class TWAPStrategy(BaseStrategy): factor=order.factor, ) order_list.append(_order) - return TradeDecison(order_list=order_list, ori_strategy=self) + return TradeDecisionWO(order_list=order_list, strategy=self) class SBBStrategyBase(BaseStrategy): @@ -140,7 +149,7 @@ class SBBStrategyBase(BaseStrategy): def __init__( self, - outer_trade_decision: TradeDecison = None, + outer_trade_decision: BaseTradeDecision = None, trade_exchange: Exchange = None, level_infra: LevelInfrastructure = None, common_infra: CommonInfrastructure = None, @@ -148,8 +157,8 @@ class SBBStrategyBase(BaseStrategy): """ Parameters ---------- - outer_trade_decision : TradeDecison - the trade decison of outer strategy which this startegy relies + outer_trade_decision : BaseTradeDecision + the trade decision of outer strategy which this startegy relies trade_exchange : Exchange exchange that provides market info, used to deal order and generate report - If `trade_exchange` is None, self.trade_exchange will be set with common_infra @@ -178,11 +187,11 @@ class SBBStrategyBase(BaseStrategy): if common_infra.has("trade_exchange"): self.trade_exchange = common_infra.get("trade_exchange") - def reset(self, outer_trade_decision: TradeDecison = None, **kwargs): + def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs): """ Parameters ---------- - outer_trade_decision : TradeDecison, optional + outer_trade_decision : BaseTradeDecision, optional """ super(SBBStrategyBase, self).reset(outer_trade_decision=outer_trade_decision, **kwargs) if outer_trade_decision is not None: @@ -336,7 +345,7 @@ class SBBStrategyBase(BaseStrategy): # in the first one of two adjacent bars, store the trend for the second one to use self.trade_trend[order.stock_id] = _pred_trend - return TradeDecison(order_list=order_list, ori_strategy=self) + return TradeDecision(order_list=order_list, ori_strategy=self) class SBBStrategyEMA(SBBStrategyBase): @@ -346,7 +355,7 @@ class SBBStrategyEMA(SBBStrategyBase): def __init__( self, - outer_trade_decision: TradeDecison = None, + outer_trade_decision: BaseTradeDecision = None, instruments: Union[List, str] = "csi300", freq: str = "day", trade_exchange: Exchange = None, @@ -426,7 +435,7 @@ class ACStrategy(BaseStrategy): lamb: float = 1e-6, eta: float = 2.5e-6, window_size: int = 20, - outer_trade_decision: TradeDecison = None, + outer_trade_decision: BaseTradeDecision = None, instruments: Union[List, str] = "csi300", freq: str = "day", trade_exchange: Exchange = None, @@ -503,11 +512,11 @@ class ACStrategy(BaseStrategy): self.trade_calendar = level_infra.get("trade_calendar") self._reset_signal() - def reset(self, outer_trade_decision: TradeDecison = None, **kwargs): + def reset(self, outer_trade_decision: BaseTradeDecision = None, **kwargs): """ Parameters ---------- - outer_trade_decision : TradeDecison, optional + outer_trade_decision : BaseTradeDecision, optional """ super(ACStrategy, self).reset(outer_trade_decision=outer_trade_decision, **kwargs) if outer_trade_decision is not None: @@ -592,13 +601,13 @@ class ACStrategy(BaseStrategy): factor=order.factor, ) order_list.append(_order) - return TradeDecison(order_list=order_list, ori_strategy=self) + return TradeDecision(order_list=order_list, ori_strategy=self) class RandomOrderStrategy(BaseStrategy): def __init__(self, - time_range: Tuple = ("9:30", "15:00"), # The range is closed on both left and right. + index_range: Tuple[int, int], # The range is closed on both left and right. sample_ratio: float = 1., volume_ratio: float = 0.01, market: str = "all", @@ -607,10 +616,10 @@ class RandomOrderStrategy(BaseStrategy): """ Parameters ---------- - time_range : Tuple - the intra day time range of the orders + index_range : Tuple + the intra day time index range of the orders the left and right is closed. - # TODO: this is a time_range level limitation. We'll implement a more detailed limitation later. + # TODO: this is a index_range level limitation. We'll implement a more detailed limitation later. sample_ratio : float the ratio of all orders are sampled volume_ratio : float @@ -621,12 +630,27 @@ class RandomOrderStrategy(BaseStrategy): """ super().__init__(*args, **kwargs) - self.time_range = time_range + self.index_range = index_range self.sample_ratio = sample_ratio self.volume_ratio = volume_ratio self.market = market - exch: Exchange = self.common_infra.get("exchange") - self.volume = D.features(D.instruments("market"), ["Mean($volume, 10)"], start_time=exch.start_time, end_time=exch.end_time) + exch: Exchange = self.common_infra.get("trade_exchange") + self.volume = D.features(D.instruments(market), ["Mean(Ref($volume, 1), 10)"], start_time=exch.start_time, end_time=exch.end_time) + self.volume_df = self.volume.iloc[:, 0].unstack() def generate_trade_decision(self, execute_result=None): - return super().generate_trade_decision(execute_result=execute_result) + trade_step = self.trade_calendar.get_trade_step() + step_time_start, step_time_end = self.trade_calendar.get_step_time(trade_step) + + order_list = [] + for direction in Order.SELL, Order.BUY: + for stock_id, volume in self.volume_df[step_time_start].dropna().sample(frac=self.sample_ratio).items(): + order_list.append( + self.common_infra.get("trade_exchange").create_order( + code=stock_id, + amount=volume * self.volume_ratio, + start_time=step_time_start, + end_time=step_time_end, + direction=direction, # 1 for buy + )) + return TradeDecisionWO(order_list, self) diff --git a/qlib/strategy/base.py b/qlib/strategy/base.py index f060ccdb7..b20b0db66 100644 --- a/qlib/strategy/base.py +++ b/qlib/strategy/base.py @@ -7,7 +7,8 @@ from ..data.dataset import DatasetH from ..data.dataset.utils import convert_index_format from ..rl.interpreter import ActionInterpreter, StateInterpreter from ..utils import init_instance_by_config -from ..backtest.utils import BaseTradeDecision, CommonInfrastructure, LevelInfrastructure, TradeCalendarManager, TradeDecison +from ..backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager +from ..backtest.order import BaseTradeDecision class BaseStrategy: @@ -15,16 +16,16 @@ class BaseStrategy: def __init__( self, - outer_trade_decision: TradeDecison = None, + outer_trade_decision: BaseTradeDecision = None, level_infra: LevelInfrastructure = None, common_infra: CommonInfrastructure = None, ): """ Parameters ---------- - outer_trade_decision : TradeDecison, optional - the trade decison of outer strategy which this startegy relies, and it will be traded in [start_time, end_time], by default None - - If the strategy is used to split trade decison, it will be used + outer_trade_decision : BaseTradeDecision, optional + the trade decision of outer strategy which this startegy relies, and it will be traded in [start_time, end_time], by default None + - If the strategy is used to split trade decision, it will be used - If the strategy is used for portfolio management, it can be ignored level_infra : LevelInfrastructure, optional level shared infrastructure for backtesting, including trade calendar @@ -34,14 +35,14 @@ class BaseStrategy: self.reset(level_infra=level_infra, common_infra=common_infra, outer_trade_decision=outer_trade_decision) - def reset_level_infra(self, level_infra): + def reset_level_infra(self, level_infra: LevelInfrastructure): if not hasattr(self, "level_infra"): self.level_infra = level_infra else: self.level_infra.update(level_infra) if level_infra.has("trade_calendar"): - self.trade_calendar = level_infra.get("trade_calendar") + self.trade_calendar: TradeCalendarManager = level_infra.get("trade_calendar") def reset_common_infra(self, common_infra: CommonInfrastructure): if not hasattr(self, "common_infra"): @@ -62,7 +63,7 @@ class BaseStrategy: """ - reset `level_infra`, used to reset trade calendar, .etc - reset `common_infra`, used to reset `trade_account`, `trade_exchange`, .etc - - reset `outer_trade_decision`, used to make split decison + - reset `outer_trade_decision`, used to make split decision """ if level_infra is not None: self.reset_level_infra(level_infra) @@ -79,19 +80,19 @@ class BaseStrategy: Parameters ---------- execute_result : List[object], optional - the executed result for trade decison, by default None + the executed result for trade decision, by default None - When call the generate_trade_decision firstly, `execute_result` could be None """ raise NotImplementedError("generate_trade_decision is not implemented!") - def update_trade_decision(self, trade_decison: BaseTradeDecision, trade_calendar: TradeCalendarManager) -> Union[BaseTradeDecision, None]: + def update_trade_decision(self, trade_decision: BaseTradeDecision, trade_calendar: TradeCalendarManager) -> Union[BaseTradeDecision, None]: """ update trade decision in each step of inner execution, this method enable all order Parameters ---------- - trade_decison : TradeDecison - the trade decison that will be updated + trade_decision : BaseTradeDecision + the trade decision that will be updated trade_calendar : TradeCalendarManager The calendar of the **inner strategy**!!!!! @@ -125,7 +126,7 @@ class ModelStrategy(BaseStrategy): self, model: BaseModel, dataset: DatasetH, - outer_trade_decision: TradeDecison = None, + outer_trade_decision: BaseTradeDecision = None, level_infra: LevelInfrastructure = None, common_infra: CommonInfrastructure = None, **kwargs, @@ -161,7 +162,7 @@ class RLStrategy(BaseStrategy): def __init__( self, policy, - outer_trade_decision: TradeDecison = None, + outer_trade_decision: BaseTradeDecision = None, level_infra: LevelInfrastructure = None, common_infra: CommonInfrastructure = None, **kwargs, @@ -184,7 +185,7 @@ class RLIntStrategy(RLStrategy): policy, state_interpreter: Union[dict, StateInterpreter], action_interpreter: Union[dict, ActionInterpreter], - outer_trade_decision: TradeDecison = None, + outer_trade_decision: BaseTradeDecision = None, level_infra: LevelInfrastructure = None, common_infra: CommonInfrastructure = None, **kwargs, diff --git a/qlib/utils/resam.py b/qlib/utils/resam.py index d28076d88..ae0cdf9d1 100644 --- a/qlib/utils/resam.py +++ b/qlib/utils/resam.py @@ -7,58 +7,7 @@ from typing import Tuple, List, Union, Optional, Callable from . import lazy_sort_index from ..config import C - -NORM_FREQ_MONTH = "month" -NORM_FREQ_WEEK = "week" -NORM_FREQ_DAY = "day" -NORM_FREQ_MINUTE = "minute" - - -def parse_freq(freq: str) -> Tuple[int, str]: - """ - Parse freq into a unified format - - Parameters - ---------- - freq : str - Raw freq, supported freq should match the re '^([0-9]*)(month|mon|week|w|day|d|minute|min)$' - - Returns - ------- - freq: Tuple[int, str] - Unified freq, including freq count and unified freq unit. The freq unit should be '[month|week|day|minute]'. - Example: - - .. code-block:: - - print(parse_freq("day")) - (1, "day" ) - print(parse_freq("2mon")) - (2, "month") - print(parse_freq("10w")) - (10, "week") - - """ - freq = freq.lower() - match_obj = re.match("^([0-9]*)(month|mon|week|w|day|d|minute|min)$", freq) - if match_obj is None: - raise ValueError( - "freq format is not supported, the freq should be like (n)month/mon, (n)week/w, (n)day/d, (n)minute/min" - ) - _count = int(match_obj.group(1)) if match_obj.group(1) else 1 - _freq = match_obj.group(2) - _freq_format_dict = { - "month": NORM_FREQ_MONTH, - "mon": NORM_FREQ_MONTH, - "week": NORM_FREQ_WEEK, - "w": NORM_FREQ_WEEK, - "day": NORM_FREQ_DAY, - "d": NORM_FREQ_DAY, - "minute": NORM_FREQ_MINUTE, - "min": NORM_FREQ_MINUTE, - } - return _count, _freq_format_dict[_freq] - +from .time import Freq def resam_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> np.ndarray: """ @@ -80,13 +29,13 @@ def resam_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> np np.ndarray The calendar with frequency freq_sam """ - raw_count, freq_raw = parse_freq(freq_raw) - sam_count, freq_sam = parse_freq(freq_sam) + raw_count, freq_raw = Freq.parse(freq_raw) + sam_count, freq_sam = Freq.parse(freq_sam) if not len(calendar_raw): return calendar_raw # if freq_sam is xminute, divide each trading day into several bars evenly - if freq_sam == NORM_FREQ_MINUTE: + if freq_sam == Freq.NORM_FREQ_MINUTE: def cal_sam_minute(x, sam_minutes): """ @@ -119,7 +68,7 @@ def resam_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> np else: raise ValueError("calendar minute_index error, check `min_data_shift` in qlib.config.C") - if freq_raw != NORM_FREQ_MINUTE: + if freq_raw != Freq.NORM_FREQ_MINUTE: raise ValueError("when sampling minute calendar, freq of raw calendar must be minute or min") else: if raw_count > sam_count: @@ -130,15 +79,15 @@ def resam_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> np # else, convert the raw calendar into day calendar, and divide the whole calendar into several bars evenly else: _calendar_day = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, 0, 0, 0), calendar_raw))) - if freq_sam == NORM_FREQ_DAY: + if freq_sam == Freq.NORM_FREQ_DAY: return _calendar_day[::sam_count] - elif freq_sam == NORM_FREQ_WEEK: + elif freq_sam == Freq.NORM_FREQ_WEEK: _day_in_week = np.array(list(map(lambda x: x.dayofweek, _calendar_day))) _calendar_week = _calendar_day[np.ediff1d(_day_in_week, to_begin=-1) < 0] return _calendar_week[::sam_count] - elif freq_sam == NORM_FREQ_MONTH: + elif freq_sam == Freq.NORM_FREQ_MONTH: _day_in_month = np.array(list(map(lambda x: x.day, _calendar_day))) _calendar_month = _calendar_day[np.ediff1d(_day_in_month, to_begin=-1) < 0] return _calendar_month[::sam_count] @@ -180,7 +129,7 @@ def get_resam_calendar( """ - _, norm_freq = parse_freq(freq) + _, norm_freq = Freq.parse(freq) from ..data.data import Cal @@ -189,7 +138,7 @@ def get_resam_calendar( freq, freq_sam = freq, None except (ValueError, KeyError): freq_sam = freq - if norm_freq in [NORM_FREQ_MONTH, NORM_FREQ_WEEK, NORM_FREQ_DAY]: + if norm_freq in [Freq.NORM_FREQ_MONTH, Freq.NORM_FREQ_WEEK, Freq.NORM_FREQ_DAY]: try: _calendar = Cal.calendar( start_time=start_time, end_time=end_time, freq="day", freq_sam=freq, future=future @@ -200,7 +149,7 @@ def get_resam_calendar( start_time=start_time, end_time=end_time, freq="1min", freq_sam=freq, future=future ) freq = "1min" - elif norm_freq == NORM_FREQ_MINUTE: + elif norm_freq == Freq.NORM_FREQ_MINUTE: _calendar = Cal.calendar( start_time=start_time, end_time=end_time, freq="1min", freq_sam=freq, future=future ) @@ -224,15 +173,15 @@ def get_higher_eq_freq_feature(instruments, fields, start_time=None, end_time=No _result = D.features(instruments, fields, start_time, end_time, freq=freq, disk_cache=disk_cache) _freq = freq except (ValueError, KeyError): - _, norm_freq = parse_freq(freq) - if norm_freq in [NORM_FREQ_MONTH, NORM_FREQ_WEEK, NORM_FREQ_DAY]: + _, norm_freq = Freq.parse(freq) + if norm_freq in [Freq.NORM_FREQ_MONTH, Freq.NORM_FREQ_WEEK, Freq.NORM_FREQ_DAY]: try: _result = D.features(instruments, fields, start_time, end_time, freq="day", disk_cache=disk_cache) _freq = "day" except (ValueError, KeyError): _result = D.features(instruments, fields, start_time, end_time, freq="1min", disk_cache=disk_cache) _freq = "1min" - elif norm_freq == NORM_FREQ_MINUTE: + elif norm_freq == Freq.NORM_FREQ_MINUTE: _result = D.features(instruments, fields, start_time, end_time, freq="1min", disk_cache=disk_cache) _freq = "1min" else: diff --git a/qlib/utils/time.py b/qlib/utils/time.py new file mode 100644 index 000000000..6e3bd71a3 --- /dev/null +++ b/qlib/utils/time.py @@ -0,0 +1,115 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +""" +Time related utils are compiled in this script +""" +import bisect +from datetime import time +from typing import List, Tuple +import re +from numpy import append +import pandas as pd + + +def get_min_cal() -> List[time]: + """ + get the minute level calendar in day period + + Returns + ------- + List[time]: + + """ + cal = [] + for ts in list(pd.date_range("9:30", "11:29", freq="1min")) + list(pd.date_range("13:00", "14:59", freq="1min")): + cal.append(ts.time()) + return cal + + +class Freq: + NORM_FREQ_MONTH = "month" + NORM_FREQ_WEEK = "week" + NORM_FREQ_DAY = "day" + NORM_FREQ_MINUTE = "minute" + SUPPORT_CAL_LIST = [NORM_FREQ_MINUTE] + + MIN_CAL = get_min_cal() + + def __init__(self, freq: str) -> None: + self.count, self.base = self.parse(freq) + + @staticmethod + def parse(freq: str) -> Tuple[int, str]: + """ + Parse freq into a unified format + + Parameters + ---------- + freq : str + Raw freq, supported freq should match the re '^([0-9]*)(month|mon|week|w|day|d|minute|min)$' + + Returns + ------- + freq: Tuple[int, str] + Unified freq, including freq count and unified freq unit. The freq unit should be '[month|week|day|minute]'. + Example: + + .. code-block:: + + print(Freq.parse("day")) + (1, "day" ) + print(Freq.parse("2mon")) + (2, "month") + print(Freq.parse("10w")) + (10, "week") + + """ + freq = freq.lower() + match_obj = re.match("^([0-9]*)(month|mon|week|w|day|d|minute|min)$", freq) + if match_obj is None: + raise ValueError( + "freq format is not supported, the freq should be like (n)month/mon, (n)week/w, (n)day/d, (n)minute/min" + ) + _count = int(match_obj.group(1)) if match_obj.group(1) else 1 + _freq = match_obj.group(2) + _freq_format_dict = { + "month": Freq.NORM_FREQ_MONTH, + "mon": Freq.NORM_FREQ_MONTH, + "week": Freq.NORM_FREQ_WEEK, + "w": Freq.NORM_FREQ_WEEK, + "day": Freq.NORM_FREQ_DAY, + "d": Freq.NORM_FREQ_DAY, + "minute": Freq.NORM_FREQ_MINUTE, + "min": Freq.NORM_FREQ_MINUTE, + } + return _count, _freq_format_dict[_freq] + + +def get_day_min_idx_range(start: str, end: str, freq: str) -> Tuple[int, int]: + """ + get the min-bar index in a day for a time range (both left and right is closed) given a fixed frequency + Parameters + ---------- + start : str + e.g. "9:30" + end : str + e.g. "14:30" + freq : str + "1min" + + Returns + ------- + Tuple[int, int]: + The index of start and end in the calendar. Both left and right are **closed** + """ + start = pd.Timestamp(start).time() + end = pd.Timestamp(end).time() + freq = Freq(freq) + in_day_cal = Freq.MIN_CAL[::freq.count] + left_idx = bisect.bisect_left(in_day_cal, start) + right_idx = bisect.bisect_right(in_day_cal, end) - 1 + return left_idx, right_idx + + +if __name__ == "__main__": + print(get_day_min_idx_range("8:30", "14:59", "10min")) diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index 0f6950587..549658071 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -16,7 +16,7 @@ from ..backtest import backtest as normal_backtest from ..utils import init_instance_by_config, get_module_by_module_path from ..log import get_module_logger from ..utils import flatten_dict -from ..utils.resam import parse_freq +from ..utils.time import Freq from ..strategy.base import BaseStrategy from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec @@ -344,17 +344,17 @@ class PortAnaRecord(RecordTemp): indicator_analysis_freq = [indicator_analysis_freq] self.risk_analysis_freq = [ - "{0}{1}".format(*parse_freq(_analysis_freq)) for _analysis_freq in risk_analysis_freq + "{0}{1}".format(*Freq.parse(_analysis_freq)) for _analysis_freq in risk_analysis_freq ] self.indicator_analysis_freq = [ - "{0}{1}".format(*parse_freq(_analysis_freq)) for _analysis_freq in indicator_analysis_freq + "{0}{1}".format(*Freq.parse(_analysis_freq)) for _analysis_freq in indicator_analysis_freq ] self.indicator_analysis_method = indicator_analysis_method def _get_report_freq(self, executor_config): ret_freq = [] if executor_config["kwargs"].get("generate_report", False): - _count, _freq = parse_freq(executor_config["kwargs"]["time_per_step"]) + _count, _freq = Freq.parse(executor_config["kwargs"]["time_per_step"]) ret_freq.append(f"{_count}{_freq}") if "sub_env" in executor_config["kwargs"]: ret_freq.extend(self._get_report_freq(executor_config["kwargs"]["sub_env"]))