diff --git a/examples/highfreq/data/README.md b/examples/highfreq/README.md similarity index 100% rename from examples/highfreq/data/README.md rename to examples/highfreq/README.md diff --git a/examples/highfreq/data/highfreq_handler.py b/examples/highfreq/highfreq_handler.py similarity index 100% rename from examples/highfreq/data/highfreq_handler.py rename to examples/highfreq/highfreq_handler.py diff --git a/examples/highfreq/data/highfreq_ops.py b/examples/highfreq/highfreq_ops.py similarity index 100% rename from examples/highfreq/data/highfreq_ops.py rename to examples/highfreq/highfreq_ops.py diff --git a/examples/highfreq/data/highfreq_processor.py b/examples/highfreq/highfreq_processor.py similarity index 100% rename from examples/highfreq/data/highfreq_processor.py rename to examples/highfreq/highfreq_processor.py diff --git a/examples/highfreq/data/workflow.py b/examples/highfreq/workflow.py similarity index 100% rename from examples/highfreq/data/workflow.py rename to examples/highfreq/workflow.py diff --git a/examples/highfreq/backtest/workflow.py b/examples/multi_level_trading/workflow.py similarity index 89% rename from examples/highfreq/backtest/workflow.py rename to examples/multi_level_trading/workflow.py index 786469d8b..9b0e6dc77 100644 --- a/examples/highfreq/backtest/workflow.py +++ b/examples/multi_level_trading/workflow.py @@ -91,13 +91,13 @@ if __name__ == "__main__": }, }, "env": { - "class": "SplitEnv", - "module_path": "qlib.contrib.backtest.env", + "class": "SplitExecutor", + "module_path": "qlib.contrib.backtest.executor", "kwargs": { "step_bar": "week", "sub_env": { - "class": "SimulatorEnv", - "module_path": "qlib.contrib.backtest.env", + "class": "SimulatorExecutor", + "module_path": "qlib.contrib.backtest.executor", "kwargs": { "step_bar": "day", "verbose": True, @@ -118,14 +118,17 @@ if __name__ == "__main__": "backtest": { "start_time": trade_start_time, "end_time": trade_end_time, - "verbose": False, - "limit_threshold": 0.095, "account": 100000000, "benchmark": benchmark, - "deal_price": "close", - "open_cost": 0.0005, - "close_cost": 0.0015, - "min_cost": 5, + "exchange_kwargs": { + "freq": "day", + "verbose": False, + "limit_threshold": 0.095, + "deal_price": "close", + "open_cost": 0.0005, + "close_cost": 0.0015, + "min_cost": 5, + }, }, } diff --git a/qlib/contrib/backtest/__init__.py b/qlib/contrib/backtest/__init__.py index dacbdfefc..c8114d852 100644 --- a/qlib/contrib/backtest/__init__.py +++ b/qlib/contrib/backtest/__init__.py @@ -1,15 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from .order import Order -from .position import Position + from .exchange import Exchange -from .report import Report +from .executor import BaseExecutor from .backtest import backtest as backtest_func -import copy -import numpy as np import inspect +from ...strategy.base import BaseStrategy from ...utils import init_instance_by_config from ...log import get_module_logger from ...config import C @@ -90,21 +88,6 @@ def get_exchange( return init_instance_by_config(exchange, accept_types=Exchange) -def init_env_instance_by_config(env): - if isinstance(env, dict): - env_config = copy.copy(env) - if "kwargs" in env_config: - env_kwargs = copy.copy(env_config["kwargs"]) - if "sub_env" in env_kwargs: - env_kwargs["sub_env"] = init_env_instance_by_config(env_kwargs["sub_env"]) - if "sub_strategy" in env_kwargs: - env_kwargs["sub_strategy"] = init_instance_by_config(env_kwargs["sub_strategy"]) - env_config["kwargs"] = env_kwargs - return init_instance_by_config(env_config) - else: - return env - - def setup_exchange(root_instance, trade_exchange=None, force=False): if "trade_exchange" in inspect.getfullargspec(root_instance.__class__).args: if force: @@ -118,13 +101,11 @@ def setup_exchange(root_instance, trade_exchange=None, force=False): setup_exchange(root_instance.sub_strategy, trade_exchange) -def backtest(start_time, end_time, strategy, env, benchmark="SH000905", account=1e9, **kwargs): - trade_strategy = init_instance_by_config(strategy) - trade_env = init_env_instance_by_config(env) +def backtest(start_time, end_time, strategy, env, benchmark="SH000905", account=1e9, exchange_kwargs={}): + trade_strategy = init_instance_by_config(strategy, accept_types=BaseStrategy) + trade_env = init_instance_by_config(env, accept_types=BaseExecutor) - spec = inspect.getfullargspec(get_exchange) - exchange_args = {k: v for k, v in kwargs.items() if k in spec.args} - trade_exchange = get_exchange(**exchange_args) + trade_exchange = get_exchange(**exchange_kwargs) setup_exchange(trade_env, trade_exchange) setup_exchange(trade_strategy, trade_exchange) diff --git a/qlib/contrib/backtest/account.py b/qlib/contrib/backtest/account.py index 39fecbd88..7e37c1093 100644 --- a/qlib/contrib/backtest/account.py +++ b/qlib/contrib/backtest/account.py @@ -3,13 +3,14 @@ import copy +import warnings import pandas as pd from .position import Position from .report import Report from .order import Order from ...data import D -from ...utils import parse_freq, sample_feature +from ...utils.sample import parse_freq, sample_feature """ @@ -110,6 +111,8 @@ class Account: for k, v in kwargs.items(): if hasattr(self, k): setattr(self, k, v) + else: + warnings.warn(f"reser error, attribute {k} is not found!") def get_positions(self): return self.positions diff --git a/qlib/contrib/backtest/backtest.py b/qlib/contrib/backtest/backtest.py index d67d6782b..d5f92ebae 100644 --- a/qlib/contrib/backtest/backtest.py +++ b/qlib/contrib/backtest/backtest.py @@ -1,10 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. - -import numpy as np -import pandas as pd - from .account import Account @@ -14,9 +10,9 @@ def backtest(start_time, end_time, trade_strategy, trade_env, benchmark, account trade_env.reset(start_time=start_time, end_time=end_time, trade_account=trade_account) trade_strategy.reset(start_time=start_time, end_time=end_time) - trade_state = trade_env.get_init_state() + _execute_state = trade_env.get_init_state() while not trade_env.finished(): - _order_list = trade_strategy.generate_order_list(**trade_state) - trade_state, trade_info = trade_env.execute(_order_list) + _order_list = trade_strategy.generate_order_list(_execute_state) + _execute_state = trade_env.execute(_order_list) return trade_env.get_report() diff --git a/qlib/contrib/backtest/exchange.py b/qlib/contrib/backtest/exchange.py index 51f0dd68d..86045fd7a 100644 --- a/qlib/contrib/backtest/exchange.py +++ b/qlib/contrib/backtest/exchange.py @@ -11,7 +11,7 @@ import pandas as pd from ...data.data import D from ...data.dataset.utils import get_level_index from ...config import C, REG_CN -from ...utils import sample_feature +from ...utils.sample import sample_feature from ...log import get_module_logger from .order import Order diff --git a/qlib/contrib/backtest/env.py b/qlib/contrib/backtest/executor.py similarity index 63% rename from qlib/contrib/backtest/env.py rename to qlib/contrib/backtest/executor.py index eb922cefd..935af7361 100644 --- a/qlib/contrib/backtest/env.py +++ b/qlib/contrib/backtest/executor.py @@ -1,19 +1,34 @@ -import re -import json import copy import warnings -import pathlib -import numpy as np import pandas as pd +from typing import Tuple, List, Union, Optional, Callable from ...data.data import Cal -from ...utils import get_sample_freq_calendar, parse_freq -from .position import Position +from ...strategy.base import BaseStrategy +from ...utils import init_instance_by_config +from ...utils.sample import get_sample_freq_calendar, parse_freq from .report import Report from .order import Order +from .account import Account +from .exchange import Exchange class BaseTradeCalendar: - def __init__(self, step_bar, start_time=None, end_time=None, **kwargs): + def __init__( + self, step_bar: str, start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None + ): + """ + Parameters + ---------- + step_bar : str + frequency of each trading step bar + start_time : Union[str, pd.Timestamp], optional + start time of trading, by default None + If `start_time` is None, it must be reset before trading. + end_time : Union[str, pd.Timestamp], optional + end time of trading, by default None + If `end_time` is None, it must be reset before trading. + """ + self.step_bar = step_bar self.reset(start_time=start_time, end_time=end_time) @@ -27,10 +42,9 @@ class BaseTradeCalendar: if self.start_time and self.end_time: _calendar, freq, freq_sam = get_sample_freq_calendar(freq=self.step_bar) self.calendar = _calendar - _start_time, _end_time, _start_index, _end_index = Cal.locate_index( + _, _, _start_index, _end_index = Cal.locate_index( self.start_time, self.end_time, freq=freq, freq_sam=freq_sam ) - _trade_calendar = self.calendar[_start_index : _end_index + 1] self.start_index = _start_index self.end_index = _end_index self.trade_len = _end_index - _start_index + 1 @@ -45,6 +59,8 @@ class BaseTradeCalendar: for k, v in kwargs.items(): if hasattr(self, k): setattr(self, k, v) + else: + warnings.warn(f"reser error, attribute {k} is not found!") def _get_calendar_time(self, trade_index=1, shift=0): trade_index = trade_index - shift @@ -55,34 +71,43 @@ class BaseTradeCalendar: return self.trade_index >= self.trade_len - 1 def step(self): + if self.finished(): + raise RuntimeError(f"this env has completed its task, please reset it if you want to call it!") self.trade_index = self.trade_index + 1 -class BaseEnv(BaseTradeCalendar): - """ - # Strategy framework document - - class Env(BaseEnv): - """ +class BaseExecutor(BaseTradeCalendar): + """Base executor for trading""" def __init__( self, - step_bar, - start_time=None, - end_time=None, - trade_account=None, - generate_report=False, - verbose=False, + step_bar: str, + start_time: Union[str, pd.Timestamp] = None, + end_time: Union[str, pd.Timestamp] = None, + trade_account: Account = None, + generate_report: bool = False, + verbose: bool = False, **kwargs, ): - self.generate_report = generate_report - self.verbose = verbose - super(BaseEnv, self).__init__( + """ + Parameters + ---------- + trade_account : Account, optional + trade account for trading, by default None + If `trade_account` is None, it must be reset before trading + generate_report : bool, optional + whether to generate report, by default False + verbose : bool, optional + whether to print log, by default False + """ + super(BaseExecutor, self).__init__( step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, **kwargs ) + self.generate_report = generate_report + self.verbose = verbose def reset(self, trade_account=None, **kwargs): - super(BaseEnv, self).reset(**kwargs) + super(BaseExecutor, self).reset(**kwargs) if trade_account: self.trade_account = trade_account self.trade_account.reset(freq=self.step_bar, report=Report(), positions={}) @@ -101,23 +126,31 @@ class BaseEnv(BaseTradeCalendar): raise NotImplementedError("get_report is not implemented!") -class SplitEnv(BaseEnv): +class SplitExecutor(BaseExecutor): def __init__( self, - step_bar, - sub_env, - sub_strategy, - start_time=None, - end_time=None, - trade_account=None, - trade_exchange=None, - generate_report=False, - verbose=False, + step_bar: str, + sub_env: Union[BaseExecutor, dict], + sub_strategy: Union[BaseStrategy, dict], + start_time: Union[str, pd.Timestamp] = None, + end_time: Union[str, pd.Timestamp] = None, + trade_account: Account = None, + trade_exchange: Exchange = None, + generate_report: bool = False, + verbose: bool = False, **kwargs, ): - self.sub_env = sub_env - self.sub_strategy = sub_strategy - super(SplitEnv, self).__init__( + """ + Parameters + ---------- + sub_env : BaseExecutor + trading env in each trading bar. + sub_strategy : BaseStrategy + trading strategy in each trading bar + trade_exchange : Exchange + exchange that provides market info + """ + super(SplitExecutor, self).__init__( step_bar=step_bar, start_time=start_time, end_time=end_time, @@ -127,28 +160,26 @@ class SplitEnv(BaseEnv): verbose=verbose, **kwargs, ) + self.sub_env = init_instance_by_config(sub_env, accept_types=BaseExecutor) + self.sub_strategy = init_instance_by_config(sub_strategy, accept_types=BaseStrategy) def reset(self, trade_account=None, trade_exchange=None, **kwargs): - super(SplitEnv, self).reset(trade_account=trade_account, **kwargs) + + super(SplitExecutor, self).reset(trade_account=trade_account, **kwargs) if trade_account: self.sub_env.reset(trade_account=copy.copy(trade_account)) if trade_exchange: self.trade_exchange = trade_exchange - def execute(self, order_list, **kwargs): - if self.finished(): - raise StopIteration(f"this env has completed its task, please reset it if you want to call it!") - # if self.track: - # yield action - # episode_reward = 0 - super(SplitEnv, self).step() + def execute(self, order_list): + super(SplitExecutor, self).step() trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index) self.sub_env.reset(start_time=trade_start_time, end_time=trade_end_time) self.sub_strategy.reset(start_time=trade_start_time, end_time=trade_end_time, trade_order_list=order_list) - trade_state = self.sub_env.get_init_state() + _execute_state = self.sub_env.get_init_state() while not self.sub_env.finished(): - _order_list = self.sub_strategy.generate_order_list(**trade_state) - trade_state, trade_info = self.sub_env.execute(order_list=_order_list) + _order_list = self.sub_strategy.generate_order_list(_execute_state) + _execute_state = self.sub_env.execute(order_list=_order_list) self.trade_account.update_bar_end( trade_start_time=trade_start_time, @@ -156,9 +187,8 @@ class SplitEnv(BaseEnv): trade_exchange=self.trade_exchange, update_report=self.generate_report, ) - _obs = {"current": self.trade_account.current} - _info = {} - return _obs, _info + _execute_state = {"current": self.trade_account.current} + return _execute_state def get_report(self): sub_env_report_dict = self.sub_env.get_report() @@ -167,12 +197,10 @@ class SplitEnv(BaseEnv): _positions = self.trade_account.get_positions() _count, _freq = parse_freq(self.step_bar) sub_env_report_dict.update({f"{_count}{_freq}": (_report, _positions)}) - return sub_env_report_dict - else: - return sub_env_report_dict + return sub_env_report_dict -class SimulatorEnv(BaseEnv): +class SimulatorExecutor(BaseExecutor): def __init__( self, step_bar, @@ -184,7 +212,13 @@ class SimulatorEnv(BaseEnv): verbose=False, **kwargs, ): - super(SimulatorEnv, self).__init__( + """ + Parameters + ---------- + trade_exchange : Exchange + exchange that provides market info + """ + super(SimulatorExecutor, self).__init__( step_bar=step_bar, start_time=start_time, end_time=end_time, @@ -196,17 +230,12 @@ class SimulatorEnv(BaseEnv): ) def reset(self, trade_exchange=None, **kwargs): - super(SimulatorEnv, self).reset(**kwargs) + super(SimulatorExecutor, self).reset(**kwargs) if trade_exchange: self.trade_exchange = trade_exchange - def execute(self, order_list, **kwargs): - """ - Return: obs, done, info - """ - if self.finished(): - raise StopIteration(f"this env has completed its task, please reset it if you want to call it!") - super(SimulatorEnv, self).step() + def execute(self, order_list): + super(SimulatorExecutor, self).step() trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index) trade_info = [] for order in order_list: @@ -219,21 +248,25 @@ class SimulatorEnv(BaseEnv): if self.verbose: if order.direction == Order.SELL: # sell print( - "[I {:%Y-%m-%d}]: sell {}, price {:.2f}, amount {}, value {:.2f}.".format( + "[I {:%Y-%m-%d}]: sell {}, price {:.2f}, amount {}, deal_amount {}, factor {}, value {:.2f}.".format( trade_start_time, order.stock_id, trade_price, + order.amount, order.deal_amount, + order.factor, trade_val, ) ) else: print( - "[I {:%Y-%m-%d}]: buy {}, price {:.2f}, amount {}, value {:.2f}.".format( + "[I {:%Y-%m-%d}]: buy {}, price {:.2f}, amount {}, deal_amount {}, factor {}, value {:.2f}.".format( trade_start_time, order.stock_id, trade_price, + order.amount, order.deal_amount, + order.factor, trade_val, ) ) @@ -249,9 +282,8 @@ class SimulatorEnv(BaseEnv): trade_exchange=self.trade_exchange, update_report=self.generate_report, ) - _obs = {"current": self.trade_account.current} - _info = {"trade_info": trade_info} - return _obs, _info + _execute_state = {"current": self.trade_account.current, "trade_info": trade_info} + return _execute_state def get_report(self): if self.generate_report: diff --git a/qlib/contrib/backtest/interpreter.py b/qlib/contrib/backtest/interpreter.py deleted file mode 100644 index 7f33c809d..000000000 --- a/qlib/contrib/backtest/interpreter.py +++ /dev/null @@ -1,16 +0,0 @@ -class BaseInterpreter: - @staticmethod - def interpret(**kwargs): - raise NotImplementedError("interpret is not implemented!") - - -class ActionInterpreter: - @staticmethod - def interpret(action, **kwargs): - return action - - -class StateInterpreter: - @staticmethod - def interpret(state, **kwargs): - return state diff --git a/qlib/contrib/evaluate.py b/qlib/contrib/evaluate.py index 91cfc1d89..10f80671e 100644 --- a/qlib/contrib/evaluate.py +++ b/qlib/contrib/evaluate.py @@ -10,6 +10,7 @@ import warnings from ..log import get_module_logger from .backtest import get_exchange, backtest as backtest_func from ..utils import get_date_range +from ..utils.sample import parse_freq from ..data import D from ..config import C @@ -19,7 +20,7 @@ from ..data.dataset.utils import get_level_index logger = get_module_logger("Evaluate") -def risk_analysis(r, N=252): +def risk_analysis(r, N: int = None, freq: str = None): """Risk Analysis Parameters @@ -27,8 +28,26 @@ def risk_analysis(r, N=252): r : pandas.Series daily return series. N: int - scaler for annualizing information_ratio (day: 250, week: 50, month: 12). + scaler for annualizing information_ratio (day: 250, week: 50, month: 12), at least one of `N` and `freq` should exist + freq: str + analysis frequency used for calculating the scaler, at least one of `N` and `freq` should exist """ + + def cal_risk_analysis_scaler(freq): + _count, _freq = parse_freq(freq) + _freq_scaler = { + "minute": 240 * 250, + "day": 250, + "week": 50, + "month": 12, + } + return _count * _freq_scaler[_freq] + + if N is None and freq is None: + raise ValueError("at least one of `N` and `freq` should exist") + if N is None: + N = cal_risk_analysis_scaler(freq) + mean = r.mean() std = r.std(ddof=1) annualized_return = mean * N diff --git a/qlib/contrib/online/executor.py b/qlib/contrib/online/executor.py deleted file mode 100644 index 2bd0937a0..000000000 --- a/qlib/contrib/online/executor.py +++ /dev/null @@ -1,291 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - - -import re -import json -import copy -import pathlib -import pandas as pd -from ...data import D -from ...utils import get_date_in_file_name -from ...utils import get_pre_trading_date -from ..backtest.order import Order - - -class BaseExecutor: - """ - # Strategy framework document - - class Executor(BaseExecutor): - """ - - def execute(self, trade_account, order_list, trade_date): - """ - return the executed result (trade_info) after trading at trade_date. - NOTICE: trade_account will not be modified after executing. - Parameter - --------- - trade_account : Account() - order_list : list - [Order()] - trade_date : pd.Timestamp - Return - --------- - trade_info : list - [Order(), float, float, float] - """ - raise NotImplementedError("get_execute_result for this model is not implemented.") - - def save_executed_file_from_trade_info(self, trade_info, user_path, trade_date): - """ - Save the trade_info to the .csv transaction file in disk - the columns of result file is - ['date', 'stock_id', 'direction', 'trade_val', 'trade_cost', 'trade_price', 'factor'] - Parameter - --------- - trade_info : list of [Order(), float, float, float] - (order, trade_val, trade_cost, trade_price), trade_info with out factor - user_path: str / pathlib.Path() - the sub folder to save user data - - transaction_path : string / pathlib.Path() - """ - YYYY, MM, DD = str(trade_date.date()).split("-") - folder_path = pathlib.Path(user_path) / "trade" / YYYY / MM - if not folder_path.exists(): - folder_path.mkdir(parents=True) - transaction_path = folder_path / "transaction_{}.csv".format(str(trade_date.date())) - columns = [ - "date", - "stock_id", - "direction", - "amount", - "trade_val", - "trade_cost", - "trade_price", - "factor", - ] - data = [] - for [order, trade_val, trade_cost, trade_price] in trade_info: - data.append( - [ - trade_date, - order.stock_id, - order.direction, - order.amount, - trade_val, - trade_cost, - trade_price, - order.factor, - ] - ) - df = pd.DataFrame(data, columns=columns) - df.to_csv(transaction_path, index=False) - - def load_trade_info_from_executed_file(self, user_path, trade_date): - YYYY, MM, DD = str(trade_date.date()).split("-") - file_path = pathlib.Path(user_path) / "trade" / YYYY / MM / "transaction_{}.csv".format(str(trade_date.date())) - if not file_path.exists(): - raise ValueError("File {} not exists!".format(file_path)) - - filedate = get_date_in_file_name(file_path) - transaction = pd.read_csv(file_path) - trade_info = [] - for i in range(len(transaction)): - date = transaction.loc[i]["date"] - if not date == filedate: - continue - # raise ValueError("date in transaction file {} not equal to it's file date{}".format(date, filedate)) - order = Order( - stock_id=transaction.loc[i]["stock_id"], - amount=transaction.loc[i]["amount"], - trade_date=transaction.loc[i]["date"], - direction=transaction.loc[i]["direction"], - factor=transaction.loc[i]["factor"], - ) - trade_val = transaction.loc[i]["trade_val"] - trade_cost = transaction.loc[i]["trade_cost"] - trade_price = transaction.loc[i]["trade_price"] - trade_info.append([order, trade_val, trade_cost, trade_price]) - return trade_info - - -class SimulatorExecutor(BaseExecutor): - def __init__(self, trade_exchange, verbose=False): - self.trade_exchange = trade_exchange - self.verbose = verbose - self.order_list = [] - - def execute(self, trade_account, order_list, trade_date): - """ - execute the order list, do the trading wil exchange at date. - Will not modify the trade_account. - Parameter - trade_account : Account() - order_list : list - list or orders - trade_date : pd.Timestamp - :return: - trade_info : list of [Order(), float, float, float] - (order, trade_val, trade_cost, trade_price), trade_info with out factor - """ - account = copy.deepcopy(trade_account) - trade_info = [] - - for order in order_list: - # check holding thresh is done in strategy - # if order.direction==0: # sell order - # # checking holding thresh limit for sell order - # if trade_account.current.get_stock_count(order.stock_id) < thresh: - # # can not sell this code - # continue - # is order executable - # check order - if self.trade_exchange.check_order(order) is True: - # execute the order - trade_val, trade_cost, trade_price = self.trade_exchange.deal_order(order, trade_account=account) - trade_info.append([order, trade_val, trade_cost, trade_price]) - if self.verbose: - if order.direction == Order.SELL: # sell - print( - "[I {:%Y-%m-%d}]: sell {}, price {:.2f}, amount {}, value {:.2f}.".format( - trade_date, - order.stock_id, - trade_price, - order.deal_amount, - trade_val, - ) - ) - else: - print( - "[I {:%Y-%m-%d}]: buy {}, price {:.2f}, amount {}, value {:.2f}.".format( - trade_date, - order.stock_id, - trade_price, - order.deal_amount, - trade_val, - ) - ) - - else: - if self.verbose: - print("[W {:%Y-%m-%d}]: {} wrong.".format(trade_date, order.stock_id)) - # do nothing - pass - return trade_info - - -def save_score_series(score_series, user_path, trade_date): - """Save the score_series into a .csv file. - The columns of saved file is - [stock_id, score] - - Parameter - --------- - order_list: [Order()] - list of Order() - date: pd.Timestamp - the date to save the order list - user_path: str / pathlib.Path() - the sub folder to save user data - """ - user_path = pathlib.Path(user_path) - YYYY, MM, DD = str(trade_date.date()).split("-") - folder_path = user_path / "score" / YYYY / MM - if not folder_path.exists(): - folder_path.mkdir(parents=True) - file_path = folder_path / "score_{}.csv".format(str(trade_date.date())) - score_series.to_csv(file_path) - - -def load_score_series(user_path, trade_date): - """Save the score_series into a .csv file. - The columns of saved file is - [stock_id, score] - - Parameter - --------- - order_list: [Order()] - list of Order() - date: pd.Timestamp - the date to save the order list - user_path: str / pathlib.Path() - the sub folder to save user data - """ - user_path = pathlib.Path(user_path) - YYYY, MM, DD = str(trade_date.date()).split("-") - folder_path = user_path / "score" / YYYY / MM - if not folder_path.exists(): - folder_path.mkdir(parents=True) - file_path = folder_path / "score_{}.csv".format(str(trade_date.date())) - score_series = pd.read_csv(file_path, index_col=0, header=None, names=["instrument", "score"]) - return score_series - - -def save_order_list(order_list, user_path, trade_date): - """ - Save the order list into a json file. - Will calculate the real amount in order according to factors at date. - - The format in json file like - {"sell": {"stock_id": amount, ...} - ,"buy": {"stock_id": amount, ...}} - - :param - order_list: [Order()] - list of Order() - date: pd.Timestamp - the date to save the order list - user_path: str / pathlib.Path() - the sub folder to save user data - """ - user_path = pathlib.Path(user_path) - YYYY, MM, DD = str(trade_date.date()).split("-") - folder_path = user_path / "trade" / YYYY / MM - if not folder_path.exists(): - folder_path.mkdir(parents=True) - sell = {} - buy = {} - for order in order_list: - if order.direction == 0: # sell - sell[order.stock_id] = [order.amount, order.factor] - else: - buy[order.stock_id] = [order.amount, order.factor] - order_dict = {"sell": sell, "buy": buy} - file_path = folder_path / "orderlist_{}.json".format(str(trade_date.date())) - with file_path.open("w") as fp: - json.dump(order_dict, fp) - - -def load_order_list(user_path, trade_date): - user_path = pathlib.Path(user_path) - YYYY, MM, DD = str(trade_date.date()).split("-") - path = user_path / "trade" / YYYY / MM / "orderlist_{}.json".format(str(trade_date.date())) - if not path.exists(): - raise ValueError("File {} not exists!".format(path)) - # get orders - with path.open("r") as fp: - order_dict = json.load(fp) - order_list = [] - for stock_id in order_dict["sell"]: - amount, factor = order_dict["sell"][stock_id] - order = Order( - stock_id=stock_id, - amount=amount, - trade_date=pd.Timestamp(trade_date), - direction=Order.SELL, - factor=factor, - ) - order_list.append(order) - for stock_id in order_dict["buy"]: - amount, factor = order_dict["buy"][stock_id] - order = Order( - stock_id=stock_id, - amount=amount, - trade_date=pd.Timestamp(trade_date), - direction=Order.BUY, - factor=factor, - ) - order_list.append(order) - return order_list diff --git a/qlib/contrib/strategy/model_strategy.py b/qlib/contrib/strategy/model_strategy.py index 6899a10a5..1fc1bf070 100644 --- a/qlib/contrib/strategy/model_strategy.py +++ b/qlib/contrib/strategy/model_strategy.py @@ -3,7 +3,7 @@ import warnings import numpy as np import pandas as pd -from ...utils import sample_feature +from ...utils.sample import sample_feature from ...strategy.base import ModelStrategy from ..backtest.order import Order from .order_generator import OrderGenWInteract @@ -66,7 +66,7 @@ class TopkDropoutStrategy(ModelStrategy): if trade_exchange: self.trade_exchange = trade_exchange - def get_risk_degree(self, trade_index): + def get_risk_degree(self, trade_index=None): """get_risk_degree Return the proportion of your total value you will used in investment. Dynamically risk_degree will result in Market timing. @@ -74,7 +74,7 @@ class TopkDropoutStrategy(ModelStrategy): # It will use 95% amoutn of your total value by default return self.risk_degree - def generate_order_list(self, current, **kwargs): + def generate_order_list(self, execute_state): super(TopkDropoutStrategy, self).step() trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index) pred_start_time, pred_end_time = self._get_calendar_time(self.trade_index, shift=1) @@ -120,6 +120,7 @@ class TopkDropoutStrategy(ModelStrategy): def filter_stock(l): return l + current = execute_state.get("current") current_temp = copy.deepcopy(current) # generate order list for this adjust date sell_order_list = [] @@ -163,6 +164,7 @@ class TopkDropoutStrategy(ModelStrategy): # Get the stock list we really want to buy buy = today[: len(sell) + self.topk - len(last)] + print("INTRANEL BAR", len(sell), len(sell) + self.topk - len(last), len(last)) # print("flag", len(sell), len(buy), self.topk, len(last)) for code in current_stock_list: if not self.trade_exchange.is_stock_tradable( @@ -175,13 +177,17 @@ class TopkDropoutStrategy(ModelStrategy): continue # sell order sell_amount = current_temp.get_stock_amount(code=code) + factor = self.trade_exchange.get_factor( + stock_id=code, start_time=trade_start_time, end_time=trade_end_time + ) + # sell_amount = self.trade_exchange.round_amount_by_trade_unit(sell_amount, factor) sell_order = Order( stock_id=code, amount=sell_amount, start_time=trade_start_time, end_time=trade_end_time, direction=Order.SELL, # 0 for sell, 1 for buy - factor=self.trade_exchange.get_factor(code, trade_start_time, trade_end_time), + factor=factor, ) # is order executable if self.trade_exchange.check_order(sell_order): @@ -228,19 +234,36 @@ class WeightStrategyBase(ModelStrategy): def __init__( self, step_bar, + model, + dataset, start_time=None, end_time=None, order_generator_cls_or_obj=OrderGenWInteract, trade_exchange=None, **kwargs, ): - super(WeightStrategyBase, self).__init__(step_bar, start_time, end_time) - self.trade_exchange = trade_exchange + super(WeightStrategyBase, self).__init__( + step_bar, model, dataset, start_time, end_time, trade_exchange=trade_exchange, **kwargs + ) + if isinstance(order_generator_cls_or_obj, type): self.order_generator = order_generator_cls_or_obj() else: self.order_generator = order_generator_cls_or_obj + def reset(self, trade_exchange=None, **kwargs): + super(WeightStrategyBase, self).reset(**kwargs) + if trade_exchange: + self.trade_exchange = trade_exchange + + def get_risk_degree(self, trade_index=None): + """get_risk_degree + Return the proportion of your total value you will used in investment. + Dynamically risk_degree will result in Market timing. + """ + # It will use 95% amoutn of your total value by default + return 0.95 + def generate_target_weight_position(self, score, current, trade_start_time, trade_end_time): """ Generate target position from score for this date and the current position.The cash is not considered in the position @@ -256,7 +279,7 @@ class WeightStrategyBase(ModelStrategy): """ raise NotImplementedError() - def generate_order_list(self, current, **kwargs): + def generate_order_list(self, execute_state): """ Parameters ----------- @@ -277,7 +300,8 @@ class WeightStrategyBase(ModelStrategy): pred_score = sample_feature(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last") if pred_score is None: return [] - current_temp = copy.deepcopy(trade_account.current) + current = execute_state.get("current") + current_temp = copy.deepcopy(current) target_weight_position = self.generate_target_weight_position( score=pred_score, current=current_temp, trade_start_time=trade_start_time, trade_end_time=trade_end_time ) diff --git a/qlib/contrib/strategy/rule_strategy.py b/qlib/contrib/strategy/rule_strategy.py index 45df94830..073f513c7 100644 --- a/qlib/contrib/strategy/rule_strategy.py +++ b/qlib/contrib/strategy/rule_strategy.py @@ -3,14 +3,15 @@ import warnings import numpy as np import pandas as pd -from ...utils import sample_feature + +from ...utils.sample import sample_feature from ...data.data import D -from ...data.dataset.utils import get_level_index -from ...strategy.base import RuleStrategy, TradingEnhancement +from ...data.dataset.utils import convert_index_format +from ...strategy.base import RuleStrategy, OrderEnhancement from ..backtest.order import Order -class TWAPStrategy(RuleStrategy, TradingEnhancement): +class TWAPStrategy(RuleStrategy, OrderEnhancement): def __init__( self, step_bar, @@ -23,7 +24,7 @@ class TWAPStrategy(RuleStrategy, TradingEnhancement): def reset(self, trade_order_list=None, trade_exchange=None, **kwargs): super(TWAPStrategy, self).reset(**kwargs) - TradingEnhancement.reset(self, trade_order_list=trade_order_list) + OrderEnhancement.reset(self, trade_order_list=trade_order_list) if trade_exchange: self.trade_exchange = trade_exchange if trade_order_list: @@ -31,7 +32,7 @@ class TWAPStrategy(RuleStrategy, TradingEnhancement): for order in self.trade_order_list: self.trade_amount[(order.stock_id, order.direction)] = order.amount - def generate_order_list(self, **kwargs): + def generate_order_list(self, execute_state): super(TWAPStrategy, self).step() trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index) order_list = [] @@ -66,7 +67,7 @@ class TWAPStrategy(RuleStrategy, TradingEnhancement): return order_list -class SBBStrategyBase(RuleStrategy, TradingEnhancement): +class SBBStrategyBase(RuleStrategy, OrderEnhancement): """ (S)elect the (B)etter one among every two adjacent trading (B)ars to sell or buy. """ @@ -87,7 +88,7 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement): def reset(self, trade_order_list=None, trade_exchange=None, **kwargs): super(SBBStrategyBase, self).reset(**kwargs) - TradingEnhancement.reset(self, trade_order_list=trade_order_list) + OrderEnhancement.reset(self, trade_order_list=trade_order_list) if trade_exchange: self.trade_exchange = trade_exchange if trade_order_list is not None: @@ -100,7 +101,7 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement): def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None): raise NotImplementedError("pred_price_trend method is not implemented!") - def generate_order_list(self, **kwargs): + def generate_order_list(self, execute_state): super(SBBStrategyBase, self).step() if not self.trade_order_list: return [] @@ -109,7 +110,7 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement): order_list = [] for order in self.trade_order_list: if self.trade_index % 2 == 1: - _pred_trend = self._pred_price_trend(order.stock_id) + _pred_trend = self._pred_price_trend(order.stock_id, pred_start_time, pred_end_time) else: _pred_trend = self.trade_trend[(order.stock_id, order.direction)] @@ -127,7 +128,7 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement): _order_amount = self.trade_amount[(order.stock_id, order.direction)] / ( self.trade_len - self.trade_index ) - if self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit: + elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit: trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit) _order_amount = ( (trade_unit_cnt + self.trade_len - self.trade_index - 1) @@ -146,6 +147,7 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement): factor=order.factor, ) order_list.append(_order) + # print("DEBUG AMOUNT", _order_amount, self.trade_amount[(order.stock_id, order.direction)], _amount_trade_unit) else: _order_amount = None if _amount_trade_unit is None: @@ -154,12 +156,12 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement): * self.trade_amount[(order.stock_id, order.direction)] / (self.trade_len - self.trade_index + 1) ) - if self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit: + elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit: trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit) _order_amount = ( - 2 - * (trade_unit_cnt + self.trade_len - self.trade_index) + (trade_unit_cnt + self.trade_len - self.trade_index) // (self.trade_len - self.trade_index + 1) + * 2 * _amount_trade_unit ) if _order_amount: @@ -197,6 +199,7 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement): factor=order.factor, ) order_list.append(_order) + # print("DEBUG AMOUNT", _order_amount, self.trade_amount[(order.stock_id, order.direction)], _amount_trade_unit) if self.trade_index % 2 == 1: self.trade_trend[(order.stock_id, order.direction)] = _pred_trend @@ -226,20 +229,15 @@ class SBBStrategyEMA(SBBStrategyBase): self.instruments = D.instruments(instruments) self.freq = freq - def _convert_index_format(self, df): - if get_level_index(df, level="datetime") == 1: - df = df.swaplevel().sort_index() - return df - - def _reset_trade_calendar(self, start_time=None, end_time=None): - super(SBBStrategyEMA, self)._reset_trade_calendar(start_time=start_time, end_time=end_time) + def reset(self, start_time=None, end_time=None, **kwargs): + super(SBBStrategyEMA, self).reset(start_time=start_time, end_time=end_time, **kwargs) if self.start_time and self.end_time: fields = ["EMA($close, 10)-EMA($close, 20)"] signal_start_time, _ = self._get_calendar_time(trade_index=self.trade_index, shift=1) signal_df = D.features( self.instruments, fields, start_time=signal_start_time, end_time=self.end_time, freq=self.freq ) - signal_df = self._convert_index_format(signal_df) + signal_df = convert_index_format(signal_df) signal_df.columns = ["signal"] self.signal = {} for stock_id, stock_val in signal_df.groupby(level="instrument"): diff --git a/qlib/data/data.py b/qlib/data/data.py index d44139c80..91a21da9f 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -25,7 +25,8 @@ from ..log import get_module_logger from ..utils import parse_field, read_bin, hash_args, normalize_cache_fields, code_to_fname from .base import Feature from .cache import DiskDatasetCache, DiskExpressionCache -from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path, sample_calendar +from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path +from ..utils.sample import sample_calendar class CalendarProvider(abc.ABC): @@ -35,7 +36,7 @@ class CalendarProvider(abc.ABC): """ @abc.abstractmethod - def calendar(self, start_time=None, end_time=None, freq="day", future=False): + def calendar(self, start_time=None, end_time=None, freq="day", freq_sam=None, future=False): """Get calendar of certain market in given time range. Parameters @@ -46,6 +47,8 @@ class CalendarProvider(abc.ABC): end of the time range. freq : str time frequency, available: year/quarter/month/week/day. + freq_sam : str + sample frequency used for sampling lower-frequency calendar, by default None(raw calendar). future : bool whether including future trading day. @@ -769,7 +772,7 @@ class ClientCalendarProvider(CalendarProvider): def set_conn(self, conn): self.conn = conn - def calendar(self, start_time=None, end_time=None, freq="day", future=False): + def calendar(self, start_time=None, end_time=None, freq="day", freq_sam=None, future=False): self.conn.send_request( request_type="calendar", @@ -937,8 +940,8 @@ class BaseProvider: To keep compatible with old qlib provider. """ - def calendar(self, start_time=None, end_time=None, freq="day", future=False): - return Cal.calendar(start_time, end_time, freq, future=future) + def calendar(self, start_time=None, end_time=None, freq="day", freq_sam=None, future=False): + return Cal.calendar(start_time, end_time, freq, freq_sam, future=future) def instruments(self, market="all", filter_pipe=None, start_time=None, end_time=None): if start_time is not None or end_time is not None: diff --git a/qlib/data/dataset/utils.py b/qlib/data/dataset/utils.py index feda19044..f7b07d563 100644 --- a/qlib/data/dataset/utils.py +++ b/qlib/data/dataset/utils.py @@ -70,3 +70,27 @@ def fetch_df_by_index( return df.loc[ pd.IndexSlice[idx_slc], ] + + +def convert_index_format(df: Union[pd.DataFrame, pd.Series], level: str = "datetime") -> Union[pd.DataFrame, pd.Series]: + """ + Convert the format of df.MultiIndex according to the following rules: + - If `level` is the first level of df.MultiIndex, do nothing + - If `level` is the second level of df.MultiIndex, swap the level of index. + + Parameters + ---------- + df : Union[pd.DataFrame, pd.Series] + raw DataFrame/Series + level : str, optional + the level that will be converted to the first one, by default "datetime" + + Returns + ------- + Union[pd.DataFrame, pd.Series] + converted DataFrame/Series + """ + + if get_level_index(df, level=level) == 1: + df = df.swaplevel().sort_index() + return df diff --git a/qlib/rl/__init__.py b/qlib/rl/__init__.py new file mode 100644 index 000000000..59e481eb9 --- /dev/null +++ b/qlib/rl/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. diff --git a/qlib/rl/env.py b/qlib/rl/env.py new file mode 100644 index 000000000..9424aafab --- /dev/null +++ b/qlib/rl/env.py @@ -0,0 +1,104 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from .interpreter import StateInterpreter, ActionInterpreter + +from ..contrib.backtest.executor import BaseExecutor + + +class BaseRLEnv: + def reset(self, **kwargs): + raise NotImplementedError("reset is not implemented!") + + def step(self, action): + """ + step method of rl env + Parameters + ---------- + action : + action from rl policy + + Returns + ------- + env state to rl policy + """ + raise NotImplementedError("step is not implemented!") + + +class QlibRLEnv: + """qlib-based RL env""" + + def __init__( + self, + executor: BaseExecutor, + ): + """ + Parameters + ---------- + executor : BaseExecutor + qlib multi-level/single-level executor, which can be regarded as gamecore in RL + """ + self.executor = executor + + def reset(self, **kwargs): + self.executor.reset(**kwargs) + + +class QlibIntRLEnv(QlibRLEnv): + """(Qlib)-based RL (Env) with (Interpreter)""" + + def __init__( + self, + executor: BaseExecutor, + state_interpreter: StateInterpreter, + action_interpreter: ActionInterpreter, + state_interpret_kwargs: dict = {}, + action_interpret_kwargs: dict = {}, + ): + """ + + Parameters + ---------- + state_interpreter : StateInterpreter + interpretor that interprets the qlib execute result into rl env state. + action_interpreter : ActionInterpreter + interpretor that interprets the rl agent action into qlib order list + state_interpret_kwargs : dict, optional + arguments may be used in `state_interpreter.interpret`, by default {} + such as the following arguments: + - trade exchange : Exchange + Exchange that can provide market info + action_interpret_kwargs: dict, optional + arguments may be used in `action_interpreter.interpret`, by default {} + such as the following arguments: + - trade_order_list : List[Order] + If the strategy is used to split order, it presents the trade order pool. + """ + super(QlibIntRLEnv, self).__init__(executor=executor) + self.state_interpreter = state_interpreter + self.action_interpreter = action_interpreter + self.state_interpret_kwargs = state_interpret_kwargs + self.action_interpret_kwargs = action_interpret_kwargs + + def step(self, action): + """ + step method of rl env, it run as following step: + - Use `action_interpreter.interpret` method to interpret the agent action into order list + - Execute the order list with qlib executor, and get the executed result + - Use `state_interpreter.interpret` method to interpret the executed result into env state + + Parameters + ---------- + action : + action from rl policy + + Returns + ------- + env state to rl rl policy + """ + _interpret_action = self.action_interpreter.interpret(action=action, **self.state_interpret_kwargs) + _execute_result = self.executor.execute(_interpret_action) + _interpret_state = self.state_interpreter.interpret( + execute_result=_execute_result, **self.action_interpret_kwargs + ) + return _interpret_state diff --git a/qlib/rl/interpreter.py b/qlib/rl/interpreter.py new file mode 100644 index 000000000..bad337f72 --- /dev/null +++ b/qlib/rl/interpreter.py @@ -0,0 +1,20 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + + +class BaseInterpreter: + @staticmethod + def interpret(**kwargs): + raise NotImplementedError("interpret is not implemented!") + + +class ActionInterpreter(BaseInterpreter): + @staticmethod + def interpret(action, **kwargs): + raise NotImplementedError("interpret is not implemented!") + + +class StateInterpreter(BaseInterpreter): + @staticmethod + def interpret(execute_result, **kwargs): + raise NotImplementedError("interpret is not implemented!") diff --git a/qlib/strategy/base.py b/qlib/strategy/base.py index 8a857eb00..a5e7210bd 100644 --- a/qlib/strategy/base.py +++ b/qlib/strategy/base.py @@ -1,55 +1,160 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. - -import copy -import warnings -import numpy as np import pandas as pd +from typing import Tuple, List, Union, Optional, Callable -from ..utils import get_sample_freq_calendar +from ..model.base import BaseModel from ..data.dataset import DatasetH -from ..data.dataset.utils import get_level_index +from ..data.dataset.utils import convert_index_format from ..contrib.backtest.order import Order -from ..contrib.backtest.env import BaseTradeCalendar - -""" -1. BaseStrategy 的粒度一定是数据粒度的整数倍 -- 关于calendar的合并咋整 -- adjust_dates这个东西啥用 -- label和freq和strategy的bar分离,这个如何决策呢 -""" +from ..contrib.backtest.executor import BaseTradeCalendar +from ..rl.interpreter import ActionInterpreter, StateInterpreter class BaseStrategy(BaseTradeCalendar): - def generate_order_list(self, **kwargs): + """Base strategy""" + + def generate_order_list(self, execute_state): + """Generate order list in each trading bar""" raise NotImplementedError("generator_order_list is not implemented!") class RuleStrategy(BaseStrategy): + """Trading strategy with rules""" + pass class ModelStrategy(BaseStrategy): - def __init__(self, step_bar, model, dataset: DatasetH, start_time=None, end_time=None, **kwargs): + """Trading Strategy by using Model to make predictions""" + + def __init__( + self, + step_bar: str, + model: BaseModel, + dataset: DatasetH, + start_time: Union[str, pd.Timestamp] = None, + end_time: Union[str, pd.Timestamp] = None, + **kwargs, + ): + """ + Parameters + ---------- + model : BaseModel + the model used in when making predictions + dataset : DatasetH + provide test data for model + kwargs : dict + arguments that will be passed into `reset` method + """ self.model = model self.dataset = dataset - self.pred_scores = self._convert_index_format(self.model.predict(dataset)) + self.pred_scores = convert_index_format(self.model.predict(dataset), level="datetime") # pred_score_dates = self.pred_scores.index.get_level_values(level="datetime") super(ModelStrategy, self).__init__(step_bar, start_time, end_time, **kwargs) - def _convert_index_format(self, df): - if get_level_index(df, level="datetime") == 1: - df = df.swaplevel().sort_index() - return df - def _update_model(self): - """update pred score""" + """ + Update model in each bar when using online data as the following steps: + - update dataset with online data, the dataset should support online update + - make the latest prediction scores of the new bar + - update the pred score into the latest prediction + """ raise NotImplementedError("_update_model is not implemented!") -class TradingEnhancement: - def reset(self, trade_order_list=None): +class RLStrategy(BaseStrategy): + """RL-based Strategy""" + + def __init__( + self, + step_bar: str, + policy, + start_time: Union[str, pd.Timestamp] = None, + end_time: Union[str, pd.Timestamp] = None, + **kwargs, + ): + """ + Parameters + ---------- + policy : + RL policy for generate action + """ + super(RLStrategy, self).__init__(step_bar, start_time, end_time, **kwargs) + self.policy = policy + + +class RLIntStrategy(RLStrategy): + """(RL)-based (Strategy) with (Int)erpreter""" + + def __init__( + self, + step_bar: str, + policy, + state_interpreter: StateInterpreter, + action_interpreter: ActionInterpreter, + start_time: Union[str, pd.Timestamp] = None, + end_time: Union[str, pd.Timestamp] = None, + state_interpret_kwargs: dict = {}, + action_interpret_kwargs: dict = {}, + **kwargs, + ): + """ + Parameters + ---------- + state_interpreter : StateInterpreter + interpretor that interprets the qlib execute result into rl env state. + action_interpreter : ActionInterpreter + interpretor that interprets the rl agent action into qlib order list + start_time : Union[str, pd.Timestamp], optional + start time of trading, by default None + end_time : Union[str, pd.Timestamp], optional + end time of trading, by default None + state_interpret_kwargs : dict, optional + arguments may be used in `state_interpreter.interpret`, by default {} + such as the following arguments: + - trade exchange : Exchange + Exchange that can provide market info + action_interpret_kwargs: dict, optional + arguments may be used in `action_interpreter.interpret`, by default {} + such as the following arguments: + - trade_order_list : List[Order] + If the strategy is used to split order, it presents the trade order pool. + """ + super(RLIntStrategy, self).__init__(step_bar, policy, start_time, end_time, **kwargs) + + self.policy = policy + self.action_interpreter = action_interpreter + self.state_interpreter = state_interpreter + self.state_interpret_kwargs = state_interpret_kwargs + self.action_interpret_kwargs = action_interpret_kwargs + + def generate_order_list(self, execute_state): + super(RLStrategy, self).step() + _interpret_state = self.state_interpretor.interpret( + execute_result=execute_state, **self.action_interpret_kwargs + ) + _policy_action = self.policy.step(_interpret_state) + _order_list = self.action_interpreter.interpret(action=_policy_action, **self.state_interpret_kwargs) + return _order_list + + +class OrderEnhancement: + """ + Order enhancement for strategy + - If the strategy is used to split orders, the enhancement should be inherited + - If the strategy is used for portfolio management, the enhancement can be ignored + """ + + def reset(self, trade_order_list: List[Order] = None): + """reset trade orders for split strategy + + Parameters + ---------- + trade_order_list for split strategy: List[Order], optional + trading orders , by default None + """ if trade_order_list is not None: self.trade_order_list = trade_order_list diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index a6bba1f38..15652dbaf 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -800,217 +800,3 @@ def fname_to_code(fname: str): if fname.startswith(prefix): fname = fname.lstrip(prefix) return fname - - -########################## Sample ############################ -def sample_calendar_bac(calendar_raw, freq_raw, freq_sam): - """ - freq_raw : "min" or "day" - """ - freq_raw = "1" + freq_raw if re.match("^[0-9]", freq_raw) is None else freq_raw - freq_sam = "1" + freq_sam if re.match("^[0-9]", freq_sam) is None else freq_sam - - if freq_sam.endswith(("minute", "min")): - - def cal_next_sam_minute(x, sam_minutes): - hour = x.hour - minute = x.minute - if 9 <= hour <= 11: - minute_index = (11 - hour) * 60 + 30 - minute + 120 - elif 13 <= hour <= 15: - minute_index = (15 - hour) * 60 - minute - else: - raise ValueError("calendar hour must be in [9, 11] or [13, 15]") - - minute_index = minute_index // sam_minutes * sam_minutes - - if 0 <= minute_index < 120: - return 15 - (minute_index + 59) // 60, (120 - minute_index) % 60 - elif 120 <= minute_index < 240: - return 11 - (minute_index - 120 + 29) // 60, (240 - minute_index + 30) % 60 - else: - raise ValueError("calendar minute_index error") - - sam_minutes = int(freq_sam[:-3]) if freq_sam.endswith("min") else int(freq_sam[:-6]) - - if not freq_raw.endswith(("minute", "min")): - raise ValueError("when sampling minute calendar, freq of raw calendar must be minute or min") - else: - raw_minutes = int(freq_raw[:-3]) if freq_raw.endswith("min") else int(freq_raw[:-6]) - if raw_minutes > sam_minutes: - raise ValueError("raw freq must be higher than sample freq") - - _calendar_minute = np.unique( - list( - map( - lambda x: pd.Timestamp(x.year, x.month, x.day, *cal_next_sam_minute(x, sam_minutes), 59), - calendar_raw, - ) - ) - ) - return _calendar_minute - else: - - _calendar_day = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, 23, 59, 59), calendar_raw))) - if freq_sam.endswith(("day", "d")): - sam_days = int(freq_sam[:-1]) if freq_sam.endswith("d") else int(freq_sam[:-3]) - return _calendar_day[(len(_calendar_day) + sam_days - 1) % sam_days :: sam_days] - - elif freq_sam.endswith(("week", "w")): - sam_weeks = int(freq_sam[:-1]) if freq_sam.endswith("w") else int(freq_sam[:-4]) - _day_in_week = np.array(list(map(lambda x: x.dayofweek, _calendar_day))) - _calendar_week = _calendar_day[np.ediff1d(_day_in_week[::-1], to_begin=1)[::-1] > 0] - return _calendar_week[(len(_calendar_week) + sam_weeks - 1) % sam_weeks :: sam_weeks] - - elif freq_sam.endswith(("month", "m")): - sam_months = int(freq_sam[:-1]) if freq_sam.endswith("m") else int(freq_sam[:-5]) - _day_in_month = np.array(list(map(lambda x: x.day, _calendar_day))) - _calendar_month = _calendar_day[np.ediff1d(_day_in_month[::-1], to_begin=1)[::-1] > 0] - return _calendar_month[(len(_calendar_month) + sam_months - 1) % sam_months :: sam_months] - else: - raise ValueError("sample freq must be xmin, xd, xw, xm") - - -def parse_freq(freq): - freq = freq.lower() - search_obj = re.search("^([0-9]*)([a-z]+)", freq) - if search_obj is None: - raise ValueError("freq format is not supported") - _count = int(search_obj.group(1) if search_obj.group(1) else "1") - _freq = search_obj.group(2) - _freq_format_dict = { - "month": "month", - "mon": "month", - "week": "week", - "w": "week", - "day": "day", - "d": "day", - "minute": "minute", - "min": "minute", - } - try: - _freq = _freq_format_dict.get(_freq) - except KeyError: - raise ValueError( - "freq format is not supported, the supported freq includes (x)month/m, (x)day/d, (x)minute/min" - ) - return _count, _freq - - -def sample_calendar(calendar_raw, freq_raw, freq_sam): - """ - freq_raw : "min" or "day" - """ - raw_count, freq_raw = parse_freq(freq_raw) - sam_count, freq_sam = parse_freq(freq_sam) - if not len(calendar_raw): - return calendar_raw - if freq_sam == "minute": - - def cal_next_sam_minute(x, sam_minutes): - hour = x.hour - minute = x.minute - if (hour == 9 and minute >= 30) or (9 < hour < 11) or (hour == 11 and minute < 30): - minute_index = (hour - 9) * 60 + minute - 30 - elif 13 <= hour < 15: - minute_index = (hour - 13) * 60 + minute + 120 - else: - raise ValueError("calendar hour must be in [9, 11] or [13, 15]") - - minute_index = minute_index // sam_minutes * sam_minutes - - if 0 <= minute_index < 120: - return 9 + (minute_index + 30) // 60, (minute_index + 30) % 60 - elif 120 <= minute_index < 240: - return 13 + (minute_index - 120) // 60, (minute_index - 120) % 60 - else: - raise ValueError("calendar minute_index error") - - if req_raw != "minute": - raise ValueError("when sampling minute calendar, freq of raw calendar must be minute or min") - else: - if raw_count > sam_count: - raise ValueError("raw freq must be higher than sample freq") - _calendar_minute = np.unique( - list( - map(lambda x: pd.Timestamp(x.year, x.month, x.day, *cal_next_sam_minute(x, sam_count), 0), calendar_raw) - ) - ) - if calendar_raw[0] > _calendar_minute[0]: - _calendar_minute[0] = calendar_raw[0] - return _calendar_minute - else: - _calendar_day = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, 0, 0, 0), calendar_raw))) - if freq_sam == "day": - return _calendar_day[::sam_count] - - elif freq_sam == "week": - _day_in_week = np.array(list(map(lambda x: x.dayofweek, _calendar_day))) - _calendar_week = _calendar_day[np.ediff1d(_day_in_week, to_begin=-1) < 0] - return _calendar_week[::sam_count] - - elif freq_sam == "month": - _day_in_month = np.array(list(map(lambda x: x.day, _calendar_day))) - _calendar_month = _calendar_day[np.ediff1d(_day_in_month, to_begin=-1) < 0] - return _calendar_month[::sam_count] - else: - raise ValueError("sample freq must be xmin, xd, xw, xm") - - -def get_sample_freq_calendar(start_time=None, end_time=None, freq="day", **kwargs): - _, norm_freq = parse_freq(freq) - - from ..data.data import Cal - - try: - _calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq=freq, **kwargs) - freq, freq_sam = freq, None - except ValueError: - freq_sam = freq - if norm_freq in ["month", "week", "day"]: - try: - _calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq="day", freq_sam=freq, **kwargs) - freq = "day" - except ValueError: - raise - _calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq="min", freq_sam=freq, **kwargs) - freq = "min" - elif norm_freq == "minute": - _calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq="min", freq_sam=freq, **kwargs) - freq = "min" - else: - raise ValueError(f"freq {freq} is not supported") - return _calendar, freq, freq_sam - - -def sample_feature(feature, start_time=None, end_time=None, fields=None, method="last", method_kwargs={}): - selector_datetime = slice(start_time, end_time) - fields = fields if fields else slice(None) - - from ..data.dataset.utils import get_level_index - - datetime_level = get_level_index(feature, level="datetime") == 0 - if isinstance(feature, pd.Series): - feature = feature.loc[selector_datetime] if datetime_level else feature.loc[(slice(None), selector_datetime)] - elif isinstance(feature, pd.DataFrame): - feature = ( - feature.loc[selector_datetime, fields] - if datetime_level - else feature.loc[(slice(None), selector_datetime), fields] - ) - if feature.empty: - return None - if isinstance(feature.index, pd.MultiIndex): - if callable(method): - method_func = method - return feature.groupby(level="instrument").apply(lambda x: method_func(x, **method_kwargs)) - elif isinstance(method, str): - return getattr(feature.groupby(level="instrument"), method)(**method_kwargs) - else: - if callable(method): - method_func = method - return method_func(feature, **method_kwargs) - elif isinstance(method, str): - return getattr(feature, method)(**method_kwargs) - - return feature diff --git a/qlib/utils/sample.py b/qlib/utils/sample.py new file mode 100644 index 000000000..9f67d4981 --- /dev/null +++ b/qlib/utils/sample.py @@ -0,0 +1,300 @@ +import re +import numpy as np +import pandas as pd +from typing import Tuple, List, Union, Optional, Callable + + +def parse_freq(freq: str) -> Tuple[int, str]: + """ + Parse freq into a unified format + + Parameters + ---------- + freq : str + Raw freq, supported freq should match the re '^([0-9]*)(month|mon|week|w|day|d|minute|min)$' + + Returns + ------- + freq: Tuple[int, str] + Unified freq, including freq count and unified freq unit. The freq unit should be '[month|week|day|minute]'. + Example: + + .. code-block:: + + print(parse_freq("day")) + (1, "day" ) + print(parse_freq("2mon")) + (2, "month") + print(parse_freq("10w")) + (10, "week") + + """ + freq = freq.lower() + match_obj = re.match("^([0-9]*)(month|mon|week|w|day|d|minute|min)$", freq) + if match_obj is None: + raise ValueError( + "freq format is not supported, the freq should be like (n)month/mon, (n)week/w, (n)day/d, (n)minute/min" + ) + _count = int(match_obj.group(1) if match_obj.group(1) else "1") + _freq = match_obj.group(2) + _freq_format_dict = { + "month": "month", + "mon": "month", + "week": "week", + "w": "week", + "day": "day", + "d": "day", + "minute": "minute", + "min": "minute", + } + return _count, _freq_format_dict[_freq] + + +def sample_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> np.ndarray: + """ + Sample the calendar with frequency freq_raw into the calendar with frequency freq_sam + + Parameters + ---------- + calendar_raw : np.ndarray + The calendar with frequency freq_raw + freq_raw : str + Frequency of the raw calendar + freq_sam : str + Sample frequency + + Returns + ------- + np.ndarray + The calendar with frequency freq_sam + """ + raw_count, freq_raw = parse_freq(freq_raw) + sam_count, freq_sam = parse_freq(freq_sam) + if not len(calendar_raw): + return calendar_raw + if freq_sam == "minute": + + def cal_next_sam_minute(x, sam_minutes): + hour = x.hour + minute = x.minute + if (hour == 9 and minute >= 30) or (9 < hour < 11) or (hour == 11 and minute < 30): + minute_index = (hour - 9) * 60 + minute - 30 + elif 13 <= hour < 15: + minute_index = (hour - 13) * 60 + minute + 120 + else: + raise ValueError("calendar hour must be in [9, 11] or [13, 15]") + + minute_index = minute_index // sam_minutes * sam_minutes + + if 0 <= minute_index < 120: + return 9 + (minute_index + 30) // 60, (minute_index + 30) % 60 + elif 120 <= minute_index < 240: + return 13 + (minute_index - 120) // 60, (minute_index - 120) % 60 + else: + raise ValueError("calendar minute_index error") + + if freq_raw != "minute": + raise ValueError("when sampling minute calendar, freq of raw calendar must be minute or min") + else: + if raw_count > sam_count: + raise ValueError("raw freq must be higher than sampling freq") + _calendar_minute = np.unique( + list( + map(lambda x: pd.Timestamp(x.year, x.month, x.day, *cal_next_sam_minute(x, sam_count), 0), calendar_raw) + ) + ) + if calendar_raw[0] > _calendar_minute[0]: + _calendar_minute[0] = calendar_raw[0] + return _calendar_minute + else: + _calendar_day = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, 0, 0, 0), calendar_raw))) + if freq_sam == "day": + return _calendar_day[::sam_count] + + elif freq_sam == "week": + _day_in_week = np.array(list(map(lambda x: x.dayofweek, _calendar_day))) + _calendar_week = _calendar_day[np.ediff1d(_day_in_week, to_begin=-1) < 0] + return _calendar_week[::sam_count] + + elif freq_sam == "month": + _day_in_month = np.array(list(map(lambda x: x.day, _calendar_day))) + _calendar_month = _calendar_day[np.ediff1d(_day_in_month, to_begin=-1) < 0] + return _calendar_month[::sam_count] + else: + raise ValueError("sampling freq must be xmin, xd, xw, xm") + + +def get_sample_freq_calendar( + start_time: Union[str, pd.Timestamp] = None, + end_time: Union[str, pd.Timestamp] = None, + freq: str = "day", + future: bool = False, +) -> Tuple[np.ndarray, str, Optional[str]]: + """ + Get the calendar with frequency freq. + + - If the calendar with the raw frequency freq exists, return it directly + + - Else, sample from a higher frequency calendar automatically + + Parameters + ---------- + start_time : Union[str, pd.Timestamp], optional + start time of calendar, by default None + end_time : Union[str, pd.Timestamp], optional + end time of calendar, by default None + freq : str, optional + freq of calendar, by default "day" + future : bool, optional + whether including future trading day. + + Returns + ------- + Tuple[np.ndarray, str, Optional[str]] + + - the first value is the calendar + - the second value is the raw freq of calendar + - the third value is the sampling freq of calendar, it's None if the raw frequency freq exists. + + """ + + _, norm_freq = parse_freq(freq) + + from ..data.data import Cal + + try: + _calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq=freq, future=future) + freq, freq_sam = freq, None + except ValueError: + freq_sam = freq + if norm_freq in ["month", "week", "day"]: + try: + _calendar = Cal.calendar( + start_time=start_time, end_time=end_time, freq="day", freq_sam=freq, future=future + ) + freq = "day" + except ValueError: + _calendar = Cal.calendar( + start_time=start_time, end_time=end_time, freq="min", freq_sam=freq, future=future + ) + freq = "min" + elif norm_freq == "minute": + _calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq="min", freq_sam=freq, future=future) + freq = "min" + else: + raise ValueError(f"freq {freq} is not supported") + return _calendar, freq, freq_sam + + +def sample_feature( + feature: Union[pd.DataFrame, pd.Series], + start_time: Union[str, pd.Timestamp] = None, + end_time: Union[str, pd.Timestamp] = None, + fields: Union[str, List[str]] = None, + method: Union[str, Callable] = "last", + method_kwargs: dict = {}, +): + """ + Sample value from pandas DataFrame or Series for each stock + + - If `feature` has MultiIndex[instrument, datetime], apply the `method` to each instruemnt data with datetime in [start_time, end_time] + Example: + + .. code-block:: + + print(feature) + $close $volume + instrument datetime + SH600000 2010-01-04 86.778313 16162960.0 + 2010-01-05 87.433578 28117442.0 + 2010-01-06 85.713585 23632884.0 + 2010-01-07 83.788803 20813402.0 + 2010-01-08 84.730675 16044853.0 + + SH600655 2010-01-04 2699.567383 158193.328125 + 2010-01-08 2612.359619 77501.406250 + 2010-01-11 2712.982422 160852.390625 + 2010-01-12 2788.688232 164587.937500 + 2010-01-13 2790.604004 145460.453125 + + print(sample_feature(feature, start_time="2010-01-04", end_time="2010-01-05", fields=["$close", "$volume"], method="last")) + $close $volume + instrument + SH600000 87.433578 28117442.0 + SH600655 2699.567383 158193.328125 + + - Else, the `feature` should have Index[datetime], just apply the `method` to `feature` directly + Example: + + .. code-block:: + print(feature) + $close $volume + datetime + 2010-01-04 86.778313 16162960.0 + 2010-01-05 87.433578 28117442.0 + 2010-01-06 85.713585 23632884.0 + 2010-01-07 83.788803 20813402.0 + 2010-01-08 84.730675 16044853.0 + + print(sample_feature(feature, start_time="2010-01-04", end_time="2010-01-05", fields=["$close", "$volume"], method="last")) + + $close 87.433578 + $volume 28117442.0 + + print(sample_feature(feature, start_time="2010-01-04", end_time="2010-01-05", fields="$close", method="last")) + + 87.433578 + + Parameters + ---------- + feature : Union[pd.DataFrame, pd.Series] + Raw feature to be sampled + start_time : Union[str, pd.Timestamp], optional + start sampling time, by default None + end_time : Union[str, pd.Timestamp], optional + end sampling time, by default None + fields : Union[str, List[str]], optional + column names, it's ignored when sample pd.Series data, by default None(all columns) + method : Union[str, Callable], optional + sample method, apply method function to each stock series data, by default "last" + - If type(method) is str, it should be an attribute of SeriesGroupBy or DataFrameGroupby, and run feature.groupby + - If `feature` has MultiIndex[instrument, datetime], method must be a member of pandas.groupby when it's type is str.or callable function. + method_kwargs : dict, optional + arguments of method, by default {} + + Returns + ------- + The Sampled DataFrame/Series/Value + """ + + selector_datetime = slice(start_time, end_time) + if fields is None: + fields = slice(None) + + from ..data.dataset.utils import get_level_index + + datetime_level = get_level_index(feature, level="datetime") == 0 + if isinstance(feature, pd.Series): + feature = feature.loc[selector_datetime] if datetime_level else feature.loc[(slice(None), selector_datetime)] + elif isinstance(feature, pd.DataFrame): + feature = ( + feature.loc[selector_datetime, fields] + if datetime_level + else feature.loc[(slice(None), selector_datetime), fields] + ) + if feature.empty: + return None + if isinstance(feature.index, pd.MultiIndex): + if callable(method): + method_func = method + return feature.groupby(level="instrument").apply(lambda x: method_func(x, **method_kwargs)) + elif isinstance(method, str): + return getattr(feature.groupby(level="instrument"), method)(**method_kwargs) + else: + if callable(method): + method_func = method + return method_func(feature, **method_kwargs) + elif isinstance(method, str): + return getattr(feature, method)(**method_kwargs) + + return feature diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index fa1dc2e25..8a8bde7ef 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -14,7 +14,8 @@ from ..data.dataset import DatasetH from ..data.dataset.handler import DataHandlerLP from ..utils import init_instance_by_config, get_module_by_module_path from ..log import get_module_logger -from ..utils import flatten_dict, parse_freq +from ..utils import flatten_dict +from ..utils.sample import parse_freq from ..strategy.base import BaseStrategy from ..contrib.eva.alpha import calc_ic, calc_long_short_return @@ -315,16 +316,6 @@ class PortAnaRecord(RecordTemp): ret_freq.extend(self._get_report_freq(env_config["kwargs"]["sub_env"])) return ret_freq - def _cal_risk_analysis_scaler(self, freq): - _count, _freq = parse_freq(freq) - _freq_scaler = { - "minute": 240 * 250, - "day": 250, - "week": 50, - "month": 12, - } - return _count * _freq_scaler[_freq] - def generate(self, **kwargs): # custom strategy and get backtest report_dict = normal_backtest(env=self.env_config, strategy=self.strategy_config, **self.backtest_config) @@ -343,12 +334,11 @@ class PortAnaRecord(RecordTemp): else: report_normal, _ = report_dict.get(self.risk_analysis_freq) analysis = dict() - risk_analysis_scaler = self._cal_risk_analysis_scaler(self.risk_analysis_freq) analysis["excess_return_without_cost"] = risk_analysis( - report_normal["return"] - report_normal["bench"], risk_analysis_scaler + report_normal["return"] - report_normal["bench"], self.risk_analysis_freq ) analysis["excess_return_with_cost"] = risk_analysis( - report_normal["return"] - report_normal["bench"] - report_normal["cost"], risk_analysis_scaler + report_normal["return"] - report_normal["bench"] - report_normal["cost"], self.risk_analysis_freq ) analysis_df = pd.concat(analysis) # type: pd.DataFrame # log metrics