mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-04 19:41:00 +08:00
update report & account
This commit is contained in:
@@ -26,10 +26,10 @@ rtn & earning in the Account
|
||||
|
||||
|
||||
class Account:
|
||||
def __init__(self, init_cash, last_trade_date=None):
|
||||
self.init_vars(init_cash, last_trade_date)
|
||||
def __init__(self, init_cash, last_trade_time=None):
|
||||
self.init_vars(init_cash, last_trade_time)
|
||||
|
||||
def init_vars(self, init_cash, last_trade_date=None):
|
||||
def init_vars(self, init_cash, last_trade_time=None):
|
||||
# init cash
|
||||
self.init_cash = init_cash
|
||||
self.current = Position(cash=init_cash)
|
||||
@@ -40,7 +40,7 @@ class Account:
|
||||
self.val = 0
|
||||
self.report = Report()
|
||||
self.earning = 0
|
||||
self.last_trade_date = last_trade_date
|
||||
self.last_trade_time = last_trade_time
|
||||
|
||||
def get_positions(self):
|
||||
return self.positions
|
||||
@@ -83,7 +83,7 @@ class Account:
|
||||
self.current.update_order(order, trade_val, cost, trade_price)
|
||||
self.update_state_from_order(order, trade_val, cost, trade_price)
|
||||
|
||||
def update_bar_end(self, start_time, end_time, trader):
|
||||
def update_bar_end(self, trade_start_time, trade_end_time, trade_exchange):
|
||||
"""
|
||||
start_time: pd.TimeStamp
|
||||
end_time: pd.TimeStamp
|
||||
@@ -103,11 +103,11 @@ class Account:
|
||||
profit = 0
|
||||
for code in stock_list:
|
||||
# if suspend, no new price to be updated, profit is 0
|
||||
if trader.check_stock_suspended(code, today):
|
||||
if trade_exchange.check_stock_suspended(code, trade_start_time, trade_end_time):
|
||||
continue
|
||||
today_close = trader.get_close(code, today)
|
||||
profit += (today_close - self.current.position[code]["price"]) * self.current.position[code]["amount"]
|
||||
self.current.update_stock_price(stock_id=code, price=today_close)
|
||||
bar_close = trade_exchange.get_close(code, trade_start_time, trade_end_time)
|
||||
profit += (bar_close - self.current.position[code]["price"]) * self.current.position[code]["amount"]
|
||||
self.current.update_stock_price(stock_id=code, price=bar_close)
|
||||
self.rtn += profit
|
||||
# update holding day count
|
||||
self.current.add_count_all()
|
||||
@@ -117,54 +117,55 @@ class Account:
|
||||
# account_value - last_account_value
|
||||
# for the first trade date, account_value - init_cash
|
||||
# self.report.is_empty() to judge is_first_trade_date
|
||||
# get last_account_value, today_account_value, today_stock_value
|
||||
# get last_account_value, now_account_value, now_stock_value
|
||||
if self.report.is_empty():
|
||||
last_account_value = self.init_cash
|
||||
else:
|
||||
last_account_value = self.report.get_latest_account_value()
|
||||
today_account_value = self.current.calculate_value()
|
||||
today_stock_value = self.current.calculate_stock_value()
|
||||
self.earning = today_account_value - last_account_value
|
||||
now_account_value = self.current.calculate_value()
|
||||
now_stock_value = self.current.calculate_stock_value()
|
||||
self.earning = now_account_value - last_account_value
|
||||
# update report for today
|
||||
# judge whether the the trading is begin.
|
||||
# and don't add init account state into report, due to we don't have excess return in those days.
|
||||
self.report.update_report_record(
|
||||
trade_date=today,
|
||||
account_value=today_account_value,
|
||||
trade_start_time=trade_start_time,
|
||||
trade_end_time=trade_end_time,
|
||||
account_value=now_account_value,
|
||||
cash=self.current.position["cash"],
|
||||
return_rate=(self.earning + self.ct) / last_account_value,
|
||||
# here use earning to calculate return, position's view, earning consider cost, true return
|
||||
# in order to make same definition with original backtest in evaluate.py
|
||||
turnover_rate=self.to / last_account_value,
|
||||
cost_rate=self.ct / last_account_value,
|
||||
stock_value=today_stock_value,
|
||||
stock_value=now_stock_value,
|
||||
)
|
||||
# set today_account_value to position
|
||||
self.current.position["today_account_value"] = today_account_value
|
||||
# set now_account_value to position
|
||||
self.current.position["now_account_value"] = now_account_value
|
||||
self.current.update_weight_all()
|
||||
# update positions
|
||||
# note use deepcopy
|
||||
self.positions[today] = copy.deepcopy(self.current)
|
||||
self.positions[trade_start_time] = copy.deepcopy(self.current)
|
||||
|
||||
# finish today's updation
|
||||
# reset the daily variables
|
||||
self.rtn = 0
|
||||
self.ct = 0
|
||||
self.to = 0
|
||||
self.last_trade_date = today
|
||||
self.last_trade_time = (trade_start_time, trade_end_time)
|
||||
|
||||
def load_account(self, account_path):
|
||||
report = Report()
|
||||
position = Position()
|
||||
last_trade_date = position.load_position(account_path / "position.xlsx")
|
||||
last_trade_time = position.load_position(account_path / "position.xlsx")
|
||||
report.load_report(account_path / "report.csv")
|
||||
|
||||
# assign values
|
||||
self.init_vars(position.init_cash)
|
||||
self.current = position
|
||||
self.report = report
|
||||
self.last_trade_date = last_trade_date if last_trade_date else None
|
||||
self.last_trade_time = last_trade_time
|
||||
|
||||
def save_account(self, account_path):
|
||||
self.current.save_position(account_path / "position.xlsx", self.last_trade_date)
|
||||
self.current.save_position(account_path / "position.xlsx", self.last_trade_time)
|
||||
self.report.save_report(account_path / "report.csv")
|
||||
|
||||
169
qlib/backtest/env.py
Normal file
169
qlib/backtest/env.py
Normal file
@@ -0,0 +1,169 @@
|
||||
|
||||
|
||||
import re
|
||||
import json
|
||||
import copy
|
||||
import pathlib
|
||||
import pandas as pd
|
||||
from loguru import Logger
|
||||
from ...data import D
|
||||
from ...utils import get_date_in_file_name
|
||||
from ...utils import get_pre_trading_date
|
||||
from ..backtest.order import Order
|
||||
from ..utils import init_instance_by_config
|
||||
|
||||
class BaseEnv:
|
||||
"""
|
||||
# Strategy framework document
|
||||
|
||||
class Env(BaseEnv):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
step_bar,
|
||||
trade_account,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
track=False,
|
||||
verbose=False,
|
||||
**kwargs
|
||||
):
|
||||
self.step_bar = step_bar
|
||||
self.reset(start_time=start_time, end_time=end_time, trade_account=trade_account, track=track, **kwargs)
|
||||
|
||||
def _reset_trade_date(self, start_time=None, end_time=None):
|
||||
if start_time:
|
||||
self.start_time = start_time
|
||||
if end_time:
|
||||
self.end_time = end_time
|
||||
if not self.start_time or not self.end_time:
|
||||
raise ValueError("value of `start_time` or `end_time` is None")
|
||||
_calendar = get_sample_freq_calendar(start_time=start_time, end_time=end_time, freq=step_bar)
|
||||
self.trade_dates = np.hstack(_calendar, pd.Timestamp(self.end_time))
|
||||
self.trade_len = len(self.trade_dates)
|
||||
self.trade_index = 0
|
||||
|
||||
def reset(self, start_time=None, end_time=None, **kwargs):
|
||||
if start_time or end_time:
|
||||
self._reset_trade_date(start_time=start_time, end_time=end_time)
|
||||
self.track = kwargs.get("track", False)
|
||||
self.upper_action = kwargs.get("upper_action", None)
|
||||
self.trade_account = init_instance_by_config(kwargs.get("trade_account"))
|
||||
return self.trade_account
|
||||
|
||||
def execute(self, **kwargs):
|
||||
self.trade_index = self.trade_index + 1
|
||||
return
|
||||
(
|
||||
self.trade_account,
|
||||
{
|
||||
"start_time": self.start_time,
|
||||
"end_time": self.end_time,
|
||||
"trade_len": self.trade_len,
|
||||
"trade_index": self.trade_index - 1,
|
||||
}
|
||||
)
|
||||
|
||||
def finished(self):
|
||||
return self.trade_index >= self.trade_len - 1
|
||||
|
||||
|
||||
|
||||
class SplitEnv(BaseEnv):
|
||||
def __init__(
|
||||
self,
|
||||
step_bar,
|
||||
start_time,
|
||||
end_time,
|
||||
trade_account,
|
||||
sub_env,
|
||||
sub_strategy,
|
||||
track=False,
|
||||
verbose=False,
|
||||
**kwargs
|
||||
):
|
||||
self.sub_env = sub_env
|
||||
self.sub_strategy = sub_strategy
|
||||
super(SplitEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, track=track)
|
||||
|
||||
def execute(self, order_list, **kwargs):
|
||||
if self.finished():
|
||||
raise StopIteration(f"this env has completed its task, please reset it if you want to call it!")
|
||||
#if self.track:
|
||||
# yield action
|
||||
#episode_reward = 0
|
||||
trade_start_time = self.trade_dates[self.trade_index]
|
||||
trade_end_time = self.trade_dates[self.trade_index + 1]
|
||||
self.sub_strategy.reset(trade_order_list=order_list)
|
||||
sub_account = self.sub_env.reset(trade_order_list=order_list, start_time=self.trade_dates[self.trade_index - 1], end_time=self.trade_dates[self.trade_index])
|
||||
while not self.sub_env.finished():
|
||||
sub_order_list = self.sub_strategy.generate_order(sub_account)
|
||||
sub_account, sub_info = self.sub_env.execute(sub_order_list)
|
||||
#episode_reward += sub_reward
|
||||
_account, _info = super(SimulatorEnv, self).execute(**kwargs)
|
||||
return _account, _info
|
||||
|
||||
|
||||
|
||||
class SimulatorEnv(BaseEnv):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
step_bar,
|
||||
start_time,
|
||||
end_time,
|
||||
trade_account,
|
||||
trade_exchange,
|
||||
track=False,
|
||||
verbose=False,
|
||||
**kwargs
|
||||
):
|
||||
self.trade_exchange = trade_exchange
|
||||
super(SimulatorEnv, self).__init__(step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, track=track, verbose=verbose)
|
||||
|
||||
def execute(self, order_list, **kwargs):
|
||||
"""
|
||||
Return: obs, done, info
|
||||
"""
|
||||
if self.finished():
|
||||
raise StopIteration(f"this env has completed its task, please reset it if you want to call it!")
|
||||
|
||||
trade_start_time = self.trade_dates[self.trade_index]
|
||||
trade_end_time = self.trade_dates[self.trade_index + 1]
|
||||
trade_info = []
|
||||
for order in order_list:
|
||||
if self.trade_exchange.check_order(order) is True:
|
||||
# execute the order
|
||||
trade_val, trade_cost, trade_price = self.trade_exchange.deal_order(order, trade_account=self.trade_account)
|
||||
trade_info.append((order, trade_val, trade_cost, trade_price))
|
||||
if self.verbose:
|
||||
if order.direction == Order.SELL: # sell
|
||||
print(
|
||||
"[I {:%Y-%m-%d}]: sell {}, price {:.2f}, amount {}, value {:.2f}.".format(
|
||||
trade_start_time,
|
||||
order.stock_id,
|
||||
trade_price,
|
||||
order.deal_amount,
|
||||
trade_val,
|
||||
)
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"[I {:%Y-%m-%d}]: buy {}, price {:.2f}, amount {}, value {:.2f}.".format(
|
||||
trade_start_time,
|
||||
order.stock_id,
|
||||
trade_price,
|
||||
order.deal_amount,
|
||||
trade_val,
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
if self.verbose:
|
||||
print("[W {:%Y-%m-%d}]: {} wrong.".format(trade_start_time, order.stock_id))
|
||||
# do nothing
|
||||
pass
|
||||
self.trade_account.update_bar_end(trade_start_time=trade_start_time, trade_end_time=trade_end_time, trade_exchange=self.trade_exchange)
|
||||
_account, _info = super(SimulatorEnv, self).execute(**kwargs)
|
||||
return _account, {**_info, "trade_info", trade_info}
|
||||
@@ -163,14 +163,15 @@ class Position:
|
||||
for stock_code, weight in weight_dict.items():
|
||||
self.update_stock_weight(stock_code, weight)
|
||||
|
||||
def save_position(self, path, last_trade_date):
|
||||
def save_position(self, path, last_trade_time):
|
||||
path = pathlib.Path(path)
|
||||
p = copy.deepcopy(self.position)
|
||||
cash = pd.Series(dtype=np.float)
|
||||
cash["init_cash"] = self.init_cash
|
||||
cash["cash"] = p["cash"]
|
||||
cash["today_account_value"] = p["today_account_value"]
|
||||
cash["last_trade_date"] = str(last_trade_date.date()) if last_trade_date else None
|
||||
cash["last_trade_start_time"] = str(last_trade_time[0]) if last_trade_time else None
|
||||
cash["last_trade_end_time"] = str(last_trade_time[1]) if last_trade_time else None
|
||||
del p["cash"]
|
||||
del p["today_account_value"]
|
||||
positions = pd.DataFrame.from_dict(p, orient="index")
|
||||
@@ -201,7 +202,8 @@ class Position:
|
||||
init_cash = cash_record.loc["init_cash"].values[0]
|
||||
cash = cash_record.loc["cash"].values[0]
|
||||
today_account_value = cash_record.loc["today_account_value"].values[0]
|
||||
last_trade_date = cash_record.loc["last_trade_date"].values[0]
|
||||
last_trade_start_time = cash_record.loc["last_trade_start_time"].values[0]
|
||||
last_trade_end_time = cash_record.loc["last_trade_end_time"].values[0]
|
||||
|
||||
# assign values
|
||||
self.position = {}
|
||||
@@ -210,4 +212,6 @@ class Position:
|
||||
self.position["cash"] = cash
|
||||
self.position["today_account_value"] = today_account_value
|
||||
|
||||
return None if pd.isna(last_trade_date) else pd.Timestamp(last_trade_date)
|
||||
last_trade_start_time = None is pd.isna(last_trade_start_time) else pd.Timestamp(last_trade_start_time)
|
||||
last_trade_end_time = None is pd.isna(last_trade_end_time) else pd.Timestamp(last_trade_end_time)
|
||||
return last_trade_start_time, last_trade_end_time
|
||||
|
||||
@@ -21,20 +21,20 @@ class Report:
|
||||
self.costs = OrderedDict() # trade cost for each trade date
|
||||
self.values = OrderedDict() # value for each trade date
|
||||
self.cashes = OrderedDict()
|
||||
self.latest_report_date = None # pd.TimeStamp
|
||||
self.latest_report_time = None # pd.TimeStamp
|
||||
|
||||
def is_empty(self):
|
||||
return len(self.accounts) == 0
|
||||
|
||||
def get_latest_date(self):
|
||||
return self.latest_report_date
|
||||
return self.latest_report_time
|
||||
|
||||
def get_latest_account_value(self):
|
||||
return self.accounts[self.latest_report_date]
|
||||
return self.accounts[self.latest_report_time]
|
||||
|
||||
def update_report_record(
|
||||
self,
|
||||
trade_date=None,
|
||||
trade_time=None,
|
||||
account_value=None,
|
||||
cash=None,
|
||||
return_rate=None,
|
||||
@@ -44,7 +44,7 @@ class Report:
|
||||
):
|
||||
# check data
|
||||
if None in [
|
||||
trade_date,
|
||||
trade_time,
|
||||
account_value,
|
||||
cash,
|
||||
return_rate,
|
||||
@@ -56,14 +56,14 @@ class Report:
|
||||
"None in [trade_date, account_value, cash, return_rate, turnover_rate, cost_rate, stock_value]"
|
||||
)
|
||||
# update report data
|
||||
self.accounts[trade_date] = account_value
|
||||
self.returns[trade_date] = return_rate
|
||||
self.turnovers[trade_date] = turnover_rate
|
||||
self.costs[trade_date] = cost_rate
|
||||
self.values[trade_date] = stock_value
|
||||
self.cashes[trade_date] = cash
|
||||
self.accounts[trade_time] = account_value
|
||||
self.returns[trade_time] = return_rate
|
||||
self.turnovers[trade_time] = turnover_rate
|
||||
self.costs[trade_time] = cost_rate
|
||||
self.values[trade_time] = stock_value
|
||||
self.cashes[trade_time] = cash
|
||||
# update latest_report_date
|
||||
self.latest_report_date = trade_date
|
||||
self.latest_report_time = trade_time
|
||||
# finish daily report update
|
||||
|
||||
def generate_report_dataframe(self):
|
||||
@@ -74,7 +74,7 @@ class Report:
|
||||
report["cost"] = pd.Series(self.costs)
|
||||
report["value"] = pd.Series(self.values)
|
||||
report["cash"] = pd.Series(self.cashes)
|
||||
report.index.name = "date"
|
||||
report.index.name = "trade_time"
|
||||
return report
|
||||
|
||||
def save_report(self, path):
|
||||
@@ -94,13 +94,13 @@ class Report:
|
||||
|
||||
index = r.index
|
||||
self.init_vars()
|
||||
for date in index:
|
||||
for trade_time in index:
|
||||
self.update_report_record(
|
||||
trade_date=date,
|
||||
account_value=r.loc[date]["account"],
|
||||
cash=r.loc[date]["cash"],
|
||||
return_rate=r.loc[date]["return"],
|
||||
turnover_rate=r.loc[date]["turnover"],
|
||||
cost_rate=r.loc[date]["cost"],
|
||||
stock_value=r.loc[date]["value"],
|
||||
trade_time=trade_time,
|
||||
account_value=r.loc[trade_time]["account"],
|
||||
cash=r.loc[trade_time]["cash"],
|
||||
return_rate=r.loc[trade_time]["return"],
|
||||
turnover_rate=r.loc[trade_time]["turnover"],
|
||||
cost_rate=r.loc[trade_time]["cost"],
|
||||
stock_value=r.loc[trade_time]["value"],
|
||||
)
|
||||
|
||||
0
qlib/contrib/backtest_new/backtest.py
Normal file
0
qlib/contrib/backtest_new/backtest.py
Normal file
@@ -32,7 +32,7 @@ class BaseStrategy:
|
||||
if not self.start_time or not self.end_time:
|
||||
raise ValueError("value of `start_time` or `end_time` is None")
|
||||
_calendar = get_sample_freq_calendar(start_time=start_time, end_time=end_time, freq=step_bar)
|
||||
self.trade_dates = np.hstack(pd.Timestamp(self.start_time), _calendar, self.end_time)
|
||||
self.trade_dates = np.hstack(_calendar, pd.Timestamp(self.end_time))
|
||||
self.trade_len = len(self.trade_dates)
|
||||
self.trade_index = 0
|
||||
|
||||
|
||||
@@ -866,14 +866,15 @@ def sample_calendar(calendar_raw, freq_raw, freq_sam):
|
||||
"""
|
||||
freq_raw = "1" + freq_raw if re.match("^[0-9]", freq_raw) is None else freq_raw
|
||||
freq_sam = "1" + freq_sam if re.match("^[0-9]", freq_sam) is None else freq_sam
|
||||
|
||||
if not len(calendar_raw):
|
||||
return calendar_raw
|
||||
if freq_sam.endswith(("minute", "min")):
|
||||
def cal_next_sam_minute(x, sam_minutes):
|
||||
hour = x.hour
|
||||
minute = x.minute
|
||||
if 9 <= hour <= 11:
|
||||
if (hour == 9 and minute >= 30) or (9 < hour < 11) or (hour == 11 and minute < 30):
|
||||
minute_index = (hour - 9)*60 + minute - 30
|
||||
elif 13 <= hour <= 15:
|
||||
elif 13 <= hour < 15:
|
||||
minute_index = (hour - 13)*60 + minute + 120
|
||||
else:
|
||||
raise ValueError("calendar hour must be in [9, 11] or [13, 15]")
|
||||
@@ -894,6 +895,8 @@ def sample_calendar(calendar_raw, freq_raw, freq_sam):
|
||||
if raw_minutes > sam_minutes:
|
||||
raise ValueError("raw freq must be higher than sample freq")
|
||||
_calendar_minute = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, *cal_next_sam_minute(x, sam_minutes), 0), calendar_raw)))
|
||||
if calendar_raw[0] > _calendar_minute[0]:
|
||||
_calendar_minute[0] = calendar_raw[0]
|
||||
return _calendar_minute
|
||||
else:
|
||||
_calendar_day = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, 0, 0, 0), calendar_raw)))
|
||||
@@ -944,4 +947,5 @@ def sample_feature(feature, instruments=None, start_time=None, end_time=None, fi
|
||||
if method:
|
||||
return getattr(feature.groupby(level="instrument"), method)(**method_kwargs)
|
||||
else:
|
||||
return feature
|
||||
return feature
|
||||
|
||||
|
||||
Reference in New Issue
Block a user