1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-03 02:50:58 +08:00
Files
qlib/qlib/contrib/evaluate.py
2021-03-08 19:43:03 +08:00

270 lines
8.8 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import division
from __future__ import print_function
import numpy as np
import pandas as pd
import warnings
from ..log import get_module_logger
from .backtest import get_exchange, backtest as backtest_func
from .backtest.backtest import get_date_range
from ..data import D
from ..config import C
from ..data.dataset.utils import get_level_index
logger = get_module_logger("Evaluate")
def risk_analysis(r, N=252):
"""Risk Analysis
Parameters
----------
r : pandas.Series
daily return series.
N: int
scaler for annualizing information_ratio (day: 250, week: 50, month: 12).
"""
mean = r.mean()
std = r.std(ddof=1)
annualized_return = mean * N
information_ratio = mean / std * np.sqrt(N)
max_drawdown = (r.cumsum() - r.cumsum().cummax()).min()
data = {
"mean": mean,
"std": std,
"annualized_return": annualized_return,
"information_ratio": information_ratio,
"max_drawdown": max_drawdown,
}
res = pd.Series(data, index=data.keys()).to_frame("risk")
return res
# This is the API for compatibility for legacy code
def backtest(pred, account=1e9, shift=1, benchmark="SH000905", verbose=True, **kwargs):
"""This function will help you set a reasonable Exchange and provide default value for strategy
Parameters
----------
- **backtest workflow related or commmon arguments**
pred : pandas.DataFrame
predict should has <datetime, instrument> index and one `score` column.
account : float
init account value.
shift : int
whether to shift prediction by one day.
benchmark : str
benchmark code, default is SH000905 CSI 500.
verbose : bool
whether to print log.
- **strategy related arguments**
strategy : Strategy()
strategy used in backtest.
topk : int (Default value: 50)
top-N stocks to buy.
margin : int or float(Default value: 0.5)
- if isinstance(margin, int):
sell_limit = margin
- else:
sell_limit = pred_in_a_day.count() * margin
buffer margin, in single score_mode, continue holding stock if it is in nlargest(sell_limit).
sell_limit should be no less than topk.
n_drop : int
number of stocks to be replaced in each trading date.
risk_degree: float
0-1, 0.95 for example, use 95% money to trade.
str_type: 'amount', 'weight' or 'dropout'
strategy type: TopkAmountStrategy ,TopkWeightStrategy or TopkDropoutStrategy.
- **exchange related arguments**
exchange: Exchange()
pass the exchange for speeding up.
subscribe_fields: list
subscribe fields.
open_cost : float
open transaction cost. The default value is 0.002(0.2%).
close_cost : float
close transaction cost. The default value is 0.002(0.2%).
min_cost : float
min transaction cost.
trade_unit : int
100 for China A.
deal_price: str
dealing price type: 'close', 'open', 'vwap'.
limit_threshold : float
limit move 0.1 (10%) for example, long and short with same limit.
extract_codes: bool
will we pass the codes extracted from the pred to the exchange.
.. note:: This will be faster with offline qlib.
- **executor related arguments**
executor : BaseExecutor()
executor used in backtest.
verbose : bool
whether to print log.
"""
warnings.warn(
"this function is deprecated, please use backtest function in qlib.contrib.backtest", DeprecationWarning
)
report_dict = backtest_func(
pred=pred, account=account, shift=shift, benchmark=benchmark, verbose=verbose, return_order=False, **kwargs
)
return report_dict.get("report_df"), report_dict.get("positions")
def long_short_backtest(
pred,
topk=50,
deal_price=None,
shift=1,
open_cost=0,
close_cost=0,
trade_unit=None,
limit_threshold=None,
min_cost=5,
subscribe_fields=[],
extract_codes=False,
):
"""
A backtest for long-short strategy
:param pred: The trading signal produced on day `T`.
:param topk: The short topk securities and long topk securities.
:param deal_price: The price to deal the trading.
:param shift: Whether to shift prediction by one day. The trading day will be T+1 if shift==1.
:param open_cost: open transaction cost.
:param close_cost: close transaction cost.
:param trade_unit: 100 for China A.
:param limit_threshold: limit move 0.1 (10%) for example, long and short with same limit.
:param min_cost: min transaction cost.
:param subscribe_fields: subscribe fields.
:param extract_codes: bool.
will we pass the codes extracted from the pred to the exchange.
NOTE: This will be faster with offline qlib.
:return: The result of backtest, it is represented by a dict.
{ "long": long_returns(excess),
"short": short_returns(excess),
"long_short": long_short_returns}
"""
if get_level_index(pred, level="datetime") == 1:
pred = pred.swaplevel().sort_index()
if trade_unit is None:
trade_unit = C.trade_unit
if limit_threshold is None:
limit_threshold = C.limit_threshold
if deal_price is None:
deal_price = C.deal_price
if deal_price[0] != "$":
deal_price = "$" + deal_price
subscribe_fields = subscribe_fields.copy()
profit_str = f"Ref({deal_price}, -1)/{deal_price} - 1"
subscribe_fields.append(profit_str)
trade_exchange = get_exchange(
pred=pred,
deal_price=deal_price,
subscribe_fields=subscribe_fields,
limit_threshold=limit_threshold,
open_cost=open_cost,
close_cost=close_cost,
min_cost=min_cost,
trade_unit=trade_unit,
extract_codes=extract_codes,
shift=shift,
)
_pred_dates = pred.index.get_level_values(level="datetime")
predict_dates = D.calendar(start_time=_pred_dates.min(), end_time=_pred_dates.max())
trade_dates = np.append(predict_dates[shift:], get_date_range(predict_dates[-1], left_shift=1, right_shift=shift))
long_returns = {}
short_returns = {}
ls_returns = {}
for pdate, date in zip(predict_dates, trade_dates):
score = pred.loc(axis=0)[pdate, :]
score = score.reset_index().sort_values(by="score", ascending=False)
long_stocks = list(score.iloc[:topk]["instrument"])
short_stocks = list(score.iloc[-topk:]["instrument"])
score = score.set_index(["datetime", "instrument"]).sort_index()
long_profit = []
short_profit = []
all_profit = []
for stock in long_stocks:
if not trade_exchange.is_stock_tradable(stock_id=stock, trade_date=date):
continue
profit = trade_exchange.get_quote_info(stock_id=stock, trade_date=date)[profit_str]
if np.isnan(profit):
long_profit.append(0)
else:
long_profit.append(profit)
for stock in short_stocks:
if not trade_exchange.is_stock_tradable(stock_id=stock, trade_date=date):
continue
profit = trade_exchange.get_quote_info(stock_id=stock, trade_date=date)[profit_str]
if np.isnan(profit):
short_profit.append(0)
else:
short_profit.append(-profit)
for stock in list(score.loc(axis=0)[pdate, :].index.get_level_values(level=0)):
# exclude the suspend stock
if trade_exchange.check_stock_suspended(stock_id=stock, trade_date=date):
continue
profit = trade_exchange.get_quote_info(stock_id=stock, trade_date=date)[profit_str]
if np.isnan(profit):
all_profit.append(0)
else:
all_profit.append(profit)
long_returns[date] = np.mean(long_profit) - np.mean(all_profit)
short_returns[date] = np.mean(short_profit) + np.mean(all_profit)
ls_returns[date] = np.mean(short_profit) + np.mean(long_profit)
return dict(
zip(
["long", "short", "long_short"],
map(pd.Series, [long_returns, short_returns, ls_returns]),
)
)
def t_run():
pred_FN = "./check_pred.csv"
pred = pd.read_csv(pred_FN)
pred["datetime"] = pd.to_datetime(pred["datetime"])
pred = pred.set_index([pred.columns[0], pred.columns[1]])
pred = pred.iloc[:9000]
report_df, positions = backtest(pred=pred)
print(report_df.head())
print(positions.keys())
print(positions[list(positions.keys())[0]])
return 0
if __name__ == "__main__":
t_run()