1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

fix some comments and add docstring

This commit is contained in:
bxdd
2021-05-12 02:17:39 +08:00
parent f7d30960c1
commit 621cb243c2
25 changed files with 795 additions and 712 deletions

View File

@@ -91,13 +91,13 @@ if __name__ == "__main__":
},
},
"env": {
"class": "SplitEnv",
"module_path": "qlib.contrib.backtest.env",
"class": "SplitExecutor",
"module_path": "qlib.contrib.backtest.executor",
"kwargs": {
"step_bar": "week",
"sub_env": {
"class": "SimulatorEnv",
"module_path": "qlib.contrib.backtest.env",
"class": "SimulatorExecutor",
"module_path": "qlib.contrib.backtest.executor",
"kwargs": {
"step_bar": "day",
"verbose": True,
@@ -118,14 +118,17 @@ if __name__ == "__main__":
"backtest": {
"start_time": trade_start_time,
"end_time": trade_end_time,
"verbose": False,
"limit_threshold": 0.095,
"account": 100000000,
"benchmark": benchmark,
"deal_price": "close",
"open_cost": 0.0005,
"close_cost": 0.0015,
"min_cost": 5,
"exchange_kwargs": {
"freq": "day",
"verbose": False,
"limit_threshold": 0.095,
"deal_price": "close",
"open_cost": 0.0005,
"close_cost": 0.0015,
"min_cost": 5,
},
},
}

View File

@@ -1,15 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .order import Order
from .position import Position
from .exchange import Exchange
from .report import Report
from .executor import BaseExecutor
from .backtest import backtest as backtest_func
import copy
import numpy as np
import inspect
from ...strategy.base import BaseStrategy
from ...utils import init_instance_by_config
from ...log import get_module_logger
from ...config import C
@@ -90,21 +88,6 @@ def get_exchange(
return init_instance_by_config(exchange, accept_types=Exchange)
def init_env_instance_by_config(env):
if isinstance(env, dict):
env_config = copy.copy(env)
if "kwargs" in env_config:
env_kwargs = copy.copy(env_config["kwargs"])
if "sub_env" in env_kwargs:
env_kwargs["sub_env"] = init_env_instance_by_config(env_kwargs["sub_env"])
if "sub_strategy" in env_kwargs:
env_kwargs["sub_strategy"] = init_instance_by_config(env_kwargs["sub_strategy"])
env_config["kwargs"] = env_kwargs
return init_instance_by_config(env_config)
else:
return env
def setup_exchange(root_instance, trade_exchange=None, force=False):
if "trade_exchange" in inspect.getfullargspec(root_instance.__class__).args:
if force:
@@ -118,13 +101,11 @@ def setup_exchange(root_instance, trade_exchange=None, force=False):
setup_exchange(root_instance.sub_strategy, trade_exchange)
def backtest(start_time, end_time, strategy, env, benchmark="SH000905", account=1e9, **kwargs):
trade_strategy = init_instance_by_config(strategy)
trade_env = init_env_instance_by_config(env)
def backtest(start_time, end_time, strategy, env, benchmark="SH000905", account=1e9, exchange_kwargs={}):
trade_strategy = init_instance_by_config(strategy, accept_types=BaseStrategy)
trade_env = init_instance_by_config(env, accept_types=BaseExecutor)
spec = inspect.getfullargspec(get_exchange)
exchange_args = {k: v for k, v in kwargs.items() if k in spec.args}
trade_exchange = get_exchange(**exchange_args)
trade_exchange = get_exchange(**exchange_kwargs)
setup_exchange(trade_env, trade_exchange)
setup_exchange(trade_strategy, trade_exchange)

View File

@@ -3,13 +3,14 @@
import copy
import warnings
import pandas as pd
from .position import Position
from .report import Report
from .order import Order
from ...data import D
from ...utils import parse_freq, sample_feature
from ...utils.sample import parse_freq, sample_feature
"""
@@ -110,6 +111,8 @@ class Account:
for k, v in kwargs.items():
if hasattr(self, k):
setattr(self, k, v)
else:
warnings.warn(f"reser error, attribute {k} is not found!")
def get_positions(self):
return self.positions

View File

@@ -1,10 +1,6 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import numpy as np
import pandas as pd
from .account import Account
@@ -14,9 +10,9 @@ def backtest(start_time, end_time, trade_strategy, trade_env, benchmark, account
trade_env.reset(start_time=start_time, end_time=end_time, trade_account=trade_account)
trade_strategy.reset(start_time=start_time, end_time=end_time)
trade_state = trade_env.get_init_state()
_execute_state = trade_env.get_init_state()
while not trade_env.finished():
_order_list = trade_strategy.generate_order_list(**trade_state)
trade_state, trade_info = trade_env.execute(_order_list)
_order_list = trade_strategy.generate_order_list(_execute_state)
_execute_state = trade_env.execute(_order_list)
return trade_env.get_report()

View File

@@ -11,7 +11,7 @@ import pandas as pd
from ...data.data import D
from ...data.dataset.utils import get_level_index
from ...config import C, REG_CN
from ...utils import sample_feature
from ...utils.sample import sample_feature
from ...log import get_module_logger
from .order import Order

View File

@@ -1,19 +1,34 @@
import re
import json
import copy
import warnings
import pathlib
import numpy as np
import pandas as pd
from typing import Tuple, List, Union, Optional, Callable
from ...data.data import Cal
from ...utils import get_sample_freq_calendar, parse_freq
from .position import Position
from ...strategy.base import BaseStrategy
from ...utils import init_instance_by_config
from ...utils.sample import get_sample_freq_calendar, parse_freq
from .report import Report
from .order import Order
from .account import Account
from .exchange import Exchange
class BaseTradeCalendar:
def __init__(self, step_bar, start_time=None, end_time=None, **kwargs):
def __init__(
self, step_bar: str, start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None
):
"""
Parameters
----------
step_bar : str
frequency of each trading step bar
start_time : Union[str, pd.Timestamp], optional
start time of trading, by default None
If `start_time` is None, it must be reset before trading.
end_time : Union[str, pd.Timestamp], optional
end time of trading, by default None
If `end_time` is None, it must be reset before trading.
"""
self.step_bar = step_bar
self.reset(start_time=start_time, end_time=end_time)
@@ -27,10 +42,9 @@ class BaseTradeCalendar:
if self.start_time and self.end_time:
_calendar, freq, freq_sam = get_sample_freq_calendar(freq=self.step_bar)
self.calendar = _calendar
_start_time, _end_time, _start_index, _end_index = Cal.locate_index(
_, _, _start_index, _end_index = Cal.locate_index(
self.start_time, self.end_time, freq=freq, freq_sam=freq_sam
)
_trade_calendar = self.calendar[_start_index : _end_index + 1]
self.start_index = _start_index
self.end_index = _end_index
self.trade_len = _end_index - _start_index + 1
@@ -45,6 +59,8 @@ class BaseTradeCalendar:
for k, v in kwargs.items():
if hasattr(self, k):
setattr(self, k, v)
else:
warnings.warn(f"reser error, attribute {k} is not found!")
def _get_calendar_time(self, trade_index=1, shift=0):
trade_index = trade_index - shift
@@ -55,34 +71,43 @@ class BaseTradeCalendar:
return self.trade_index >= self.trade_len - 1
def step(self):
if self.finished():
raise RuntimeError(f"this env has completed its task, please reset it if you want to call it!")
self.trade_index = self.trade_index + 1
class BaseEnv(BaseTradeCalendar):
"""
# Strategy framework document
class Env(BaseEnv):
"""
class BaseExecutor(BaseTradeCalendar):
"""Base executor for trading"""
def __init__(
self,
step_bar,
start_time=None,
end_time=None,
trade_account=None,
generate_report=False,
verbose=False,
step_bar: str,
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
trade_account: Account = None,
generate_report: bool = False,
verbose: bool = False,
**kwargs,
):
self.generate_report = generate_report
self.verbose = verbose
super(BaseEnv, self).__init__(
"""
Parameters
----------
trade_account : Account, optional
trade account for trading, by default None
If `trade_account` is None, it must be reset before trading
generate_report : bool, optional
whether to generate report, by default False
verbose : bool, optional
whether to print log, by default False
"""
super(BaseExecutor, self).__init__(
step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, **kwargs
)
self.generate_report = generate_report
self.verbose = verbose
def reset(self, trade_account=None, **kwargs):
super(BaseEnv, self).reset(**kwargs)
super(BaseExecutor, self).reset(**kwargs)
if trade_account:
self.trade_account = trade_account
self.trade_account.reset(freq=self.step_bar, report=Report(), positions={})
@@ -101,23 +126,31 @@ class BaseEnv(BaseTradeCalendar):
raise NotImplementedError("get_report is not implemented!")
class SplitEnv(BaseEnv):
class SplitExecutor(BaseExecutor):
def __init__(
self,
step_bar,
sub_env,
sub_strategy,
start_time=None,
end_time=None,
trade_account=None,
trade_exchange=None,
generate_report=False,
verbose=False,
step_bar: str,
sub_env: Union[BaseExecutor, dict],
sub_strategy: Union[BaseStrategy, dict],
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
trade_account: Account = None,
trade_exchange: Exchange = None,
generate_report: bool = False,
verbose: bool = False,
**kwargs,
):
self.sub_env = sub_env
self.sub_strategy = sub_strategy
super(SplitEnv, self).__init__(
"""
Parameters
----------
sub_env : BaseExecutor
trading env in each trading bar.
sub_strategy : BaseStrategy
trading strategy in each trading bar
trade_exchange : Exchange
exchange that provides market info
"""
super(SplitExecutor, self).__init__(
step_bar=step_bar,
start_time=start_time,
end_time=end_time,
@@ -127,28 +160,26 @@ class SplitEnv(BaseEnv):
verbose=verbose,
**kwargs,
)
self.sub_env = init_instance_by_config(sub_env, accept_types=BaseExecutor)
self.sub_strategy = init_instance_by_config(sub_strategy, accept_types=BaseStrategy)
def reset(self, trade_account=None, trade_exchange=None, **kwargs):
super(SplitEnv, self).reset(trade_account=trade_account, **kwargs)
super(SplitExecutor, self).reset(trade_account=trade_account, **kwargs)
if trade_account:
self.sub_env.reset(trade_account=copy.copy(trade_account))
if trade_exchange:
self.trade_exchange = trade_exchange
def execute(self, order_list, **kwargs):
if self.finished():
raise StopIteration(f"this env has completed its task, please reset it if you want to call it!")
# if self.track:
# yield action
# episode_reward = 0
super(SplitEnv, self).step()
def execute(self, order_list):
super(SplitExecutor, self).step()
trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index)
self.sub_env.reset(start_time=trade_start_time, end_time=trade_end_time)
self.sub_strategy.reset(start_time=trade_start_time, end_time=trade_end_time, trade_order_list=order_list)
trade_state = self.sub_env.get_init_state()
_execute_state = self.sub_env.get_init_state()
while not self.sub_env.finished():
_order_list = self.sub_strategy.generate_order_list(**trade_state)
trade_state, trade_info = self.sub_env.execute(order_list=_order_list)
_order_list = self.sub_strategy.generate_order_list(_execute_state)
_execute_state = self.sub_env.execute(order_list=_order_list)
self.trade_account.update_bar_end(
trade_start_time=trade_start_time,
@@ -156,9 +187,8 @@ class SplitEnv(BaseEnv):
trade_exchange=self.trade_exchange,
update_report=self.generate_report,
)
_obs = {"current": self.trade_account.current}
_info = {}
return _obs, _info
_execute_state = {"current": self.trade_account.current}
return _execute_state
def get_report(self):
sub_env_report_dict = self.sub_env.get_report()
@@ -167,12 +197,10 @@ class SplitEnv(BaseEnv):
_positions = self.trade_account.get_positions()
_count, _freq = parse_freq(self.step_bar)
sub_env_report_dict.update({f"{_count}{_freq}": (_report, _positions)})
return sub_env_report_dict
else:
return sub_env_report_dict
return sub_env_report_dict
class SimulatorEnv(BaseEnv):
class SimulatorExecutor(BaseExecutor):
def __init__(
self,
step_bar,
@@ -184,7 +212,13 @@ class SimulatorEnv(BaseEnv):
verbose=False,
**kwargs,
):
super(SimulatorEnv, self).__init__(
"""
Parameters
----------
trade_exchange : Exchange
exchange that provides market info
"""
super(SimulatorExecutor, self).__init__(
step_bar=step_bar,
start_time=start_time,
end_time=end_time,
@@ -196,17 +230,12 @@ class SimulatorEnv(BaseEnv):
)
def reset(self, trade_exchange=None, **kwargs):
super(SimulatorEnv, self).reset(**kwargs)
super(SimulatorExecutor, self).reset(**kwargs)
if trade_exchange:
self.trade_exchange = trade_exchange
def execute(self, order_list, **kwargs):
"""
Return: obs, done, info
"""
if self.finished():
raise StopIteration(f"this env has completed its task, please reset it if you want to call it!")
super(SimulatorEnv, self).step()
def execute(self, order_list):
super(SimulatorExecutor, self).step()
trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index)
trade_info = []
for order in order_list:
@@ -219,21 +248,25 @@ class SimulatorEnv(BaseEnv):
if self.verbose:
if order.direction == Order.SELL: # sell
print(
"[I {:%Y-%m-%d}]: sell {}, price {:.2f}, amount {}, value {:.2f}.".format(
"[I {:%Y-%m-%d}]: sell {}, price {:.2f}, amount {}, deal_amount {}, factor {}, value {:.2f}.".format(
trade_start_time,
order.stock_id,
trade_price,
order.amount,
order.deal_amount,
order.factor,
trade_val,
)
)
else:
print(
"[I {:%Y-%m-%d}]: buy {}, price {:.2f}, amount {}, value {:.2f}.".format(
"[I {:%Y-%m-%d}]: buy {}, price {:.2f}, amount {}, deal_amount {}, factor {}, value {:.2f}.".format(
trade_start_time,
order.stock_id,
trade_price,
order.amount,
order.deal_amount,
order.factor,
trade_val,
)
)
@@ -249,9 +282,8 @@ class SimulatorEnv(BaseEnv):
trade_exchange=self.trade_exchange,
update_report=self.generate_report,
)
_obs = {"current": self.trade_account.current}
_info = {"trade_info": trade_info}
return _obs, _info
_execute_state = {"current": self.trade_account.current, "trade_info": trade_info}
return _execute_state
def get_report(self):
if self.generate_report:

View File

@@ -1,16 +0,0 @@
class BaseInterpreter:
@staticmethod
def interpret(**kwargs):
raise NotImplementedError("interpret is not implemented!")
class ActionInterpreter:
@staticmethod
def interpret(action, **kwargs):
return action
class StateInterpreter:
@staticmethod
def interpret(state, **kwargs):
return state

View File

@@ -10,6 +10,7 @@ import warnings
from ..log import get_module_logger
from .backtest import get_exchange, backtest as backtest_func
from ..utils import get_date_range
from ..utils.sample import parse_freq
from ..data import D
from ..config import C
@@ -19,7 +20,7 @@ from ..data.dataset.utils import get_level_index
logger = get_module_logger("Evaluate")
def risk_analysis(r, N=252):
def risk_analysis(r, N: int = None, freq: str = None):
"""Risk Analysis
Parameters
@@ -27,8 +28,26 @@ def risk_analysis(r, N=252):
r : pandas.Series
daily return series.
N: int
scaler for annualizing information_ratio (day: 250, week: 50, month: 12).
scaler for annualizing information_ratio (day: 250, week: 50, month: 12), at least one of `N` and `freq` should exist
freq: str
analysis frequency used for calculating the scaler, at least one of `N` and `freq` should exist
"""
def cal_risk_analysis_scaler(freq):
_count, _freq = parse_freq(freq)
_freq_scaler = {
"minute": 240 * 250,
"day": 250,
"week": 50,
"month": 12,
}
return _count * _freq_scaler[_freq]
if N is None and freq is None:
raise ValueError("at least one of `N` and `freq` should exist")
if N is None:
N = cal_risk_analysis_scaler(freq)
mean = r.mean()
std = r.std(ddof=1)
annualized_return = mean * N

View File

@@ -1,291 +0,0 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import re
import json
import copy
import pathlib
import pandas as pd
from ...data import D
from ...utils import get_date_in_file_name
from ...utils import get_pre_trading_date
from ..backtest.order import Order
class BaseExecutor:
"""
# Strategy framework document
class Executor(BaseExecutor):
"""
def execute(self, trade_account, order_list, trade_date):
"""
return the executed result (trade_info) after trading at trade_date.
NOTICE: trade_account will not be modified after executing.
Parameter
---------
trade_account : Account()
order_list : list
[Order()]
trade_date : pd.Timestamp
Return
---------
trade_info : list
[Order(), float, float, float]
"""
raise NotImplementedError("get_execute_result for this model is not implemented.")
def save_executed_file_from_trade_info(self, trade_info, user_path, trade_date):
"""
Save the trade_info to the .csv transaction file in disk
the columns of result file is
['date', 'stock_id', 'direction', 'trade_val', 'trade_cost', 'trade_price', 'factor']
Parameter
---------
trade_info : list of [Order(), float, float, float]
(order, trade_val, trade_cost, trade_price), trade_info with out factor
user_path: str / pathlib.Path()
the sub folder to save user data
transaction_path : string / pathlib.Path()
"""
YYYY, MM, DD = str(trade_date.date()).split("-")
folder_path = pathlib.Path(user_path) / "trade" / YYYY / MM
if not folder_path.exists():
folder_path.mkdir(parents=True)
transaction_path = folder_path / "transaction_{}.csv".format(str(trade_date.date()))
columns = [
"date",
"stock_id",
"direction",
"amount",
"trade_val",
"trade_cost",
"trade_price",
"factor",
]
data = []
for [order, trade_val, trade_cost, trade_price] in trade_info:
data.append(
[
trade_date,
order.stock_id,
order.direction,
order.amount,
trade_val,
trade_cost,
trade_price,
order.factor,
]
)
df = pd.DataFrame(data, columns=columns)
df.to_csv(transaction_path, index=False)
def load_trade_info_from_executed_file(self, user_path, trade_date):
YYYY, MM, DD = str(trade_date.date()).split("-")
file_path = pathlib.Path(user_path) / "trade" / YYYY / MM / "transaction_{}.csv".format(str(trade_date.date()))
if not file_path.exists():
raise ValueError("File {} not exists!".format(file_path))
filedate = get_date_in_file_name(file_path)
transaction = pd.read_csv(file_path)
trade_info = []
for i in range(len(transaction)):
date = transaction.loc[i]["date"]
if not date == filedate:
continue
# raise ValueError("date in transaction file {} not equal to it's file date{}".format(date, filedate))
order = Order(
stock_id=transaction.loc[i]["stock_id"],
amount=transaction.loc[i]["amount"],
trade_date=transaction.loc[i]["date"],
direction=transaction.loc[i]["direction"],
factor=transaction.loc[i]["factor"],
)
trade_val = transaction.loc[i]["trade_val"]
trade_cost = transaction.loc[i]["trade_cost"]
trade_price = transaction.loc[i]["trade_price"]
trade_info.append([order, trade_val, trade_cost, trade_price])
return trade_info
class SimulatorExecutor(BaseExecutor):
def __init__(self, trade_exchange, verbose=False):
self.trade_exchange = trade_exchange
self.verbose = verbose
self.order_list = []
def execute(self, trade_account, order_list, trade_date):
"""
execute the order list, do the trading wil exchange at date.
Will not modify the trade_account.
Parameter
trade_account : Account()
order_list : list
list or orders
trade_date : pd.Timestamp
:return:
trade_info : list of [Order(), float, float, float]
(order, trade_val, trade_cost, trade_price), trade_info with out factor
"""
account = copy.deepcopy(trade_account)
trade_info = []
for order in order_list:
# check holding thresh is done in strategy
# if order.direction==0: # sell order
# # checking holding thresh limit for sell order
# if trade_account.current.get_stock_count(order.stock_id) < thresh:
# # can not sell this code
# continue
# is order executable
# check order
if self.trade_exchange.check_order(order) is True:
# execute the order
trade_val, trade_cost, trade_price = self.trade_exchange.deal_order(order, trade_account=account)
trade_info.append([order, trade_val, trade_cost, trade_price])
if self.verbose:
if order.direction == Order.SELL: # sell
print(
"[I {:%Y-%m-%d}]: sell {}, price {:.2f}, amount {}, value {:.2f}.".format(
trade_date,
order.stock_id,
trade_price,
order.deal_amount,
trade_val,
)
)
else:
print(
"[I {:%Y-%m-%d}]: buy {}, price {:.2f}, amount {}, value {:.2f}.".format(
trade_date,
order.stock_id,
trade_price,
order.deal_amount,
trade_val,
)
)
else:
if self.verbose:
print("[W {:%Y-%m-%d}]: {} wrong.".format(trade_date, order.stock_id))
# do nothing
pass
return trade_info
def save_score_series(score_series, user_path, trade_date):
"""Save the score_series into a .csv file.
The columns of saved file is
[stock_id, score]
Parameter
---------
order_list: [Order()]
list of Order()
date: pd.Timestamp
the date to save the order list
user_path: str / pathlib.Path()
the sub folder to save user data
"""
user_path = pathlib.Path(user_path)
YYYY, MM, DD = str(trade_date.date()).split("-")
folder_path = user_path / "score" / YYYY / MM
if not folder_path.exists():
folder_path.mkdir(parents=True)
file_path = folder_path / "score_{}.csv".format(str(trade_date.date()))
score_series.to_csv(file_path)
def load_score_series(user_path, trade_date):
"""Save the score_series into a .csv file.
The columns of saved file is
[stock_id, score]
Parameter
---------
order_list: [Order()]
list of Order()
date: pd.Timestamp
the date to save the order list
user_path: str / pathlib.Path()
the sub folder to save user data
"""
user_path = pathlib.Path(user_path)
YYYY, MM, DD = str(trade_date.date()).split("-")
folder_path = user_path / "score" / YYYY / MM
if not folder_path.exists():
folder_path.mkdir(parents=True)
file_path = folder_path / "score_{}.csv".format(str(trade_date.date()))
score_series = pd.read_csv(file_path, index_col=0, header=None, names=["instrument", "score"])
return score_series
def save_order_list(order_list, user_path, trade_date):
"""
Save the order list into a json file.
Will calculate the real amount in order according to factors at date.
The format in json file like
{"sell": {"stock_id": amount, ...}
,"buy": {"stock_id": amount, ...}}
:param
order_list: [Order()]
list of Order()
date: pd.Timestamp
the date to save the order list
user_path: str / pathlib.Path()
the sub folder to save user data
"""
user_path = pathlib.Path(user_path)
YYYY, MM, DD = str(trade_date.date()).split("-")
folder_path = user_path / "trade" / YYYY / MM
if not folder_path.exists():
folder_path.mkdir(parents=True)
sell = {}
buy = {}
for order in order_list:
if order.direction == 0: # sell
sell[order.stock_id] = [order.amount, order.factor]
else:
buy[order.stock_id] = [order.amount, order.factor]
order_dict = {"sell": sell, "buy": buy}
file_path = folder_path / "orderlist_{}.json".format(str(trade_date.date()))
with file_path.open("w") as fp:
json.dump(order_dict, fp)
def load_order_list(user_path, trade_date):
user_path = pathlib.Path(user_path)
YYYY, MM, DD = str(trade_date.date()).split("-")
path = user_path / "trade" / YYYY / MM / "orderlist_{}.json".format(str(trade_date.date()))
if not path.exists():
raise ValueError("File {} not exists!".format(path))
# get orders
with path.open("r") as fp:
order_dict = json.load(fp)
order_list = []
for stock_id in order_dict["sell"]:
amount, factor = order_dict["sell"][stock_id]
order = Order(
stock_id=stock_id,
amount=amount,
trade_date=pd.Timestamp(trade_date),
direction=Order.SELL,
factor=factor,
)
order_list.append(order)
for stock_id in order_dict["buy"]:
amount, factor = order_dict["buy"][stock_id]
order = Order(
stock_id=stock_id,
amount=amount,
trade_date=pd.Timestamp(trade_date),
direction=Order.BUY,
factor=factor,
)
order_list.append(order)
return order_list

View File

@@ -3,7 +3,7 @@ import warnings
import numpy as np
import pandas as pd
from ...utils import sample_feature
from ...utils.sample import sample_feature
from ...strategy.base import ModelStrategy
from ..backtest.order import Order
from .order_generator import OrderGenWInteract
@@ -66,7 +66,7 @@ class TopkDropoutStrategy(ModelStrategy):
if trade_exchange:
self.trade_exchange = trade_exchange
def get_risk_degree(self, trade_index):
def get_risk_degree(self, trade_index=None):
"""get_risk_degree
Return the proportion of your total value you will used in investment.
Dynamically risk_degree will result in Market timing.
@@ -74,7 +74,7 @@ class TopkDropoutStrategy(ModelStrategy):
# It will use 95% amoutn of your total value by default
return self.risk_degree
def generate_order_list(self, current, **kwargs):
def generate_order_list(self, execute_state):
super(TopkDropoutStrategy, self).step()
trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index)
pred_start_time, pred_end_time = self._get_calendar_time(self.trade_index, shift=1)
@@ -120,6 +120,7 @@ class TopkDropoutStrategy(ModelStrategy):
def filter_stock(l):
return l
current = execute_state.get("current")
current_temp = copy.deepcopy(current)
# generate order list for this adjust date
sell_order_list = []
@@ -163,6 +164,7 @@ class TopkDropoutStrategy(ModelStrategy):
# Get the stock list we really want to buy
buy = today[: len(sell) + self.topk - len(last)]
print("INTRANEL BAR", len(sell), len(sell) + self.topk - len(last), len(last))
# print("flag", len(sell), len(buy), self.topk, len(last))
for code in current_stock_list:
if not self.trade_exchange.is_stock_tradable(
@@ -175,13 +177,17 @@ class TopkDropoutStrategy(ModelStrategy):
continue
# sell order
sell_amount = current_temp.get_stock_amount(code=code)
factor = self.trade_exchange.get_factor(
stock_id=code, start_time=trade_start_time, end_time=trade_end_time
)
# sell_amount = self.trade_exchange.round_amount_by_trade_unit(sell_amount, factor)
sell_order = Order(
stock_id=code,
amount=sell_amount,
start_time=trade_start_time,
end_time=trade_end_time,
direction=Order.SELL, # 0 for sell, 1 for buy
factor=self.trade_exchange.get_factor(code, trade_start_time, trade_end_time),
factor=factor,
)
# is order executable
if self.trade_exchange.check_order(sell_order):
@@ -228,19 +234,36 @@ class WeightStrategyBase(ModelStrategy):
def __init__(
self,
step_bar,
model,
dataset,
start_time=None,
end_time=None,
order_generator_cls_or_obj=OrderGenWInteract,
trade_exchange=None,
**kwargs,
):
super(WeightStrategyBase, self).__init__(step_bar, start_time, end_time)
self.trade_exchange = trade_exchange
super(WeightStrategyBase, self).__init__(
step_bar, model, dataset, start_time, end_time, trade_exchange=trade_exchange, **kwargs
)
if isinstance(order_generator_cls_or_obj, type):
self.order_generator = order_generator_cls_or_obj()
else:
self.order_generator = order_generator_cls_or_obj
def reset(self, trade_exchange=None, **kwargs):
super(WeightStrategyBase, self).reset(**kwargs)
if trade_exchange:
self.trade_exchange = trade_exchange
def get_risk_degree(self, trade_index=None):
"""get_risk_degree
Return the proportion of your total value you will used in investment.
Dynamically risk_degree will result in Market timing.
"""
# It will use 95% amoutn of your total value by default
return 0.95
def generate_target_weight_position(self, score, current, trade_start_time, trade_end_time):
"""
Generate target position from score for this date and the current position.The cash is not considered in the position
@@ -256,7 +279,7 @@ class WeightStrategyBase(ModelStrategy):
"""
raise NotImplementedError()
def generate_order_list(self, current, **kwargs):
def generate_order_list(self, execute_state):
"""
Parameters
-----------
@@ -277,7 +300,8 @@ class WeightStrategyBase(ModelStrategy):
pred_score = sample_feature(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
if pred_score is None:
return []
current_temp = copy.deepcopy(trade_account.current)
current = execute_state.get("current")
current_temp = copy.deepcopy(current)
target_weight_position = self.generate_target_weight_position(
score=pred_score, current=current_temp, trade_start_time=trade_start_time, trade_end_time=trade_end_time
)

View File

@@ -3,14 +3,15 @@ import warnings
import numpy as np
import pandas as pd
from ...utils import sample_feature
from ...utils.sample import sample_feature
from ...data.data import D
from ...data.dataset.utils import get_level_index
from ...strategy.base import RuleStrategy, TradingEnhancement
from ...data.dataset.utils import convert_index_format
from ...strategy.base import RuleStrategy, OrderEnhancement
from ..backtest.order import Order
class TWAPStrategy(RuleStrategy, TradingEnhancement):
class TWAPStrategy(RuleStrategy, OrderEnhancement):
def __init__(
self,
step_bar,
@@ -23,7 +24,7 @@ class TWAPStrategy(RuleStrategy, TradingEnhancement):
def reset(self, trade_order_list=None, trade_exchange=None, **kwargs):
super(TWAPStrategy, self).reset(**kwargs)
TradingEnhancement.reset(self, trade_order_list=trade_order_list)
OrderEnhancement.reset(self, trade_order_list=trade_order_list)
if trade_exchange:
self.trade_exchange = trade_exchange
if trade_order_list:
@@ -31,7 +32,7 @@ class TWAPStrategy(RuleStrategy, TradingEnhancement):
for order in self.trade_order_list:
self.trade_amount[(order.stock_id, order.direction)] = order.amount
def generate_order_list(self, **kwargs):
def generate_order_list(self, execute_state):
super(TWAPStrategy, self).step()
trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index)
order_list = []
@@ -66,7 +67,7 @@ class TWAPStrategy(RuleStrategy, TradingEnhancement):
return order_list
class SBBStrategyBase(RuleStrategy, TradingEnhancement):
class SBBStrategyBase(RuleStrategy, OrderEnhancement):
"""
(S)elect the (B)etter one among every two adjacent trading (B)ars to sell or buy.
"""
@@ -87,7 +88,7 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement):
def reset(self, trade_order_list=None, trade_exchange=None, **kwargs):
super(SBBStrategyBase, self).reset(**kwargs)
TradingEnhancement.reset(self, trade_order_list=trade_order_list)
OrderEnhancement.reset(self, trade_order_list=trade_order_list)
if trade_exchange:
self.trade_exchange = trade_exchange
if trade_order_list is not None:
@@ -100,7 +101,7 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement):
def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None):
raise NotImplementedError("pred_price_trend method is not implemented!")
def generate_order_list(self, **kwargs):
def generate_order_list(self, execute_state):
super(SBBStrategyBase, self).step()
if not self.trade_order_list:
return []
@@ -109,7 +110,7 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement):
order_list = []
for order in self.trade_order_list:
if self.trade_index % 2 == 1:
_pred_trend = self._pred_price_trend(order.stock_id)
_pred_trend = self._pred_price_trend(order.stock_id, pred_start_time, pred_end_time)
else:
_pred_trend = self.trade_trend[(order.stock_id, order.direction)]
@@ -127,7 +128,7 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement):
_order_amount = self.trade_amount[(order.stock_id, order.direction)] / (
self.trade_len - self.trade_index
)
if self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit:
elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit:
trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
_order_amount = (
(trade_unit_cnt + self.trade_len - self.trade_index - 1)
@@ -146,6 +147,7 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement):
factor=order.factor,
)
order_list.append(_order)
# print("DEBUG AMOUNT", _order_amount, self.trade_amount[(order.stock_id, order.direction)], _amount_trade_unit)
else:
_order_amount = None
if _amount_trade_unit is None:
@@ -154,12 +156,12 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement):
* self.trade_amount[(order.stock_id, order.direction)]
/ (self.trade_len - self.trade_index + 1)
)
if self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit:
elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit:
trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
_order_amount = (
2
* (trade_unit_cnt + self.trade_len - self.trade_index)
(trade_unit_cnt + self.trade_len - self.trade_index)
// (self.trade_len - self.trade_index + 1)
* 2
* _amount_trade_unit
)
if _order_amount:
@@ -197,6 +199,7 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement):
factor=order.factor,
)
order_list.append(_order)
# print("DEBUG AMOUNT", _order_amount, self.trade_amount[(order.stock_id, order.direction)], _amount_trade_unit)
if self.trade_index % 2 == 1:
self.trade_trend[(order.stock_id, order.direction)] = _pred_trend
@@ -226,20 +229,15 @@ class SBBStrategyEMA(SBBStrategyBase):
self.instruments = D.instruments(instruments)
self.freq = freq
def _convert_index_format(self, df):
if get_level_index(df, level="datetime") == 1:
df = df.swaplevel().sort_index()
return df
def _reset_trade_calendar(self, start_time=None, end_time=None):
super(SBBStrategyEMA, self)._reset_trade_calendar(start_time=start_time, end_time=end_time)
def reset(self, start_time=None, end_time=None, **kwargs):
super(SBBStrategyEMA, self).reset(start_time=start_time, end_time=end_time, **kwargs)
if self.start_time and self.end_time:
fields = ["EMA($close, 10)-EMA($close, 20)"]
signal_start_time, _ = self._get_calendar_time(trade_index=self.trade_index, shift=1)
signal_df = D.features(
self.instruments, fields, start_time=signal_start_time, end_time=self.end_time, freq=self.freq
)
signal_df = self._convert_index_format(signal_df)
signal_df = convert_index_format(signal_df)
signal_df.columns = ["signal"]
self.signal = {}
for stock_id, stock_val in signal_df.groupby(level="instrument"):

View File

@@ -25,7 +25,8 @@ from ..log import get_module_logger
from ..utils import parse_field, read_bin, hash_args, normalize_cache_fields, code_to_fname
from .base import Feature
from .cache import DiskDatasetCache, DiskExpressionCache
from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path, sample_calendar
from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path
from ..utils.sample import sample_calendar
class CalendarProvider(abc.ABC):
@@ -35,7 +36,7 @@ class CalendarProvider(abc.ABC):
"""
@abc.abstractmethod
def calendar(self, start_time=None, end_time=None, freq="day", future=False):
def calendar(self, start_time=None, end_time=None, freq="day", freq_sam=None, future=False):
"""Get calendar of certain market in given time range.
Parameters
@@ -46,6 +47,8 @@ class CalendarProvider(abc.ABC):
end of the time range.
freq : str
time frequency, available: year/quarter/month/week/day.
freq_sam : str
sample frequency used for sampling lower-frequency calendar, by default None(raw calendar).
future : bool
whether including future trading day.
@@ -769,7 +772,7 @@ class ClientCalendarProvider(CalendarProvider):
def set_conn(self, conn):
self.conn = conn
def calendar(self, start_time=None, end_time=None, freq="day", future=False):
def calendar(self, start_time=None, end_time=None, freq="day", freq_sam=None, future=False):
self.conn.send_request(
request_type="calendar",
@@ -937,8 +940,8 @@ class BaseProvider:
To keep compatible with old qlib provider.
"""
def calendar(self, start_time=None, end_time=None, freq="day", future=False):
return Cal.calendar(start_time, end_time, freq, future=future)
def calendar(self, start_time=None, end_time=None, freq="day", freq_sam=None, future=False):
return Cal.calendar(start_time, end_time, freq, freq_sam, future=future)
def instruments(self, market="all", filter_pipe=None, start_time=None, end_time=None):
if start_time is not None or end_time is not None:

View File

@@ -70,3 +70,27 @@ def fetch_df_by_index(
return df.loc[
pd.IndexSlice[idx_slc],
]
def convert_index_format(df: Union[pd.DataFrame, pd.Series], level: str = "datetime") -> Union[pd.DataFrame, pd.Series]:
"""
Convert the format of df.MultiIndex according to the following rules:
- If `level` is the first level of df.MultiIndex, do nothing
- If `level` is the second level of df.MultiIndex, swap the level of index.
Parameters
----------
df : Union[pd.DataFrame, pd.Series]
raw DataFrame/Series
level : str, optional
the level that will be converted to the first one, by default "datetime"
Returns
-------
Union[pd.DataFrame, pd.Series]
converted DataFrame/Series
"""
if get_level_index(df, level=level) == 1:
df = df.swaplevel().sort_index()
return df

2
qlib/rl/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

104
qlib/rl/env.py Normal file
View File

@@ -0,0 +1,104 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from .interpreter import StateInterpreter, ActionInterpreter
from ..contrib.backtest.executor import BaseExecutor
class BaseRLEnv:
def reset(self, **kwargs):
raise NotImplementedError("reset is not implemented!")
def step(self, action):
"""
step method of rl env
Parameters
----------
action :
action from rl policy
Returns
-------
env state to rl policy
"""
raise NotImplementedError("step is not implemented!")
class QlibRLEnv:
"""qlib-based RL env"""
def __init__(
self,
executor: BaseExecutor,
):
"""
Parameters
----------
executor : BaseExecutor
qlib multi-level/single-level executor, which can be regarded as gamecore in RL
"""
self.executor = executor
def reset(self, **kwargs):
self.executor.reset(**kwargs)
class QlibIntRLEnv(QlibRLEnv):
"""(Qlib)-based RL (Env) with (Interpreter)"""
def __init__(
self,
executor: BaseExecutor,
state_interpreter: StateInterpreter,
action_interpreter: ActionInterpreter,
state_interpret_kwargs: dict = {},
action_interpret_kwargs: dict = {},
):
"""
Parameters
----------
state_interpreter : StateInterpreter
interpretor that interprets the qlib execute result into rl env state.
action_interpreter : ActionInterpreter
interpretor that interprets the rl agent action into qlib order list
state_interpret_kwargs : dict, optional
arguments may be used in `state_interpreter.interpret`, by default {}
such as the following arguments:
- trade exchange : Exchange
Exchange that can provide market info
action_interpret_kwargs: dict, optional
arguments may be used in `action_interpreter.interpret`, by default {}
such as the following arguments:
- trade_order_list : List[Order]
If the strategy is used to split order, it presents the trade order pool.
"""
super(QlibIntRLEnv, self).__init__(executor=executor)
self.state_interpreter = state_interpreter
self.action_interpreter = action_interpreter
self.state_interpret_kwargs = state_interpret_kwargs
self.action_interpret_kwargs = action_interpret_kwargs
def step(self, action):
"""
step method of rl env, it run as following step:
- Use `action_interpreter.interpret` method to interpret the agent action into order list
- Execute the order list with qlib executor, and get the executed result
- Use `state_interpreter.interpret` method to interpret the executed result into env state
Parameters
----------
action :
action from rl policy
Returns
-------
env state to rl rl policy
"""
_interpret_action = self.action_interpreter.interpret(action=action, **self.state_interpret_kwargs)
_execute_result = self.executor.execute(_interpret_action)
_interpret_state = self.state_interpreter.interpret(
execute_result=_execute_result, **self.action_interpret_kwargs
)
return _interpret_state

20
qlib/rl/interpreter.py Normal file
View File

@@ -0,0 +1,20 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
class BaseInterpreter:
@staticmethod
def interpret(**kwargs):
raise NotImplementedError("interpret is not implemented!")
class ActionInterpreter(BaseInterpreter):
@staticmethod
def interpret(action, **kwargs):
raise NotImplementedError("interpret is not implemented!")
class StateInterpreter(BaseInterpreter):
@staticmethod
def interpret(execute_result, **kwargs):
raise NotImplementedError("interpret is not implemented!")

View File

@@ -1,55 +1,160 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import copy
import warnings
import numpy as np
import pandas as pd
from typing import Tuple, List, Union, Optional, Callable
from ..utils import get_sample_freq_calendar
from ..model.base import BaseModel
from ..data.dataset import DatasetH
from ..data.dataset.utils import get_level_index
from ..data.dataset.utils import convert_index_format
from ..contrib.backtest.order import Order
from ..contrib.backtest.env import BaseTradeCalendar
"""
1. BaseStrategy 的粒度一定是数据粒度的整数倍
- 关于calendar的合并咋整
- adjust_dates这个东西啥用
- label和freq和strategy的bar分离这个如何决策呢
"""
from ..contrib.backtest.executor import BaseTradeCalendar
from ..rl.interpreter import ActionInterpreter, StateInterpreter
class BaseStrategy(BaseTradeCalendar):
def generate_order_list(self, **kwargs):
"""Base strategy"""
def generate_order_list(self, execute_state):
"""Generate order list in each trading bar"""
raise NotImplementedError("generator_order_list is not implemented!")
class RuleStrategy(BaseStrategy):
"""Trading strategy with rules"""
pass
class ModelStrategy(BaseStrategy):
def __init__(self, step_bar, model, dataset: DatasetH, start_time=None, end_time=None, **kwargs):
"""Trading Strategy by using Model to make predictions"""
def __init__(
self,
step_bar: str,
model: BaseModel,
dataset: DatasetH,
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
**kwargs,
):
"""
Parameters
----------
model : BaseModel
the model used in when making predictions
dataset : DatasetH
provide test data for model
kwargs : dict
arguments that will be passed into `reset` method
"""
self.model = model
self.dataset = dataset
self.pred_scores = self._convert_index_format(self.model.predict(dataset))
self.pred_scores = convert_index_format(self.model.predict(dataset), level="datetime")
# pred_score_dates = self.pred_scores.index.get_level_values(level="datetime")
super(ModelStrategy, self).__init__(step_bar, start_time, end_time, **kwargs)
def _convert_index_format(self, df):
if get_level_index(df, level="datetime") == 1:
df = df.swaplevel().sort_index()
return df
def _update_model(self):
"""update pred score"""
"""
Update model in each bar when using online data as the following steps:
- update dataset with online data, the dataset should support online update
- make the latest prediction scores of the new bar
- update the pred score into the latest prediction
"""
raise NotImplementedError("_update_model is not implemented!")
class TradingEnhancement:
def reset(self, trade_order_list=None):
class RLStrategy(BaseStrategy):
"""RL-based Strategy"""
def __init__(
self,
step_bar: str,
policy,
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
**kwargs,
):
"""
Parameters
----------
policy :
RL policy for generate action
"""
super(RLStrategy, self).__init__(step_bar, start_time, end_time, **kwargs)
self.policy = policy
class RLIntStrategy(RLStrategy):
"""(RL)-based (Strategy) with (Int)erpreter"""
def __init__(
self,
step_bar: str,
policy,
state_interpreter: StateInterpreter,
action_interpreter: ActionInterpreter,
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
state_interpret_kwargs: dict = {},
action_interpret_kwargs: dict = {},
**kwargs,
):
"""
Parameters
----------
state_interpreter : StateInterpreter
interpretor that interprets the qlib execute result into rl env state.
action_interpreter : ActionInterpreter
interpretor that interprets the rl agent action into qlib order list
start_time : Union[str, pd.Timestamp], optional
start time of trading, by default None
end_time : Union[str, pd.Timestamp], optional
end time of trading, by default None
state_interpret_kwargs : dict, optional
arguments may be used in `state_interpreter.interpret`, by default {}
such as the following arguments:
- trade exchange : Exchange
Exchange that can provide market info
action_interpret_kwargs: dict, optional
arguments may be used in `action_interpreter.interpret`, by default {}
such as the following arguments:
- trade_order_list : List[Order]
If the strategy is used to split order, it presents the trade order pool.
"""
super(RLIntStrategy, self).__init__(step_bar, policy, start_time, end_time, **kwargs)
self.policy = policy
self.action_interpreter = action_interpreter
self.state_interpreter = state_interpreter
self.state_interpret_kwargs = state_interpret_kwargs
self.action_interpret_kwargs = action_interpret_kwargs
def generate_order_list(self, execute_state):
super(RLStrategy, self).step()
_interpret_state = self.state_interpretor.interpret(
execute_result=execute_state, **self.action_interpret_kwargs
)
_policy_action = self.policy.step(_interpret_state)
_order_list = self.action_interpreter.interpret(action=_policy_action, **self.state_interpret_kwargs)
return _order_list
class OrderEnhancement:
"""
Order enhancement for strategy
- If the strategy is used to split orders, the enhancement should be inherited
- If the strategy is used for portfolio management, the enhancement can be ignored
"""
def reset(self, trade_order_list: List[Order] = None):
"""reset trade orders for split strategy
Parameters
----------
trade_order_list for split strategy: List[Order], optional
trading orders , by default None
"""
if trade_order_list is not None:
self.trade_order_list = trade_order_list

View File

@@ -800,217 +800,3 @@ def fname_to_code(fname: str):
if fname.startswith(prefix):
fname = fname.lstrip(prefix)
return fname
########################## Sample ############################
def sample_calendar_bac(calendar_raw, freq_raw, freq_sam):
"""
freq_raw : "min" or "day"
"""
freq_raw = "1" + freq_raw if re.match("^[0-9]", freq_raw) is None else freq_raw
freq_sam = "1" + freq_sam if re.match("^[0-9]", freq_sam) is None else freq_sam
if freq_sam.endswith(("minute", "min")):
def cal_next_sam_minute(x, sam_minutes):
hour = x.hour
minute = x.minute
if 9 <= hour <= 11:
minute_index = (11 - hour) * 60 + 30 - minute + 120
elif 13 <= hour <= 15:
minute_index = (15 - hour) * 60 - minute
else:
raise ValueError("calendar hour must be in [9, 11] or [13, 15]")
minute_index = minute_index // sam_minutes * sam_minutes
if 0 <= minute_index < 120:
return 15 - (minute_index + 59) // 60, (120 - minute_index) % 60
elif 120 <= minute_index < 240:
return 11 - (minute_index - 120 + 29) // 60, (240 - minute_index + 30) % 60
else:
raise ValueError("calendar minute_index error")
sam_minutes = int(freq_sam[:-3]) if freq_sam.endswith("min") else int(freq_sam[:-6])
if not freq_raw.endswith(("minute", "min")):
raise ValueError("when sampling minute calendar, freq of raw calendar must be minute or min")
else:
raw_minutes = int(freq_raw[:-3]) if freq_raw.endswith("min") else int(freq_raw[:-6])
if raw_minutes > sam_minutes:
raise ValueError("raw freq must be higher than sample freq")
_calendar_minute = np.unique(
list(
map(
lambda x: pd.Timestamp(x.year, x.month, x.day, *cal_next_sam_minute(x, sam_minutes), 59),
calendar_raw,
)
)
)
return _calendar_minute
else:
_calendar_day = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, 23, 59, 59), calendar_raw)))
if freq_sam.endswith(("day", "d")):
sam_days = int(freq_sam[:-1]) if freq_sam.endswith("d") else int(freq_sam[:-3])
return _calendar_day[(len(_calendar_day) + sam_days - 1) % sam_days :: sam_days]
elif freq_sam.endswith(("week", "w")):
sam_weeks = int(freq_sam[:-1]) if freq_sam.endswith("w") else int(freq_sam[:-4])
_day_in_week = np.array(list(map(lambda x: x.dayofweek, _calendar_day)))
_calendar_week = _calendar_day[np.ediff1d(_day_in_week[::-1], to_begin=1)[::-1] > 0]
return _calendar_week[(len(_calendar_week) + sam_weeks - 1) % sam_weeks :: sam_weeks]
elif freq_sam.endswith(("month", "m")):
sam_months = int(freq_sam[:-1]) if freq_sam.endswith("m") else int(freq_sam[:-5])
_day_in_month = np.array(list(map(lambda x: x.day, _calendar_day)))
_calendar_month = _calendar_day[np.ediff1d(_day_in_month[::-1], to_begin=1)[::-1] > 0]
return _calendar_month[(len(_calendar_month) + sam_months - 1) % sam_months :: sam_months]
else:
raise ValueError("sample freq must be xmin, xd, xw, xm")
def parse_freq(freq):
freq = freq.lower()
search_obj = re.search("^([0-9]*)([a-z]+)", freq)
if search_obj is None:
raise ValueError("freq format is not supported")
_count = int(search_obj.group(1) if search_obj.group(1) else "1")
_freq = search_obj.group(2)
_freq_format_dict = {
"month": "month",
"mon": "month",
"week": "week",
"w": "week",
"day": "day",
"d": "day",
"minute": "minute",
"min": "minute",
}
try:
_freq = _freq_format_dict.get(_freq)
except KeyError:
raise ValueError(
"freq format is not supported, the supported freq includes (x)month/m, (x)day/d, (x)minute/min"
)
return _count, _freq
def sample_calendar(calendar_raw, freq_raw, freq_sam):
"""
freq_raw : "min" or "day"
"""
raw_count, freq_raw = parse_freq(freq_raw)
sam_count, freq_sam = parse_freq(freq_sam)
if not len(calendar_raw):
return calendar_raw
if freq_sam == "minute":
def cal_next_sam_minute(x, sam_minutes):
hour = x.hour
minute = x.minute
if (hour == 9 and minute >= 30) or (9 < hour < 11) or (hour == 11 and minute < 30):
minute_index = (hour - 9) * 60 + minute - 30
elif 13 <= hour < 15:
minute_index = (hour - 13) * 60 + minute + 120
else:
raise ValueError("calendar hour must be in [9, 11] or [13, 15]")
minute_index = minute_index // sam_minutes * sam_minutes
if 0 <= minute_index < 120:
return 9 + (minute_index + 30) // 60, (minute_index + 30) % 60
elif 120 <= minute_index < 240:
return 13 + (minute_index - 120) // 60, (minute_index - 120) % 60
else:
raise ValueError("calendar minute_index error")
if req_raw != "minute":
raise ValueError("when sampling minute calendar, freq of raw calendar must be minute or min")
else:
if raw_count > sam_count:
raise ValueError("raw freq must be higher than sample freq")
_calendar_minute = np.unique(
list(
map(lambda x: pd.Timestamp(x.year, x.month, x.day, *cal_next_sam_minute(x, sam_count), 0), calendar_raw)
)
)
if calendar_raw[0] > _calendar_minute[0]:
_calendar_minute[0] = calendar_raw[0]
return _calendar_minute
else:
_calendar_day = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, 0, 0, 0), calendar_raw)))
if freq_sam == "day":
return _calendar_day[::sam_count]
elif freq_sam == "week":
_day_in_week = np.array(list(map(lambda x: x.dayofweek, _calendar_day)))
_calendar_week = _calendar_day[np.ediff1d(_day_in_week, to_begin=-1) < 0]
return _calendar_week[::sam_count]
elif freq_sam == "month":
_day_in_month = np.array(list(map(lambda x: x.day, _calendar_day)))
_calendar_month = _calendar_day[np.ediff1d(_day_in_month, to_begin=-1) < 0]
return _calendar_month[::sam_count]
else:
raise ValueError("sample freq must be xmin, xd, xw, xm")
def get_sample_freq_calendar(start_time=None, end_time=None, freq="day", **kwargs):
_, norm_freq = parse_freq(freq)
from ..data.data import Cal
try:
_calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq=freq, **kwargs)
freq, freq_sam = freq, None
except ValueError:
freq_sam = freq
if norm_freq in ["month", "week", "day"]:
try:
_calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq="day", freq_sam=freq, **kwargs)
freq = "day"
except ValueError:
raise
_calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq="min", freq_sam=freq, **kwargs)
freq = "min"
elif norm_freq == "minute":
_calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq="min", freq_sam=freq, **kwargs)
freq = "min"
else:
raise ValueError(f"freq {freq} is not supported")
return _calendar, freq, freq_sam
def sample_feature(feature, start_time=None, end_time=None, fields=None, method="last", method_kwargs={}):
selector_datetime = slice(start_time, end_time)
fields = fields if fields else slice(None)
from ..data.dataset.utils import get_level_index
datetime_level = get_level_index(feature, level="datetime") == 0
if isinstance(feature, pd.Series):
feature = feature.loc[selector_datetime] if datetime_level else feature.loc[(slice(None), selector_datetime)]
elif isinstance(feature, pd.DataFrame):
feature = (
feature.loc[selector_datetime, fields]
if datetime_level
else feature.loc[(slice(None), selector_datetime), fields]
)
if feature.empty:
return None
if isinstance(feature.index, pd.MultiIndex):
if callable(method):
method_func = method
return feature.groupby(level="instrument").apply(lambda x: method_func(x, **method_kwargs))
elif isinstance(method, str):
return getattr(feature.groupby(level="instrument"), method)(**method_kwargs)
else:
if callable(method):
method_func = method
return method_func(feature, **method_kwargs)
elif isinstance(method, str):
return getattr(feature, method)(**method_kwargs)
return feature

