diff --git a/qlib/backtest/account.py b/qlib/backtest/account.py index ce4b631ac..038bbcf60 100644 --- a/qlib/backtest/account.py +++ b/qlib/backtest/account.py @@ -26,10 +26,10 @@ rtn & earning in the Account class Account: - def __init__(self, init_cash, last_trade_date=None): - self.init_vars(init_cash, last_trade_date) + def __init__(self, init_cash, last_trade_time=None): + self.init_vars(init_cash, last_trade_time) - def init_vars(self, init_cash, last_trade_date=None): + def init_vars(self, init_cash, last_trade_time=None): # init cash self.init_cash = init_cash self.current = Position(cash=init_cash) @@ -40,7 +40,7 @@ class Account: self.val = 0 self.report = Report() self.earning = 0 - self.last_trade_date = last_trade_date + self.last_trade_time = last_trade_time def get_positions(self): return self.positions @@ -83,7 +83,7 @@ class Account: self.current.update_order(order, trade_val, cost, trade_price) self.update_state_from_order(order, trade_val, cost, trade_price) - def update_bar_end(self, start_time, end_time, trader): + def update_bar_end(self, trade_start_time, trade_end_time, trade_exchange): """ start_time: pd.TimeStamp end_time: pd.TimeStamp @@ -103,11 +103,11 @@ class Account: profit = 0 for code in stock_list: # if suspend, no new price to be updated, profit is 0 - if trader.check_stock_suspended(code, today): + if trade_exchange.check_stock_suspended(code, trade_start_time, trade_end_time): continue - today_close = trader.get_close(code, today) - profit += (today_close - self.current.position[code]["price"]) * self.current.position[code]["amount"] - self.current.update_stock_price(stock_id=code, price=today_close) + bar_close = trade_exchange.get_close(code, trade_start_time, trade_end_time) + profit += (bar_close - self.current.position[code]["price"]) * self.current.position[code]["amount"] + self.current.update_stock_price(stock_id=code, price=bar_close) self.rtn += profit # update holding day count self.current.add_count_all() @@ -117,54 +117,55 @@ class Account: # account_value - last_account_value # for the first trade date, account_value - init_cash # self.report.is_empty() to judge is_first_trade_date - # get last_account_value, today_account_value, today_stock_value + # get last_account_value, now_account_value, now_stock_value if self.report.is_empty(): last_account_value = self.init_cash else: last_account_value = self.report.get_latest_account_value() - today_account_value = self.current.calculate_value() - today_stock_value = self.current.calculate_stock_value() - self.earning = today_account_value - last_account_value + now_account_value = self.current.calculate_value() + now_stock_value = self.current.calculate_stock_value() + self.earning = now_account_value - last_account_value # update report for today # judge whether the the trading is begin. # and don't add init account state into report, due to we don't have excess return in those days. self.report.update_report_record( - trade_date=today, - account_value=today_account_value, + trade_start_time=trade_start_time, + trade_end_time=trade_end_time, + account_value=now_account_value, cash=self.current.position["cash"], return_rate=(self.earning + self.ct) / last_account_value, # here use earning to calculate return, position's view, earning consider cost, true return # in order to make same definition with original backtest in evaluate.py turnover_rate=self.to / last_account_value, cost_rate=self.ct / last_account_value, - stock_value=today_stock_value, + stock_value=now_stock_value, ) - # set today_account_value to position - self.current.position["today_account_value"] = today_account_value + # set now_account_value to position + self.current.position["now_account_value"] = now_account_value self.current.update_weight_all() # update positions # note use deepcopy - self.positions[today] = copy.deepcopy(self.current) + self.positions[trade_start_time] = copy.deepcopy(self.current) # finish today's updation # reset the daily variables self.rtn = 0 self.ct = 0 self.to = 0 - self.last_trade_date = today + self.last_trade_time = (trade_start_time, trade_end_time) def load_account(self, account_path): report = Report() position = Position() - last_trade_date = position.load_position(account_path / "position.xlsx") + last_trade_time = position.load_position(account_path / "position.xlsx") report.load_report(account_path / "report.csv") # assign values self.init_vars(position.init_cash) self.current = position self.report = report - self.last_trade_date = last_trade_date if last_trade_date else None + self.last_trade_time = last_trade_time def save_account(self, account_path): - self.current.save_position(account_path / "position.xlsx", self.last_trade_date) + self.current.save_position(account_path / "position.xlsx", self.last_trade_time) self.report.save_report(account_path / "report.csv") diff --git a/qlib/backtest/env.py b/qlib/backtest/env.py new file mode 100644 index 000000000..32ed91ef0 --- /dev/null +++ b/qlib/backtest/env.py @@ -0,0 +1,169 @@ + + +import re +import json +import copy +import pathlib +import pandas as pd +from loguru import Logger +from ...data import D +from ...utils import get_date_in_file_name +from ...utils import get_pre_trading_date +from ..backtest.order import Order +from ..utils import init_instance_by_config + +class BaseEnv: + """ + # Strategy framework document + + class Env(BaseEnv): + """ + + def __init__( + self, + step_bar, + trade_account, + start_time=None, + end_time=None, + track=False, + verbose=False, + **kwargs + ): + self.step_bar = step_bar + self.reset(start_time=start_time, end_time=end_time, trade_account=trade_account, track=track, **kwargs) + + def _reset_trade_date(self, start_time=None, end_time=None): + if start_time: + self.start_time = start_time + if end_time: + self.end_time = end_time + if not self.start_time or not self.end_time: + raise ValueError("value of `start_time` or `end_time` is None") + _calendar = get_sample_freq_calendar(start_time=start_time, end_time=end_time, freq=step_bar) + self.trade_dates = np.hstack(_calendar, pd.Timestamp(self.end_time)) + self.trade_len = len(self.trade_dates) + self.trade_index = 0 + + def reset(self, start_time=None, end_time=None, **kwargs): + if start_time or end_time: + self._reset_trade_date(start_time=start_time, end_time=end_time) + self.track = kwargs.get("track", False) + self.upper_action = kwargs.get("upper_action", None) + self.trade_account = init_instance_by_config(kwargs.get("trade_account")) + return self.trade_account + + def execute(self, **kwargs): + self.trade_index = self.trade_index + 1 + return + ( + self.trade_account, + { + "start_time": self.start_time, + "end_time": self.end_time, + "trade_len": self.trade_len, + "trade_index": self.trade_index - 1, + } + ) + + def finished(self): + return self.trade_index >= self.trade_len - 1 + + + +class SplitEnv(BaseEnv): + def __init__( + self, + step_bar, + start_time, + end_time, + trade_account, + sub_env, + sub_strategy, + track=False, + verbose=False, + **kwargs + ): + self.sub_env = sub_env + self.sub_strategy = sub_strategy + super(SplitEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, track=track) + + def execute(self, order_list, **kwargs): + if self.finished(): + raise StopIteration(f"this env has completed its task, please reset it if you want to call it!") + #if self.track: + # yield action + #episode_reward = 0 + trade_start_time = self.trade_dates[self.trade_index] + trade_end_time = self.trade_dates[self.trade_index + 1] + self.sub_strategy.reset(trade_order_list=order_list) + sub_account = self.sub_env.reset(trade_order_list=order_list, start_time=self.trade_dates[self.trade_index - 1], end_time=self.trade_dates[self.trade_index]) + while not self.sub_env.finished(): + sub_order_list = self.sub_strategy.generate_order(sub_account) + sub_account, sub_info = self.sub_env.execute(sub_order_list) + #episode_reward += sub_reward + _account, _info = super(SimulatorEnv, self).execute(**kwargs) + return _account, _info + + + +class SimulatorEnv(BaseEnv): + + def __init__( + self, + step_bar, + start_time, + end_time, + trade_account, + trade_exchange, + track=False, + verbose=False, + **kwargs + ): + self.trade_exchange = trade_exchange + super(SimulatorEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, track=track, verbose=verbose) + + def execute(self, order_list, **kwargs): + """ + Return: obs, done, info + """ + if self.finished(): + raise StopIteration(f"this env has completed its task, please reset it if you want to call it!") + + trade_start_time = self.trade_dates[self.trade_index] + trade_end_time = self.trade_dates[self.trade_index + 1] + trade_info = [] + for order in order_list: + if self.trade_exchange.check_order(order) is True: + # execute the order + trade_val, trade_cost, trade_price = self.trade_exchange.deal_order(order, trade_account=self.trade_account) + trade_info.append((order, trade_val, trade_cost, trade_price)) + if self.verbose: + if order.direction == Order.SELL: # sell + print( + "[I {:%Y-%m-%d}]: sell {}, price {:.2f}, amount {}, value {:.2f}.".format( + trade_start_time, + order.stock_id, + trade_price, + order.deal_amount, + trade_val, + ) + ) + else: + print( + "[I {:%Y-%m-%d}]: buy {}, price {:.2f}, amount {}, value {:.2f}.".format( + trade_start_time, + order.stock_id, + trade_price, + order.deal_amount, + trade_val, + ) + ) + + else: + if self.verbose: + print("[W {:%Y-%m-%d}]: {} wrong.".format(trade_start_time, order.stock_id)) + # do nothing + pass + self.trade_account.update_bar_end(trade_start_time=trade_start_time, trade_end_time=trade_end_time, trade_exchange=self.trade_exchange) + _account, _info = super(SimulatorEnv, self).execute(**kwargs) + return _account, {**_info, "trade_info", trade_info} \ No newline at end of file diff --git a/qlib/backtest/position.py b/qlib/backtest/position.py index c63651164..9945a7e8f 100644 --- a/qlib/backtest/position.py +++ b/qlib/backtest/position.py @@ -163,14 +163,15 @@ class Position: for stock_code, weight in weight_dict.items(): self.update_stock_weight(stock_code, weight) - def save_position(self, path, last_trade_date): + def save_position(self, path, last_trade_time): path = pathlib.Path(path) p = copy.deepcopy(self.position) cash = pd.Series(dtype=np.float) cash["init_cash"] = self.init_cash cash["cash"] = p["cash"] cash["today_account_value"] = p["today_account_value"] - cash["last_trade_date"] = str(last_trade_date.date()) if last_trade_date else None + cash["last_trade_start_time"] = str(last_trade_time[0]) if last_trade_time else None + cash["last_trade_end_time"] = str(last_trade_time[1]) if last_trade_time else None del p["cash"] del p["today_account_value"] positions = pd.DataFrame.from_dict(p, orient="index") @@ -201,7 +202,8 @@ class Position: init_cash = cash_record.loc["init_cash"].values[0] cash = cash_record.loc["cash"].values[0] today_account_value = cash_record.loc["today_account_value"].values[0] - last_trade_date = cash_record.loc["last_trade_date"].values[0] + last_trade_start_time = cash_record.loc["last_trade_start_time"].values[0] + last_trade_end_time = cash_record.loc["last_trade_end_time"].values[0] # assign values self.position = {} @@ -210,4 +212,6 @@ class Position: self.position["cash"] = cash self.position["today_account_value"] = today_account_value - return None if pd.isna(last_trade_date) else pd.Timestamp(last_trade_date) + last_trade_start_time = None is pd.isna(last_trade_start_time) else pd.Timestamp(last_trade_start_time) + last_trade_end_time = None is pd.isna(last_trade_end_time) else pd.Timestamp(last_trade_end_time) + return last_trade_start_time, last_trade_end_time diff --git a/qlib/backtest/report.py b/qlib/backtest/report.py index beb9759d0..9a57156f2 100644 --- a/qlib/backtest/report.py +++ b/qlib/backtest/report.py @@ -21,20 +21,20 @@ class Report: self.costs = OrderedDict() # trade cost for each trade date self.values = OrderedDict() # value for each trade date self.cashes = OrderedDict() - self.latest_report_date = None # pd.TimeStamp + self.latest_report_time = None # pd.TimeStamp def is_empty(self): return len(self.accounts) == 0 def get_latest_date(self): - return self.latest_report_date + return self.latest_report_time def get_latest_account_value(self): - return self.accounts[self.latest_report_date] + return self.accounts[self.latest_report_time] def update_report_record( self, - trade_date=None, + trade_time=None, account_value=None, cash=None, return_rate=None, @@ -44,7 +44,7 @@ class Report: ): # check data if None in [ - trade_date, + trade_time, account_value, cash, return_rate, @@ -56,14 +56,14 @@ class Report: "None in [trade_date, account_value, cash, return_rate, turnover_rate, cost_rate, stock_value]" ) # update report data - self.accounts[trade_date] = account_value - self.returns[trade_date] = return_rate - self.turnovers[trade_date] = turnover_rate - self.costs[trade_date] = cost_rate - self.values[trade_date] = stock_value - self.cashes[trade_date] = cash + self.accounts[trade_time] = account_value + self.returns[trade_time] = return_rate + self.turnovers[trade_time] = turnover_rate + self.costs[trade_time] = cost_rate + self.values[trade_time] = stock_value + self.cashes[trade_time] = cash # update latest_report_date - self.latest_report_date = trade_date + self.latest_report_time = trade_time # finish daily report update def generate_report_dataframe(self): @@ -74,7 +74,7 @@ class Report: report["cost"] = pd.Series(self.costs) report["value"] = pd.Series(self.values) report["cash"] = pd.Series(self.cashes) - report.index.name = "date" + report.index.name = "trade_time" return report def save_report(self, path): @@ -94,13 +94,13 @@ class Report: index = r.index self.init_vars() - for date in index: + for trade_time in index: self.update_report_record( - trade_date=date, - account_value=r.loc[date]["account"], - cash=r.loc[date]["cash"], - return_rate=r.loc[date]["return"], - turnover_rate=r.loc[date]["turnover"], - cost_rate=r.loc[date]["cost"], - stock_value=r.loc[date]["value"], + trade_time=trade_time, + account_value=r.loc[trade_time]["account"], + cash=r.loc[trade_time]["cash"], + return_rate=r.loc[trade_time]["return"], + turnover_rate=r.loc[trade_time]["turnover"], + cost_rate=r.loc[trade_time]["cost"], + stock_value=r.loc[trade_time]["value"], ) diff --git a/qlib/contrib/backtest_new/backtest.py b/qlib/contrib/backtest_new/backtest.py new file mode 100644 index 000000000..e69de29bb diff --git a/qlib/strategy/base.py b/qlib/strategy/base.py index 692959f21..03b9d88c0 100644 --- a/qlib/strategy/base.py +++ b/qlib/strategy/base.py @@ -32,7 +32,7 @@ class BaseStrategy: if not self.start_time or not self.end_time: raise ValueError("value of `start_time` or `end_time` is None") _calendar = get_sample_freq_calendar(start_time=start_time, end_time=end_time, freq=step_bar) - self.trade_dates = np.hstack(pd.Timestamp(self.start_time), _calendar, self.end_time) + self.trade_dates = np.hstack(_calendar, pd.Timestamp(self.end_time)) self.trade_len = len(self.trade_dates) self.trade_index = 0 diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index 1c0ef87a4..2cd2f5d13 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -866,14 +866,15 @@ def sample_calendar(calendar_raw, freq_raw, freq_sam): """ freq_raw = "1" + freq_raw if re.match("^[0-9]", freq_raw) is None else freq_raw freq_sam = "1" + freq_sam if re.match("^[0-9]", freq_sam) is None else freq_sam - + if not len(calendar_raw): + return calendar_raw if freq_sam.endswith(("minute", "min")): def cal_next_sam_minute(x, sam_minutes): hour = x.hour minute = x.minute - if 9 <= hour <= 11: + if (hour == 9 and minute >= 30) or (9 < hour < 11) or (hour == 11 and minute < 30): minute_index = (hour - 9)*60 + minute - 30 - elif 13 <= hour <= 15: + elif 13 <= hour < 15: minute_index = (hour - 13)*60 + minute + 120 else: raise ValueError("calendar hour must be in [9, 11] or [13, 15]") @@ -894,6 +895,8 @@ def sample_calendar(calendar_raw, freq_raw, freq_sam): if raw_minutes > sam_minutes: raise ValueError("raw freq must be higher than sample freq") _calendar_minute = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, *cal_next_sam_minute(x, sam_minutes), 0), calendar_raw))) + if calendar_raw[0] > _calendar_minute[0]: + _calendar_minute[0] = calendar_raw[0] return _calendar_minute else: _calendar_day = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, 0, 0, 0), calendar_raw))) @@ -944,4 +947,5 @@ def sample_feature(feature, instruments=None, start_time=None, end_time=None, fi if method: return getattr(feature.groupby(level="instrument"), method)(**method_kwargs) else: - return feature \ No newline at end of file + return feature +