300
qlib/utils/sample.py Normal file
View File

@@ -0,0 +1,300 @@
import re
import numpy as np
import pandas as pd
from typing import Tuple, List, Union, Optional, Callable
def parse_freq(freq: str) -> Tuple[int, str]:
"""
Parse freq into a unified format
Parameters
----------
freq : str
Raw freq, supported freq should match the re '^([0-9]*)(month|mon|week|w|day|d|minute|min)$'
Returns
-------
freq: Tuple[int, str]
Unified freq, including freq count and unified freq unit. The freq unit should be '[month|week|day|minute]'.
Example:
.. code-block::
print(parse_freq("day"))
(1, "day" )
print(parse_freq("2mon"))
(2, "month")
print(parse_freq("10w"))
(10, "week")
"""
freq = freq.lower()
match_obj = re.match("^([0-9]*)(month|mon|week|w|day|d|minute|min)$", freq)
if match_obj is None:
raise ValueError(
"freq format is not supported, the freq should be like (n)month/mon, (n)week/w, (n)day/d, (n)minute/min"
)
_count = int(match_obj.group(1) if match_obj.group(1) else "1")
_freq = match_obj.group(2)
_freq_format_dict = {
"month": "month",
"mon": "month",
"week": "week",
"w": "week",
"day": "day",
"d": "day",
"minute": "minute",
"min": "minute",
}
return _count, _freq_format_dict[_freq]
def sample_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> np.ndarray:
"""
Sample the calendar with frequency freq_raw into the calendar with frequency freq_sam
Parameters
----------
calendar_raw : np.ndarray
The calendar with frequency freq_raw
freq_raw : str
Frequency of the raw calendar
freq_sam : str
Sample frequency
Returns
-------
np.ndarray
The calendar with frequency freq_sam
"""
raw_count, freq_raw = parse_freq(freq_raw)
sam_count, freq_sam = parse_freq(freq_sam)
if not len(calendar_raw):
return calendar_raw
if freq_sam == "minute":
def cal_next_sam_minute(x, sam_minutes):
hour = x.hour
minute = x.minute
if (hour == 9 and minute >= 30) or (9 < hour < 11) or (hour == 11 and minute < 30):
minute_index = (hour - 9) * 60 + minute - 30
elif 13 <= hour < 15:
minute_index = (hour - 13) * 60 + minute + 120
else:
raise ValueError("calendar hour must be in [9, 11] or [13, 15]")
minute_index = minute_index // sam_minutes * sam_minutes
if 0 <= minute_index < 120:
return 9 + (minute_index + 30) // 60, (minute_index + 30) % 60
elif 120 <= minute_index < 240:
return 13 + (minute_index - 120) // 60, (minute_index - 120) % 60
else:
raise ValueError("calendar minute_index error")
if freq_raw != "minute":
raise ValueError("when sampling minute calendar, freq of raw calendar must be minute or min")
else:
if raw_count > sam_count:
raise ValueError("raw freq must be higher than sampling freq")
_calendar_minute = np.unique(
list(
map(lambda x: pd.Timestamp(x.year, x.month, x.day, *cal_next_sam_minute(x, sam_count), 0), calendar_raw)
)
)
if calendar_raw[0] > _calendar_minute[0]:
_calendar_minute[0] = calendar_raw[0]
return _calendar_minute
else:
_calendar_day = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, 0, 0, 0), calendar_raw)))
if freq_sam == "day":
return _calendar_day[::sam_count]
elif freq_sam == "week":
_day_in_week = np.array(list(map(lambda x: x.dayofweek, _calendar_day)))
_calendar_week = _calendar_day[np.ediff1d(_day_in_week, to_begin=-1) < 0]
return _calendar_week[::sam_count]
elif freq_sam == "month":
_day_in_month = np.array(list(map(lambda x: x.day, _calendar_day)))
_calendar_month = _calendar_day[np.ediff1d(_day_in_month, to_begin=-1) < 0]
return _calendar_month[::sam_count]
else:
raise ValueError("sampling freq must be xmin, xd, xw, xm")
def get_sample_freq_calendar(
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
freq: str = "day",
future: bool = False,
) -> Tuple[np.ndarray, str, Optional[str]]:
"""
Get the calendar with frequency freq.
- If the calendar with the raw frequency freq exists, return it directly
- Else, sample from a higher frequency calendar automatically
Parameters
----------
start_time : Union[str, pd.Timestamp], optional
start time of calendar, by default None
end_time : Union[str, pd.Timestamp], optional
end time of calendar, by default None
freq : str, optional
freq of calendar, by default "day"
future : bool, optional
whether including future trading day.
Returns
-------
Tuple[np.ndarray, str, Optional[str]]
- the first value is the calendar
- the second value is the raw freq of calendar
- the third value is the sampling freq of calendar, it's None if the raw frequency freq exists.
"""
_, norm_freq = parse_freq(freq)
from ..data.data import Cal
try:
_calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq=freq, future=future)
freq, freq_sam = freq, None
except ValueError:
freq_sam = freq
if norm_freq in ["month", "week", "day"]:
try:
_calendar = Cal.calendar(
start_time=start_time, end_time=end_time, freq="day", freq_sam=freq, future=future
)
freq = "day"
except ValueError:
_calendar = Cal.calendar(
start_time=start_time, end_time=end_time, freq="min", freq_sam=freq, future=future
)
freq = "min"
elif norm_freq == "minute":
_calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq="min", freq_sam=freq, future=future)
freq = "min"
else:
raise ValueError(f"freq {freq} is not supported")
return _calendar, freq, freq_sam
def sample_feature(
feature: Union[pd.DataFrame, pd.Series],
start_time: Union[str, pd.Timestamp] = None,
end_time: Union[str, pd.Timestamp] = None,
fields: Union[str, List[str]] = None,
method: Union[str, Callable] = "last",
method_kwargs: dict = {},
):
"""
Sample value from pandas DataFrame or Series for each stock
- If `feature` has MultiIndex[instrument, datetime], apply the `method` to each instruemnt data with datetime in [start_time, end_time]
Example:
.. code-block::
print(feature)
$close $volume
instrument datetime
SH600000 2010-01-04 86.778313 16162960.0
2010-01-05 87.433578 28117442.0
2010-01-06 85.713585 23632884.0
2010-01-07 83.788803 20813402.0
2010-01-08 84.730675 16044853.0
SH600655 2010-01-04 2699.567383 158193.328125
2010-01-08 2612.359619 77501.406250
2010-01-11 2712.982422 160852.390625
2010-01-12 2788.688232 164587.937500
2010-01-13 2790.604004 145460.453125
print(sample_feature(feature, start_time="2010-01-04", end_time="2010-01-05", fields=["$close", "$volume"], method="last"))
$close $volume
instrument
SH600000 87.433578 28117442.0
SH600655 2699.567383 158193.328125
- Else, the `feature` should have Index[datetime], just apply the `method` to `feature` directly
Example:
.. code-block::
print(feature)
$close $volume
datetime
2010-01-04 86.778313 16162960.0
2010-01-05 87.433578 28117442.0
2010-01-06 85.713585 23632884.0
2010-01-07 83.788803 20813402.0
2010-01-08 84.730675 16044853.0
print(sample_feature(feature, start_time="2010-01-04", end_time="2010-01-05", fields=["$close", "$volume"], method="last"))
$close 87.433578
$volume 28117442.0
print(sample_feature(feature, start_time="2010-01-04", end_time="2010-01-05", fields="$close", method="last"))
87.433578
Parameters
----------
feature : Union[pd.DataFrame, pd.Series]
Raw feature to be sampled
start_time : Union[str, pd.Timestamp], optional
start sampling time, by default None
end_time : Union[str, pd.Timestamp], optional
end sampling time, by default None
fields : Union[str, List[str]], optional
column names, it's ignored when sample pd.Series data, by default None(all columns)
method : Union[str, Callable], optional
sample method, apply method function to each stock series data, by default "last"
- If type(method) is str, it should be an attribute of SeriesGroupBy or DataFrameGroupby, and run feature.groupby
- If `feature` has MultiIndex[instrument, datetime], method must be a member of pandas.groupby when it's type is str.or callable function.
method_kwargs : dict, optional
arguments of method, by default {}
Returns
-------
The Sampled DataFrame/Series/Value
"""
selector_datetime = slice(start_time, end_time)
if fields is None:
fields = slice(None)
from ..data.dataset.utils import get_level_index
datetime_level = get_level_index(feature, level="datetime") == 0
if isinstance(feature, pd.Series):
feature = feature.loc[selector_datetime] if datetime_level else feature.loc[(slice(None), selector_datetime)]
elif isinstance(feature, pd.DataFrame):
feature = (
feature.loc[selector_datetime, fields]
if datetime_level
else feature.loc[(slice(None), selector_datetime), fields]
)
if feature.empty:
return None
if isinstance(feature.index, pd.MultiIndex):
if callable(method):
method_func = method
return feature.groupby(level="instrument").apply(lambda x: method_func(x, **method_kwargs))
elif isinstance(method, str):
return getattr(feature.groupby(level="instrument"), method)(**method_kwargs)
else:
if callable(method):
method_func = method
return method_func(feature, **method_kwargs)
elif isinstance(method, str):
return getattr(feature, method)(**method_kwargs)
return feature

View File

@@ -14,7 +14,8 @@ from ..data.dataset import DatasetH
from ..data.dataset.handler import DataHandlerLP
from ..utils import init_instance_by_config, get_module_by_module_path
from ..log import get_module_logger
from ..utils import flatten_dict, parse_freq
from ..utils import flatten_dict
from ..utils.sample import parse_freq
from ..strategy.base import BaseStrategy
from ..contrib.eva.alpha import calc_ic, calc_long_short_return
@@ -315,16 +316,6 @@ class PortAnaRecord(RecordTemp):
ret_freq.extend(self._get_report_freq(env_config["kwargs"]["sub_env"]))
return ret_freq
def _cal_risk_analysis_scaler(self, freq):
_count, _freq = parse_freq(freq)
_freq_scaler = {
"minute": 240 * 250,
"day": 250,
"week": 50,
"month": 12,
}
return _count * _freq_scaler[_freq]
def generate(self, **kwargs):
# custom strategy and get backtest
report_dict = normal_backtest(env=self.env_config, strategy=self.strategy_config, **self.backtest_config)
@@ -343,12 +334,11 @@ class PortAnaRecord(RecordTemp):
else:
report_normal, _ = report_dict.get(self.risk_analysis_freq)
analysis = dict()
risk_analysis_scaler = self._cal_risk_analysis_scaler(self.risk_analysis_freq)
analysis["excess_return_without_cost"] = risk_analysis(
report_normal["return"] - report_normal["bench"], risk_analysis_scaler
report_normal["return"] - report_normal["bench"], self.risk_analysis_freq
)
analysis["excess_return_with_cost"] = risk_analysis(
report_normal["return"] - report_normal["bench"] - report_normal["cost"], risk_analysis_scaler
report_normal["return"] - report_normal["bench"] - report_normal["cost"], self.risk_analysis_freq
)
analysis_df = pd.concat(analysis) # type: pd.DataFrame
# log metrics