mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
fix some comments and add docstring
This commit is contained in:
@@ -91,13 +91,13 @@ if __name__ == "__main__":
|
||||
},
|
||||
},
|
||||
"env": {
|
||||
"class": "SplitEnv",
|
||||
"module_path": "qlib.contrib.backtest.env",
|
||||
"class": "SplitExecutor",
|
||||
"module_path": "qlib.contrib.backtest.executor",
|
||||
"kwargs": {
|
||||
"step_bar": "week",
|
||||
"sub_env": {
|
||||
"class": "SimulatorEnv",
|
||||
"module_path": "qlib.contrib.backtest.env",
|
||||
"class": "SimulatorExecutor",
|
||||
"module_path": "qlib.contrib.backtest.executor",
|
||||
"kwargs": {
|
||||
"step_bar": "day",
|
||||
"verbose": True,
|
||||
@@ -118,14 +118,17 @@ if __name__ == "__main__":
|
||||
"backtest": {
|
||||
"start_time": trade_start_time,
|
||||
"end_time": trade_end_time,
|
||||
"verbose": False,
|
||||
"limit_threshold": 0.095,
|
||||
"account": 100000000,
|
||||
"benchmark": benchmark,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
"exchange_kwargs": {
|
||||
"freq": "day",
|
||||
"verbose": False,
|
||||
"limit_threshold": 0.095,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from .order import Order
|
||||
from .position import Position
|
||||
|
||||
from .exchange import Exchange
|
||||
from .report import Report
|
||||
from .executor import BaseExecutor
|
||||
from .backtest import backtest as backtest_func
|
||||
|
||||
import copy
|
||||
import numpy as np
|
||||
import inspect
|
||||
from ...strategy.base import BaseStrategy
|
||||
from ...utils import init_instance_by_config
|
||||
from ...log import get_module_logger
|
||||
from ...config import C
|
||||
@@ -90,21 +88,6 @@ def get_exchange(
|
||||
return init_instance_by_config(exchange, accept_types=Exchange)
|
||||
|
||||
|
||||
def init_env_instance_by_config(env):
|
||||
if isinstance(env, dict):
|
||||
env_config = copy.copy(env)
|
||||
if "kwargs" in env_config:
|
||||
env_kwargs = copy.copy(env_config["kwargs"])
|
||||
if "sub_env" in env_kwargs:
|
||||
env_kwargs["sub_env"] = init_env_instance_by_config(env_kwargs["sub_env"])
|
||||
if "sub_strategy" in env_kwargs:
|
||||
env_kwargs["sub_strategy"] = init_instance_by_config(env_kwargs["sub_strategy"])
|
||||
env_config["kwargs"] = env_kwargs
|
||||
return init_instance_by_config(env_config)
|
||||
else:
|
||||
return env
|
||||
|
||||
|
||||
def setup_exchange(root_instance, trade_exchange=None, force=False):
|
||||
if "trade_exchange" in inspect.getfullargspec(root_instance.__class__).args:
|
||||
if force:
|
||||
@@ -118,13 +101,11 @@ def setup_exchange(root_instance, trade_exchange=None, force=False):
|
||||
setup_exchange(root_instance.sub_strategy, trade_exchange)
|
||||
|
||||
|
||||
def backtest(start_time, end_time, strategy, env, benchmark="SH000905", account=1e9, **kwargs):
|
||||
trade_strategy = init_instance_by_config(strategy)
|
||||
trade_env = init_env_instance_by_config(env)
|
||||
def backtest(start_time, end_time, strategy, env, benchmark="SH000905", account=1e9, exchange_kwargs={}):
|
||||
trade_strategy = init_instance_by_config(strategy, accept_types=BaseStrategy)
|
||||
trade_env = init_instance_by_config(env, accept_types=BaseExecutor)
|
||||
|
||||
spec = inspect.getfullargspec(get_exchange)
|
||||
exchange_args = {k: v for k, v in kwargs.items() if k in spec.args}
|
||||
trade_exchange = get_exchange(**exchange_args)
|
||||
trade_exchange = get_exchange(**exchange_kwargs)
|
||||
|
||||
setup_exchange(trade_env, trade_exchange)
|
||||
setup_exchange(trade_strategy, trade_exchange)
|
||||
|
||||
@@ -3,13 +3,14 @@
|
||||
|
||||
|
||||
import copy
|
||||
import warnings
|
||||
import pandas as pd
|
||||
|
||||
from .position import Position
|
||||
from .report import Report
|
||||
from .order import Order
|
||||
from ...data import D
|
||||
from ...utils import parse_freq, sample_feature
|
||||
from ...utils.sample import parse_freq, sample_feature
|
||||
|
||||
|
||||
"""
|
||||
@@ -110,6 +111,8 @@ class Account:
|
||||
for k, v in kwargs.items():
|
||||
if hasattr(self, k):
|
||||
setattr(self, k, v)
|
||||
else:
|
||||
warnings.warn(f"reser error, attribute {k} is not found!")
|
||||
|
||||
def get_positions(self):
|
||||
return self.positions
|
||||
|
||||
@@ -1,10 +1,6 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from .account import Account
|
||||
|
||||
|
||||
@@ -14,9 +10,9 @@ def backtest(start_time, end_time, trade_strategy, trade_env, benchmark, account
|
||||
trade_env.reset(start_time=start_time, end_time=end_time, trade_account=trade_account)
|
||||
trade_strategy.reset(start_time=start_time, end_time=end_time)
|
||||
|
||||
trade_state = trade_env.get_init_state()
|
||||
_execute_state = trade_env.get_init_state()
|
||||
while not trade_env.finished():
|
||||
_order_list = trade_strategy.generate_order_list(**trade_state)
|
||||
trade_state, trade_info = trade_env.execute(_order_list)
|
||||
_order_list = trade_strategy.generate_order_list(_execute_state)
|
||||
_execute_state = trade_env.execute(_order_list)
|
||||
|
||||
return trade_env.get_report()
|
||||
|
||||
@@ -11,7 +11,7 @@ import pandas as pd
|
||||
from ...data.data import D
|
||||
from ...data.dataset.utils import get_level_index
|
||||
from ...config import C, REG_CN
|
||||
from ...utils import sample_feature
|
||||
from ...utils.sample import sample_feature
|
||||
from ...log import get_module_logger
|
||||
from .order import Order
|
||||
|
||||
|
||||
@@ -1,19 +1,34 @@
|
||||
import re
|
||||
import json
|
||||
import copy
|
||||
import warnings
|
||||
import pathlib
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Tuple, List, Union, Optional, Callable
|
||||
from ...data.data import Cal
|
||||
from ...utils import get_sample_freq_calendar, parse_freq
|
||||
from .position import Position
|
||||
from ...strategy.base import BaseStrategy
|
||||
from ...utils import init_instance_by_config
|
||||
from ...utils.sample import get_sample_freq_calendar, parse_freq
|
||||
from .report import Report
|
||||
from .order import Order
|
||||
from .account import Account
|
||||
from .exchange import Exchange
|
||||
|
||||
|
||||
class BaseTradeCalendar:
|
||||
def __init__(self, step_bar, start_time=None, end_time=None, **kwargs):
|
||||
def __init__(
|
||||
self, step_bar: str, start_time: Union[str, pd.Timestamp] = None, end_time: Union[str, pd.Timestamp] = None
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
step_bar : str
|
||||
frequency of each trading step bar
|
||||
start_time : Union[str, pd.Timestamp], optional
|
||||
start time of trading, by default None
|
||||
If `start_time` is None, it must be reset before trading.
|
||||
end_time : Union[str, pd.Timestamp], optional
|
||||
end time of trading, by default None
|
||||
If `end_time` is None, it must be reset before trading.
|
||||
"""
|
||||
|
||||
self.step_bar = step_bar
|
||||
self.reset(start_time=start_time, end_time=end_time)
|
||||
|
||||
@@ -27,10 +42,9 @@ class BaseTradeCalendar:
|
||||
if self.start_time and self.end_time:
|
||||
_calendar, freq, freq_sam = get_sample_freq_calendar(freq=self.step_bar)
|
||||
self.calendar = _calendar
|
||||
_start_time, _end_time, _start_index, _end_index = Cal.locate_index(
|
||||
_, _, _start_index, _end_index = Cal.locate_index(
|
||||
self.start_time, self.end_time, freq=freq, freq_sam=freq_sam
|
||||
)
|
||||
_trade_calendar = self.calendar[_start_index : _end_index + 1]
|
||||
self.start_index = _start_index
|
||||
self.end_index = _end_index
|
||||
self.trade_len = _end_index - _start_index + 1
|
||||
@@ -45,6 +59,8 @@ class BaseTradeCalendar:
|
||||
for k, v in kwargs.items():
|
||||
if hasattr(self, k):
|
||||
setattr(self, k, v)
|
||||
else:
|
||||
warnings.warn(f"reser error, attribute {k} is not found!")
|
||||
|
||||
def _get_calendar_time(self, trade_index=1, shift=0):
|
||||
trade_index = trade_index - shift
|
||||
@@ -55,34 +71,43 @@ class BaseTradeCalendar:
|
||||
return self.trade_index >= self.trade_len - 1
|
||||
|
||||
def step(self):
|
||||
if self.finished():
|
||||
raise RuntimeError(f"this env has completed its task, please reset it if you want to call it!")
|
||||
self.trade_index = self.trade_index + 1
|
||||
|
||||
|
||||
class BaseEnv(BaseTradeCalendar):
|
||||
"""
|
||||
# Strategy framework document
|
||||
|
||||
class Env(BaseEnv):
|
||||
"""
|
||||
class BaseExecutor(BaseTradeCalendar):
|
||||
"""Base executor for trading"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
step_bar,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
trade_account=None,
|
||||
generate_report=False,
|
||||
verbose=False,
|
||||
step_bar: str,
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
trade_account: Account = None,
|
||||
generate_report: bool = False,
|
||||
verbose: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
self.generate_report = generate_report
|
||||
self.verbose = verbose
|
||||
super(BaseEnv, self).__init__(
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
trade_account : Account, optional
|
||||
trade account for trading, by default None
|
||||
If `trade_account` is None, it must be reset before trading
|
||||
generate_report : bool, optional
|
||||
whether to generate report, by default False
|
||||
verbose : bool, optional
|
||||
whether to print log, by default False
|
||||
"""
|
||||
super(BaseExecutor, self).__init__(
|
||||
step_bar=step_bar, start_time=start_time, end_time=end_time, trade_account=trade_account, **kwargs
|
||||
)
|
||||
self.generate_report = generate_report
|
||||
self.verbose = verbose
|
||||
|
||||
def reset(self, trade_account=None, **kwargs):
|
||||
super(BaseEnv, self).reset(**kwargs)
|
||||
super(BaseExecutor, self).reset(**kwargs)
|
||||
if trade_account:
|
||||
self.trade_account = trade_account
|
||||
self.trade_account.reset(freq=self.step_bar, report=Report(), positions={})
|
||||
@@ -101,23 +126,31 @@ class BaseEnv(BaseTradeCalendar):
|
||||
raise NotImplementedError("get_report is not implemented!")
|
||||
|
||||
|
||||
class SplitEnv(BaseEnv):
|
||||
class SplitExecutor(BaseExecutor):
|
||||
def __init__(
|
||||
self,
|
||||
step_bar,
|
||||
sub_env,
|
||||
sub_strategy,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
trade_account=None,
|
||||
trade_exchange=None,
|
||||
generate_report=False,
|
||||
verbose=False,
|
||||
step_bar: str,
|
||||
sub_env: Union[BaseExecutor, dict],
|
||||
sub_strategy: Union[BaseStrategy, dict],
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
trade_account: Account = None,
|
||||
trade_exchange: Exchange = None,
|
||||
generate_report: bool = False,
|
||||
verbose: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
self.sub_env = sub_env
|
||||
self.sub_strategy = sub_strategy
|
||||
super(SplitEnv, self).__init__(
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
sub_env : BaseExecutor
|
||||
trading env in each trading bar.
|
||||
sub_strategy : BaseStrategy
|
||||
trading strategy in each trading bar
|
||||
trade_exchange : Exchange
|
||||
exchange that provides market info
|
||||
"""
|
||||
super(SplitExecutor, self).__init__(
|
||||
step_bar=step_bar,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
@@ -127,28 +160,26 @@ class SplitEnv(BaseEnv):
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
self.sub_env = init_instance_by_config(sub_env, accept_types=BaseExecutor)
|
||||
self.sub_strategy = init_instance_by_config(sub_strategy, accept_types=BaseStrategy)
|
||||
|
||||
def reset(self, trade_account=None, trade_exchange=None, **kwargs):
|
||||
super(SplitEnv, self).reset(trade_account=trade_account, **kwargs)
|
||||
|
||||
super(SplitExecutor, self).reset(trade_account=trade_account, **kwargs)
|
||||
if trade_account:
|
||||
self.sub_env.reset(trade_account=copy.copy(trade_account))
|
||||
if trade_exchange:
|
||||
self.trade_exchange = trade_exchange
|
||||
|
||||
def execute(self, order_list, **kwargs):
|
||||
if self.finished():
|
||||
raise StopIteration(f"this env has completed its task, please reset it if you want to call it!")
|
||||
# if self.track:
|
||||
# yield action
|
||||
# episode_reward = 0
|
||||
super(SplitEnv, self).step()
|
||||
def execute(self, order_list):
|
||||
super(SplitExecutor, self).step()
|
||||
trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index)
|
||||
self.sub_env.reset(start_time=trade_start_time, end_time=trade_end_time)
|
||||
self.sub_strategy.reset(start_time=trade_start_time, end_time=trade_end_time, trade_order_list=order_list)
|
||||
trade_state = self.sub_env.get_init_state()
|
||||
_execute_state = self.sub_env.get_init_state()
|
||||
while not self.sub_env.finished():
|
||||
_order_list = self.sub_strategy.generate_order_list(**trade_state)
|
||||
trade_state, trade_info = self.sub_env.execute(order_list=_order_list)
|
||||
_order_list = self.sub_strategy.generate_order_list(_execute_state)
|
||||
_execute_state = self.sub_env.execute(order_list=_order_list)
|
||||
|
||||
self.trade_account.update_bar_end(
|
||||
trade_start_time=trade_start_time,
|
||||
@@ -156,9 +187,8 @@ class SplitEnv(BaseEnv):
|
||||
trade_exchange=self.trade_exchange,
|
||||
update_report=self.generate_report,
|
||||
)
|
||||
_obs = {"current": self.trade_account.current}
|
||||
_info = {}
|
||||
return _obs, _info
|
||||
_execute_state = {"current": self.trade_account.current}
|
||||
return _execute_state
|
||||
|
||||
def get_report(self):
|
||||
sub_env_report_dict = self.sub_env.get_report()
|
||||
@@ -167,12 +197,10 @@ class SplitEnv(BaseEnv):
|
||||
_positions = self.trade_account.get_positions()
|
||||
_count, _freq = parse_freq(self.step_bar)
|
||||
sub_env_report_dict.update({f"{_count}{_freq}": (_report, _positions)})
|
||||
return sub_env_report_dict
|
||||
else:
|
||||
return sub_env_report_dict
|
||||
return sub_env_report_dict
|
||||
|
||||
|
||||
class SimulatorEnv(BaseEnv):
|
||||
class SimulatorExecutor(BaseExecutor):
|
||||
def __init__(
|
||||
self,
|
||||
step_bar,
|
||||
@@ -184,7 +212,13 @@ class SimulatorEnv(BaseEnv):
|
||||
verbose=False,
|
||||
**kwargs,
|
||||
):
|
||||
super(SimulatorEnv, self).__init__(
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
trade_exchange : Exchange
|
||||
exchange that provides market info
|
||||
"""
|
||||
super(SimulatorExecutor, self).__init__(
|
||||
step_bar=step_bar,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
@@ -196,17 +230,12 @@ class SimulatorEnv(BaseEnv):
|
||||
)
|
||||
|
||||
def reset(self, trade_exchange=None, **kwargs):
|
||||
super(SimulatorEnv, self).reset(**kwargs)
|
||||
super(SimulatorExecutor, self).reset(**kwargs)
|
||||
if trade_exchange:
|
||||
self.trade_exchange = trade_exchange
|
||||
|
||||
def execute(self, order_list, **kwargs):
|
||||
"""
|
||||
Return: obs, done, info
|
||||
"""
|
||||
if self.finished():
|
||||
raise StopIteration(f"this env has completed its task, please reset it if you want to call it!")
|
||||
super(SimulatorEnv, self).step()
|
||||
def execute(self, order_list):
|
||||
super(SimulatorExecutor, self).step()
|
||||
trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index)
|
||||
trade_info = []
|
||||
for order in order_list:
|
||||
@@ -219,21 +248,25 @@ class SimulatorEnv(BaseEnv):
|
||||
if self.verbose:
|
||||
if order.direction == Order.SELL: # sell
|
||||
print(
|
||||
"[I {:%Y-%m-%d}]: sell {}, price {:.2f}, amount {}, value {:.2f}.".format(
|
||||
"[I {:%Y-%m-%d}]: sell {}, price {:.2f}, amount {}, deal_amount {}, factor {}, value {:.2f}.".format(
|
||||
trade_start_time,
|
||||
order.stock_id,
|
||||
trade_price,
|
||||
order.amount,
|
||||
order.deal_amount,
|
||||
order.factor,
|
||||
trade_val,
|
||||
)
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"[I {:%Y-%m-%d}]: buy {}, price {:.2f}, amount {}, value {:.2f}.".format(
|
||||
"[I {:%Y-%m-%d}]: buy {}, price {:.2f}, amount {}, deal_amount {}, factor {}, value {:.2f}.".format(
|
||||
trade_start_time,
|
||||
order.stock_id,
|
||||
trade_price,
|
||||
order.amount,
|
||||
order.deal_amount,
|
||||
order.factor,
|
||||
trade_val,
|
||||
)
|
||||
)
|
||||
@@ -249,9 +282,8 @@ class SimulatorEnv(BaseEnv):
|
||||
trade_exchange=self.trade_exchange,
|
||||
update_report=self.generate_report,
|
||||
)
|
||||
_obs = {"current": self.trade_account.current}
|
||||
_info = {"trade_info": trade_info}
|
||||
return _obs, _info
|
||||
_execute_state = {"current": self.trade_account.current, "trade_info": trade_info}
|
||||
return _execute_state
|
||||
|
||||
def get_report(self):
|
||||
if self.generate_report:
|
||||
@@ -1,16 +0,0 @@
|
||||
class BaseInterpreter:
|
||||
@staticmethod
|
||||
def interpret(**kwargs):
|
||||
raise NotImplementedError("interpret is not implemented!")
|
||||
|
||||
|
||||
class ActionInterpreter:
|
||||
@staticmethod
|
||||
def interpret(action, **kwargs):
|
||||
return action
|
||||
|
||||
|
||||
class StateInterpreter:
|
||||
@staticmethod
|
||||
def interpret(state, **kwargs):
|
||||
return state
|
||||
@@ -10,6 +10,7 @@ import warnings
|
||||
from ..log import get_module_logger
|
||||
from .backtest import get_exchange, backtest as backtest_func
|
||||
from ..utils import get_date_range
|
||||
from ..utils.sample import parse_freq
|
||||
|
||||
from ..data import D
|
||||
from ..config import C
|
||||
@@ -19,7 +20,7 @@ from ..data.dataset.utils import get_level_index
|
||||
logger = get_module_logger("Evaluate")
|
||||
|
||||
|
||||
def risk_analysis(r, N=252):
|
||||
def risk_analysis(r, N: int = None, freq: str = None):
|
||||
"""Risk Analysis
|
||||
|
||||
Parameters
|
||||
@@ -27,8 +28,26 @@ def risk_analysis(r, N=252):
|
||||
r : pandas.Series
|
||||
daily return series.
|
||||
N: int
|
||||
scaler for annualizing information_ratio (day: 250, week: 50, month: 12).
|
||||
scaler for annualizing information_ratio (day: 250, week: 50, month: 12), at least one of `N` and `freq` should exist
|
||||
freq: str
|
||||
analysis frequency used for calculating the scaler, at least one of `N` and `freq` should exist
|
||||
"""
|
||||
|
||||
def cal_risk_analysis_scaler(freq):
|
||||
_count, _freq = parse_freq(freq)
|
||||
_freq_scaler = {
|
||||
"minute": 240 * 250,
|
||||
"day": 250,
|
||||
"week": 50,
|
||||
"month": 12,
|
||||
}
|
||||
return _count * _freq_scaler[_freq]
|
||||
|
||||
if N is None and freq is None:
|
||||
raise ValueError("at least one of `N` and `freq` should exist")
|
||||
if N is None:
|
||||
N = cal_risk_analysis_scaler(freq)
|
||||
|
||||
mean = r.mean()
|
||||
std = r.std(ddof=1)
|
||||
annualized_return = mean * N
|
||||
|
||||
@@ -1,291 +0,0 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
import re
|
||||
import json
|
||||
import copy
|
||||
import pathlib
|
||||
import pandas as pd
|
||||
from ...data import D
|
||||
from ...utils import get_date_in_file_name
|
||||
from ...utils import get_pre_trading_date
|
||||
from ..backtest.order import Order
|
||||
|
||||
|
||||
class BaseExecutor:
|
||||
"""
|
||||
# Strategy framework document
|
||||
|
||||
class Executor(BaseExecutor):
|
||||
"""
|
||||
|
||||
def execute(self, trade_account, order_list, trade_date):
|
||||
"""
|
||||
return the executed result (trade_info) after trading at trade_date.
|
||||
NOTICE: trade_account will not be modified after executing.
|
||||
Parameter
|
||||
---------
|
||||
trade_account : Account()
|
||||
order_list : list
|
||||
[Order()]
|
||||
trade_date : pd.Timestamp
|
||||
Return
|
||||
---------
|
||||
trade_info : list
|
||||
[Order(), float, float, float]
|
||||
"""
|
||||
raise NotImplementedError("get_execute_result for this model is not implemented.")
|
||||
|
||||
def save_executed_file_from_trade_info(self, trade_info, user_path, trade_date):
|
||||
"""
|
||||
Save the trade_info to the .csv transaction file in disk
|
||||
the columns of result file is
|
||||
['date', 'stock_id', 'direction', 'trade_val', 'trade_cost', 'trade_price', 'factor']
|
||||
Parameter
|
||||
---------
|
||||
trade_info : list of [Order(), float, float, float]
|
||||
(order, trade_val, trade_cost, trade_price), trade_info with out factor
|
||||
user_path: str / pathlib.Path()
|
||||
the sub folder to save user data
|
||||
|
||||
transaction_path : string / pathlib.Path()
|
||||
"""
|
||||
YYYY, MM, DD = str(trade_date.date()).split("-")
|
||||
folder_path = pathlib.Path(user_path) / "trade" / YYYY / MM
|
||||
if not folder_path.exists():
|
||||
folder_path.mkdir(parents=True)
|
||||
transaction_path = folder_path / "transaction_{}.csv".format(str(trade_date.date()))
|
||||
columns = [
|
||||
"date",
|
||||
"stock_id",
|
||||
"direction",
|
||||
"amount",
|
||||
"trade_val",
|
||||
"trade_cost",
|
||||
"trade_price",
|
||||
"factor",
|
||||
]
|
||||
data = []
|
||||
for [order, trade_val, trade_cost, trade_price] in trade_info:
|
||||
data.append(
|
||||
[
|
||||
trade_date,
|
||||
order.stock_id,
|
||||
order.direction,
|
||||
order.amount,
|
||||
trade_val,
|
||||
trade_cost,
|
||||
trade_price,
|
||||
order.factor,
|
||||
]
|
||||
)
|
||||
df = pd.DataFrame(data, columns=columns)
|
||||
df.to_csv(transaction_path, index=False)
|
||||
|
||||
def load_trade_info_from_executed_file(self, user_path, trade_date):
|
||||
YYYY, MM, DD = str(trade_date.date()).split("-")
|
||||
file_path = pathlib.Path(user_path) / "trade" / YYYY / MM / "transaction_{}.csv".format(str(trade_date.date()))
|
||||
if not file_path.exists():
|
||||
raise ValueError("File {} not exists!".format(file_path))
|
||||
|
||||
filedate = get_date_in_file_name(file_path)
|
||||
transaction = pd.read_csv(file_path)
|
||||
trade_info = []
|
||||
for i in range(len(transaction)):
|
||||
date = transaction.loc[i]["date"]
|
||||
if not date == filedate:
|
||||
continue
|
||||
# raise ValueError("date in transaction file {} not equal to it's file date{}".format(date, filedate))
|
||||
order = Order(
|
||||
stock_id=transaction.loc[i]["stock_id"],
|
||||
amount=transaction.loc[i]["amount"],
|
||||
trade_date=transaction.loc[i]["date"],
|
||||
direction=transaction.loc[i]["direction"],
|
||||
factor=transaction.loc[i]["factor"],
|
||||
)
|
||||
trade_val = transaction.loc[i]["trade_val"]
|
||||
trade_cost = transaction.loc[i]["trade_cost"]
|
||||
trade_price = transaction.loc[i]["trade_price"]
|
||||
trade_info.append([order, trade_val, trade_cost, trade_price])
|
||||
return trade_info
|
||||
|
||||
|
||||
class SimulatorExecutor(BaseExecutor):
|
||||
def __init__(self, trade_exchange, verbose=False):
|
||||
self.trade_exchange = trade_exchange
|
||||
self.verbose = verbose
|
||||
self.order_list = []
|
||||
|
||||
def execute(self, trade_account, order_list, trade_date):
|
||||
"""
|
||||
execute the order list, do the trading wil exchange at date.
|
||||
Will not modify the trade_account.
|
||||
Parameter
|
||||
trade_account : Account()
|
||||
order_list : list
|
||||
list or orders
|
||||
trade_date : pd.Timestamp
|
||||
:return:
|
||||
trade_info : list of [Order(), float, float, float]
|
||||
(order, trade_val, trade_cost, trade_price), trade_info with out factor
|
||||
"""
|
||||
account = copy.deepcopy(trade_account)
|
||||
trade_info = []
|
||||
|
||||
for order in order_list:
|
||||
# check holding thresh is done in strategy
|
||||
# if order.direction==0: # sell order
|
||||
# # checking holding thresh limit for sell order
|
||||
# if trade_account.current.get_stock_count(order.stock_id) < thresh:
|
||||
# # can not sell this code
|
||||
# continue
|
||||
# is order executable
|
||||
# check order
|
||||
if self.trade_exchange.check_order(order) is True:
|
||||
# execute the order
|
||||
trade_val, trade_cost, trade_price = self.trade_exchange.deal_order(order, trade_account=account)
|
||||
trade_info.append([order, trade_val, trade_cost, trade_price])
|
||||
if self.verbose:
|
||||
if order.direction == Order.SELL: # sell
|
||||
print(
|
||||
"[I {:%Y-%m-%d}]: sell {}, price {:.2f}, amount {}, value {:.2f}.".format(
|
||||
trade_date,
|
||||
order.stock_id,
|
||||
trade_price,
|
||||
order.deal_amount,
|
||||
trade_val,
|
||||
)
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"[I {:%Y-%m-%d}]: buy {}, price {:.2f}, amount {}, value {:.2f}.".format(
|
||||
trade_date,
|
||||
order.stock_id,
|
||||
trade_price,
|
||||
order.deal_amount,
|
||||
trade_val,
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
if self.verbose:
|
||||
print("[W {:%Y-%m-%d}]: {} wrong.".format(trade_date, order.stock_id))
|
||||
# do nothing
|
||||
pass
|
||||
return trade_info
|
||||
|
||||
|
||||
def save_score_series(score_series, user_path, trade_date):
|
||||
"""Save the score_series into a .csv file.
|
||||
The columns of saved file is
|
||||
[stock_id, score]
|
||||
|
||||
Parameter
|
||||
---------
|
||||
order_list: [Order()]
|
||||
list of Order()
|
||||
date: pd.Timestamp
|
||||
the date to save the order list
|
||||
user_path: str / pathlib.Path()
|
||||
the sub folder to save user data
|
||||
"""
|
||||
user_path = pathlib.Path(user_path)
|
||||
YYYY, MM, DD = str(trade_date.date()).split("-")
|
||||
folder_path = user_path / "score" / YYYY / MM
|
||||
if not folder_path.exists():
|
||||
folder_path.mkdir(parents=True)
|
||||
file_path = folder_path / "score_{}.csv".format(str(trade_date.date()))
|
||||
score_series.to_csv(file_path)
|
||||
|
||||
|
||||
def load_score_series(user_path, trade_date):
|
||||
"""Save the score_series into a .csv file.
|
||||
The columns of saved file is
|
||||
[stock_id, score]
|
||||
|
||||
Parameter
|
||||
---------
|
||||
order_list: [Order()]
|
||||
list of Order()
|
||||
date: pd.Timestamp
|
||||
the date to save the order list
|
||||
user_path: str / pathlib.Path()
|
||||
the sub folder to save user data
|
||||
"""
|
||||
user_path = pathlib.Path(user_path)
|
||||
YYYY, MM, DD = str(trade_date.date()).split("-")
|
||||
folder_path = user_path / "score" / YYYY / MM
|
||||
if not folder_path.exists():
|
||||
folder_path.mkdir(parents=True)
|
||||
file_path = folder_path / "score_{}.csv".format(str(trade_date.date()))
|
||||
score_series = pd.read_csv(file_path, index_col=0, header=None, names=["instrument", "score"])
|
||||
return score_series
|
||||
|
||||
|
||||
def save_order_list(order_list, user_path, trade_date):
|
||||
"""
|
||||
Save the order list into a json file.
|
||||
Will calculate the real amount in order according to factors at date.
|
||||
|
||||
The format in json file like
|
||||
{"sell": {"stock_id": amount, ...}
|
||||
,"buy": {"stock_id": amount, ...}}
|
||||
|
||||
:param
|
||||
order_list: [Order()]
|
||||
list of Order()
|
||||
date: pd.Timestamp
|
||||
the date to save the order list
|
||||
user_path: str / pathlib.Path()
|
||||
the sub folder to save user data
|
||||
"""
|
||||
user_path = pathlib.Path(user_path)
|
||||
YYYY, MM, DD = str(trade_date.date()).split("-")
|
||||
folder_path = user_path / "trade" / YYYY / MM
|
||||
if not folder_path.exists():
|
||||
folder_path.mkdir(parents=True)
|
||||
sell = {}
|
||||
buy = {}
|
||||
for order in order_list:
|
||||
if order.direction == 0: # sell
|
||||
sell[order.stock_id] = [order.amount, order.factor]
|
||||
else:
|
||||
buy[order.stock_id] = [order.amount, order.factor]
|
||||
order_dict = {"sell": sell, "buy": buy}
|
||||
file_path = folder_path / "orderlist_{}.json".format(str(trade_date.date()))
|
||||
with file_path.open("w") as fp:
|
||||
json.dump(order_dict, fp)
|
||||
|
||||
|
||||
def load_order_list(user_path, trade_date):
|
||||
user_path = pathlib.Path(user_path)
|
||||
YYYY, MM, DD = str(trade_date.date()).split("-")
|
||||
path = user_path / "trade" / YYYY / MM / "orderlist_{}.json".format(str(trade_date.date()))
|
||||
if not path.exists():
|
||||
raise ValueError("File {} not exists!".format(path))
|
||||
# get orders
|
||||
with path.open("r") as fp:
|
||||
order_dict = json.load(fp)
|
||||
order_list = []
|
||||
for stock_id in order_dict["sell"]:
|
||||
amount, factor = order_dict["sell"][stock_id]
|
||||
order = Order(
|
||||
stock_id=stock_id,
|
||||
amount=amount,
|
||||
trade_date=pd.Timestamp(trade_date),
|
||||
direction=Order.SELL,
|
||||
factor=factor,
|
||||
)
|
||||
order_list.append(order)
|
||||
for stock_id in order_dict["buy"]:
|
||||
amount, factor = order_dict["buy"][stock_id]
|
||||
order = Order(
|
||||
stock_id=stock_id,
|
||||
amount=amount,
|
||||
trade_date=pd.Timestamp(trade_date),
|
||||
direction=Order.BUY,
|
||||
factor=factor,
|
||||
)
|
||||
order_list.append(order)
|
||||
return order_list
|
||||
@@ -3,7 +3,7 @@ import warnings
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from ...utils import sample_feature
|
||||
from ...utils.sample import sample_feature
|
||||
from ...strategy.base import ModelStrategy
|
||||
from ..backtest.order import Order
|
||||
from .order_generator import OrderGenWInteract
|
||||
@@ -66,7 +66,7 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
if trade_exchange:
|
||||
self.trade_exchange = trade_exchange
|
||||
|
||||
def get_risk_degree(self, trade_index):
|
||||
def get_risk_degree(self, trade_index=None):
|
||||
"""get_risk_degree
|
||||
Return the proportion of your total value you will used in investment.
|
||||
Dynamically risk_degree will result in Market timing.
|
||||
@@ -74,7 +74,7 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
# It will use 95% amoutn of your total value by default
|
||||
return self.risk_degree
|
||||
|
||||
def generate_order_list(self, current, **kwargs):
|
||||
def generate_order_list(self, execute_state):
|
||||
super(TopkDropoutStrategy, self).step()
|
||||
trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index)
|
||||
pred_start_time, pred_end_time = self._get_calendar_time(self.trade_index, shift=1)
|
||||
@@ -120,6 +120,7 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
def filter_stock(l):
|
||||
return l
|
||||
|
||||
current = execute_state.get("current")
|
||||
current_temp = copy.deepcopy(current)
|
||||
# generate order list for this adjust date
|
||||
sell_order_list = []
|
||||
@@ -163,6 +164,7 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
|
||||
# Get the stock list we really want to buy
|
||||
buy = today[: len(sell) + self.topk - len(last)]
|
||||
print("INTRANEL BAR", len(sell), len(sell) + self.topk - len(last), len(last))
|
||||
# print("flag", len(sell), len(buy), self.topk, len(last))
|
||||
for code in current_stock_list:
|
||||
if not self.trade_exchange.is_stock_tradable(
|
||||
@@ -175,13 +177,17 @@ class TopkDropoutStrategy(ModelStrategy):
|
||||
continue
|
||||
# sell order
|
||||
sell_amount = current_temp.get_stock_amount(code=code)
|
||||
factor = self.trade_exchange.get_factor(
|
||||
stock_id=code, start_time=trade_start_time, end_time=trade_end_time
|
||||
)
|
||||
# sell_amount = self.trade_exchange.round_amount_by_trade_unit(sell_amount, factor)
|
||||
sell_order = Order(
|
||||
stock_id=code,
|
||||
amount=sell_amount,
|
||||
start_time=trade_start_time,
|
||||
end_time=trade_end_time,
|
||||
direction=Order.SELL, # 0 for sell, 1 for buy
|
||||
factor=self.trade_exchange.get_factor(code, trade_start_time, trade_end_time),
|
||||
factor=factor,
|
||||
)
|
||||
# is order executable
|
||||
if self.trade_exchange.check_order(sell_order):
|
||||
@@ -228,19 +234,36 @@ class WeightStrategyBase(ModelStrategy):
|
||||
def __init__(
|
||||
self,
|
||||
step_bar,
|
||||
model,
|
||||
dataset,
|
||||
start_time=None,
|
||||
end_time=None,
|
||||
order_generator_cls_or_obj=OrderGenWInteract,
|
||||
trade_exchange=None,
|
||||
**kwargs,
|
||||
):
|
||||
super(WeightStrategyBase, self).__init__(step_bar, start_time, end_time)
|
||||
self.trade_exchange = trade_exchange
|
||||
super(WeightStrategyBase, self).__init__(
|
||||
step_bar, model, dataset, start_time, end_time, trade_exchange=trade_exchange, **kwargs
|
||||
)
|
||||
|
||||
if isinstance(order_generator_cls_or_obj, type):
|
||||
self.order_generator = order_generator_cls_or_obj()
|
||||
else:
|
||||
self.order_generator = order_generator_cls_or_obj
|
||||
|
||||
def reset(self, trade_exchange=None, **kwargs):
|
||||
super(WeightStrategyBase, self).reset(**kwargs)
|
||||
if trade_exchange:
|
||||
self.trade_exchange = trade_exchange
|
||||
|
||||
def get_risk_degree(self, trade_index=None):
|
||||
"""get_risk_degree
|
||||
Return the proportion of your total value you will used in investment.
|
||||
Dynamically risk_degree will result in Market timing.
|
||||
"""
|
||||
# It will use 95% amoutn of your total value by default
|
||||
return 0.95
|
||||
|
||||
def generate_target_weight_position(self, score, current, trade_start_time, trade_end_time):
|
||||
"""
|
||||
Generate target position from score for this date and the current position.The cash is not considered in the position
|
||||
@@ -256,7 +279,7 @@ class WeightStrategyBase(ModelStrategy):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def generate_order_list(self, current, **kwargs):
|
||||
def generate_order_list(self, execute_state):
|
||||
"""
|
||||
Parameters
|
||||
-----------
|
||||
@@ -277,7 +300,8 @@ class WeightStrategyBase(ModelStrategy):
|
||||
pred_score = sample_feature(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
|
||||
if pred_score is None:
|
||||
return []
|
||||
current_temp = copy.deepcopy(trade_account.current)
|
||||
current = execute_state.get("current")
|
||||
current_temp = copy.deepcopy(current)
|
||||
target_weight_position = self.generate_target_weight_position(
|
||||
score=pred_score, current=current_temp, trade_start_time=trade_start_time, trade_end_time=trade_end_time
|
||||
)
|
||||
|
||||
@@ -3,14 +3,15 @@ import warnings
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from ...utils import sample_feature
|
||||
|
||||
from ...utils.sample import sample_feature
|
||||
from ...data.data import D
|
||||
from ...data.dataset.utils import get_level_index
|
||||
from ...strategy.base import RuleStrategy, TradingEnhancement
|
||||
from ...data.dataset.utils import convert_index_format
|
||||
from ...strategy.base import RuleStrategy, OrderEnhancement
|
||||
from ..backtest.order import Order
|
||||
|
||||
|
||||
class TWAPStrategy(RuleStrategy, TradingEnhancement):
|
||||
class TWAPStrategy(RuleStrategy, OrderEnhancement):
|
||||
def __init__(
|
||||
self,
|
||||
step_bar,
|
||||
@@ -23,7 +24,7 @@ class TWAPStrategy(RuleStrategy, TradingEnhancement):
|
||||
|
||||
def reset(self, trade_order_list=None, trade_exchange=None, **kwargs):
|
||||
super(TWAPStrategy, self).reset(**kwargs)
|
||||
TradingEnhancement.reset(self, trade_order_list=trade_order_list)
|
||||
OrderEnhancement.reset(self, trade_order_list=trade_order_list)
|
||||
if trade_exchange:
|
||||
self.trade_exchange = trade_exchange
|
||||
if trade_order_list:
|
||||
@@ -31,7 +32,7 @@ class TWAPStrategy(RuleStrategy, TradingEnhancement):
|
||||
for order in self.trade_order_list:
|
||||
self.trade_amount[(order.stock_id, order.direction)] = order.amount
|
||||
|
||||
def generate_order_list(self, **kwargs):
|
||||
def generate_order_list(self, execute_state):
|
||||
super(TWAPStrategy, self).step()
|
||||
trade_start_time, trade_end_time = self._get_calendar_time(self.trade_index)
|
||||
order_list = []
|
||||
@@ -66,7 +67,7 @@ class TWAPStrategy(RuleStrategy, TradingEnhancement):
|
||||
return order_list
|
||||
|
||||
|
||||
class SBBStrategyBase(RuleStrategy, TradingEnhancement):
|
||||
class SBBStrategyBase(RuleStrategy, OrderEnhancement):
|
||||
"""
|
||||
(S)elect the (B)etter one among every two adjacent trading (B)ars to sell or buy.
|
||||
"""
|
||||
@@ -87,7 +88,7 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement):
|
||||
|
||||
def reset(self, trade_order_list=None, trade_exchange=None, **kwargs):
|
||||
super(SBBStrategyBase, self).reset(**kwargs)
|
||||
TradingEnhancement.reset(self, trade_order_list=trade_order_list)
|
||||
OrderEnhancement.reset(self, trade_order_list=trade_order_list)
|
||||
if trade_exchange:
|
||||
self.trade_exchange = trade_exchange
|
||||
if trade_order_list is not None:
|
||||
@@ -100,7 +101,7 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement):
|
||||
def _pred_price_trend(self, stock_id, pred_start_time=None, pred_end_time=None):
|
||||
raise NotImplementedError("pred_price_trend method is not implemented!")
|
||||
|
||||
def generate_order_list(self, **kwargs):
|
||||
def generate_order_list(self, execute_state):
|
||||
super(SBBStrategyBase, self).step()
|
||||
if not self.trade_order_list:
|
||||
return []
|
||||
@@ -109,7 +110,7 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement):
|
||||
order_list = []
|
||||
for order in self.trade_order_list:
|
||||
if self.trade_index % 2 == 1:
|
||||
_pred_trend = self._pred_price_trend(order.stock_id)
|
||||
_pred_trend = self._pred_price_trend(order.stock_id, pred_start_time, pred_end_time)
|
||||
else:
|
||||
_pred_trend = self.trade_trend[(order.stock_id, order.direction)]
|
||||
|
||||
@@ -127,7 +128,7 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement):
|
||||
_order_amount = self.trade_amount[(order.stock_id, order.direction)] / (
|
||||
self.trade_len - self.trade_index
|
||||
)
|
||||
if self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit:
|
||||
elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit:
|
||||
trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
|
||||
_order_amount = (
|
||||
(trade_unit_cnt + self.trade_len - self.trade_index - 1)
|
||||
@@ -146,6 +147,7 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement):
|
||||
factor=order.factor,
|
||||
)
|
||||
order_list.append(_order)
|
||||
# print("DEBUG AMOUNT", _order_amount, self.trade_amount[(order.stock_id, order.direction)], _amount_trade_unit)
|
||||
else:
|
||||
_order_amount = None
|
||||
if _amount_trade_unit is None:
|
||||
@@ -154,12 +156,12 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement):
|
||||
* self.trade_amount[(order.stock_id, order.direction)]
|
||||
/ (self.trade_len - self.trade_index + 1)
|
||||
)
|
||||
if self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit:
|
||||
elif self.trade_amount[(order.stock_id, order.direction)] >= _amount_trade_unit:
|
||||
trade_unit_cnt = int(self.trade_amount[(order.stock_id, order.direction)] // _amount_trade_unit)
|
||||
_order_amount = (
|
||||
2
|
||||
* (trade_unit_cnt + self.trade_len - self.trade_index)
|
||||
(trade_unit_cnt + self.trade_len - self.trade_index)
|
||||
// (self.trade_len - self.trade_index + 1)
|
||||
* 2
|
||||
* _amount_trade_unit
|
||||
)
|
||||
if _order_amount:
|
||||
@@ -197,6 +199,7 @@ class SBBStrategyBase(RuleStrategy, TradingEnhancement):
|
||||
factor=order.factor,
|
||||
)
|
||||
order_list.append(_order)
|
||||
# print("DEBUG AMOUNT", _order_amount, self.trade_amount[(order.stock_id, order.direction)], _amount_trade_unit)
|
||||
if self.trade_index % 2 == 1:
|
||||
self.trade_trend[(order.stock_id, order.direction)] = _pred_trend
|
||||
|
||||
@@ -226,20 +229,15 @@ class SBBStrategyEMA(SBBStrategyBase):
|
||||
self.instruments = D.instruments(instruments)
|
||||
self.freq = freq
|
||||
|
||||
def _convert_index_format(self, df):
|
||||
if get_level_index(df, level="datetime") == 1:
|
||||
df = df.swaplevel().sort_index()
|
||||
return df
|
||||
|
||||
def _reset_trade_calendar(self, start_time=None, end_time=None):
|
||||
super(SBBStrategyEMA, self)._reset_trade_calendar(start_time=start_time, end_time=end_time)
|
||||
def reset(self, start_time=None, end_time=None, **kwargs):
|
||||
super(SBBStrategyEMA, self).reset(start_time=start_time, end_time=end_time, **kwargs)
|
||||
if self.start_time and self.end_time:
|
||||
fields = ["EMA($close, 10)-EMA($close, 20)"]
|
||||
signal_start_time, _ = self._get_calendar_time(trade_index=self.trade_index, shift=1)
|
||||
signal_df = D.features(
|
||||
self.instruments, fields, start_time=signal_start_time, end_time=self.end_time, freq=self.freq
|
||||
)
|
||||
signal_df = self._convert_index_format(signal_df)
|
||||
signal_df = convert_index_format(signal_df)
|
||||
signal_df.columns = ["signal"]
|
||||
self.signal = {}
|
||||
for stock_id, stock_val in signal_df.groupby(level="instrument"):
|
||||
|
||||
@@ -25,7 +25,8 @@ from ..log import get_module_logger
|
||||
from ..utils import parse_field, read_bin, hash_args, normalize_cache_fields, code_to_fname
|
||||
from .base import Feature
|
||||
from .cache import DiskDatasetCache, DiskExpressionCache
|
||||
from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path, sample_calendar
|
||||
from ..utils import Wrapper, init_instance_by_config, register_wrapper, get_module_by_module_path
|
||||
from ..utils.sample import sample_calendar
|
||||
|
||||
|
||||
class CalendarProvider(abc.ABC):
|
||||
@@ -35,7 +36,7 @@ class CalendarProvider(abc.ABC):
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def calendar(self, start_time=None, end_time=None, freq="day", future=False):
|
||||
def calendar(self, start_time=None, end_time=None, freq="day", freq_sam=None, future=False):
|
||||
"""Get calendar of certain market in given time range.
|
||||
|
||||
Parameters
|
||||
@@ -46,6 +47,8 @@ class CalendarProvider(abc.ABC):
|
||||
end of the time range.
|
||||
freq : str
|
||||
time frequency, available: year/quarter/month/week/day.
|
||||
freq_sam : str
|
||||
sample frequency used for sampling lower-frequency calendar, by default None(raw calendar).
|
||||
future : bool
|
||||
whether including future trading day.
|
||||
|
||||
@@ -769,7 +772,7 @@ class ClientCalendarProvider(CalendarProvider):
|
||||
def set_conn(self, conn):
|
||||
self.conn = conn
|
||||
|
||||
def calendar(self, start_time=None, end_time=None, freq="day", future=False):
|
||||
def calendar(self, start_time=None, end_time=None, freq="day", freq_sam=None, future=False):
|
||||
|
||||
self.conn.send_request(
|
||||
request_type="calendar",
|
||||
@@ -937,8 +940,8 @@ class BaseProvider:
|
||||
To keep compatible with old qlib provider.
|
||||
"""
|
||||
|
||||
def calendar(self, start_time=None, end_time=None, freq="day", future=False):
|
||||
return Cal.calendar(start_time, end_time, freq, future=future)
|
||||
def calendar(self, start_time=None, end_time=None, freq="day", freq_sam=None, future=False):
|
||||
return Cal.calendar(start_time, end_time, freq, freq_sam, future=future)
|
||||
|
||||
def instruments(self, market="all", filter_pipe=None, start_time=None, end_time=None):
|
||||
if start_time is not None or end_time is not None:
|
||||
|
||||
@@ -70,3 +70,27 @@ def fetch_df_by_index(
|
||||
return df.loc[
|
||||
pd.IndexSlice[idx_slc],
|
||||
]
|
||||
|
||||
|
||||
def convert_index_format(df: Union[pd.DataFrame, pd.Series], level: str = "datetime") -> Union[pd.DataFrame, pd.Series]:
|
||||
"""
|
||||
Convert the format of df.MultiIndex according to the following rules:
|
||||
- If `level` is the first level of df.MultiIndex, do nothing
|
||||
- If `level` is the second level of df.MultiIndex, swap the level of index.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
df : Union[pd.DataFrame, pd.Series]
|
||||
raw DataFrame/Series
|
||||
level : str, optional
|
||||
the level that will be converted to the first one, by default "datetime"
|
||||
|
||||
Returns
|
||||
-------
|
||||
Union[pd.DataFrame, pd.Series]
|
||||
converted DataFrame/Series
|
||||
"""
|
||||
|
||||
if get_level_index(df, level=level) == 1:
|
||||
df = df.swaplevel().sort_index()
|
||||
return df
|
||||
|
||||
2
qlib/rl/__init__.py
Normal file
2
qlib/rl/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
104
qlib/rl/env.py
Normal file
104
qlib/rl/env.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
from .interpreter import StateInterpreter, ActionInterpreter
|
||||
|
||||
from ..contrib.backtest.executor import BaseExecutor
|
||||
|
||||
|
||||
class BaseRLEnv:
|
||||
def reset(self, **kwargs):
|
||||
raise NotImplementedError("reset is not implemented!")
|
||||
|
||||
def step(self, action):
|
||||
"""
|
||||
step method of rl env
|
||||
Parameters
|
||||
----------
|
||||
action :
|
||||
action from rl policy
|
||||
|
||||
Returns
|
||||
-------
|
||||
env state to rl policy
|
||||
"""
|
||||
raise NotImplementedError("step is not implemented!")
|
||||
|
||||
|
||||
class QlibRLEnv:
|
||||
"""qlib-based RL env"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
executor: BaseExecutor,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
executor : BaseExecutor
|
||||
qlib multi-level/single-level executor, which can be regarded as gamecore in RL
|
||||
"""
|
||||
self.executor = executor
|
||||
|
||||
def reset(self, **kwargs):
|
||||
self.executor.reset(**kwargs)
|
||||
|
||||
|
||||
class QlibIntRLEnv(QlibRLEnv):
|
||||
"""(Qlib)-based RL (Env) with (Interpreter)"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
executor: BaseExecutor,
|
||||
state_interpreter: StateInterpreter,
|
||||
action_interpreter: ActionInterpreter,
|
||||
state_interpret_kwargs: dict = {},
|
||||
action_interpret_kwargs: dict = {},
|
||||
):
|
||||
"""
|
||||
|
||||
Parameters
|
||||
----------
|
||||
state_interpreter : StateInterpreter
|
||||
interpretor that interprets the qlib execute result into rl env state.
|
||||
action_interpreter : ActionInterpreter
|
||||
interpretor that interprets the rl agent action into qlib order list
|
||||
state_interpret_kwargs : dict, optional
|
||||
arguments may be used in `state_interpreter.interpret`, by default {}
|
||||
such as the following arguments:
|
||||
- trade exchange : Exchange
|
||||
Exchange that can provide market info
|
||||
action_interpret_kwargs: dict, optional
|
||||
arguments may be used in `action_interpreter.interpret`, by default {}
|
||||
such as the following arguments:
|
||||
- trade_order_list : List[Order]
|
||||
If the strategy is used to split order, it presents the trade order pool.
|
||||
"""
|
||||
super(QlibIntRLEnv, self).__init__(executor=executor)
|
||||
self.state_interpreter = state_interpreter
|
||||
self.action_interpreter = action_interpreter
|
||||
self.state_interpret_kwargs = state_interpret_kwargs
|
||||
self.action_interpret_kwargs = action_interpret_kwargs
|
||||
|
||||
def step(self, action):
|
||||
"""
|
||||
step method of rl env, it run as following step:
|
||||
- Use `action_interpreter.interpret` method to interpret the agent action into order list
|
||||
- Execute the order list with qlib executor, and get the executed result
|
||||
- Use `state_interpreter.interpret` method to interpret the executed result into env state
|
||||
|
||||
Parameters
|
||||
----------
|
||||
action :
|
||||
action from rl policy
|
||||
|
||||
Returns
|
||||
-------
|
||||
env state to rl rl policy
|
||||
"""
|
||||
_interpret_action = self.action_interpreter.interpret(action=action, **self.state_interpret_kwargs)
|
||||
_execute_result = self.executor.execute(_interpret_action)
|
||||
_interpret_state = self.state_interpreter.interpret(
|
||||
execute_result=_execute_result, **self.action_interpret_kwargs
|
||||
)
|
||||
return _interpret_state
|
||||
20
qlib/rl/interpreter.py
Normal file
20
qlib/rl/interpreter.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
class BaseInterpreter:
|
||||
@staticmethod
|
||||
def interpret(**kwargs):
|
||||
raise NotImplementedError("interpret is not implemented!")
|
||||
|
||||
|
||||
class ActionInterpreter(BaseInterpreter):
|
||||
@staticmethod
|
||||
def interpret(action, **kwargs):
|
||||
raise NotImplementedError("interpret is not implemented!")
|
||||
|
||||
|
||||
class StateInterpreter(BaseInterpreter):
|
||||
@staticmethod
|
||||
def interpret(execute_result, **kwargs):
|
||||
raise NotImplementedError("interpret is not implemented!")
|
||||
@@ -1,55 +1,160 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
|
||||
import copy
|
||||
import warnings
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Tuple, List, Union, Optional, Callable
|
||||
|
||||
|
||||
from ..utils import get_sample_freq_calendar
|
||||
from ..model.base import BaseModel
|
||||
from ..data.dataset import DatasetH
|
||||
from ..data.dataset.utils import get_level_index
|
||||
from ..data.dataset.utils import convert_index_format
|
||||
from ..contrib.backtest.order import Order
|
||||
from ..contrib.backtest.env import BaseTradeCalendar
|
||||
|
||||
"""
|
||||
1. BaseStrategy 的粒度一定是数据粒度的整数倍
|
||||
- 关于calendar的合并咋整
|
||||
- adjust_dates这个东西啥用
|
||||
- label和freq和strategy的bar分离,这个如何决策呢
|
||||
"""
|
||||
from ..contrib.backtest.executor import BaseTradeCalendar
|
||||
from ..rl.interpreter import ActionInterpreter, StateInterpreter
|
||||
|
||||
|
||||
class BaseStrategy(BaseTradeCalendar):
|
||||
def generate_order_list(self, **kwargs):
|
||||
"""Base strategy"""
|
||||
|
||||
def generate_order_list(self, execute_state):
|
||||
"""Generate order list in each trading bar"""
|
||||
raise NotImplementedError("generator_order_list is not implemented!")
|
||||
|
||||
|
||||
class RuleStrategy(BaseStrategy):
|
||||
"""Trading strategy with rules"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ModelStrategy(BaseStrategy):
|
||||
def __init__(self, step_bar, model, dataset: DatasetH, start_time=None, end_time=None, **kwargs):
|
||||
"""Trading Strategy by using Model to make predictions"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
step_bar: str,
|
||||
model: BaseModel,
|
||||
dataset: DatasetH,
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
model : BaseModel
|
||||
the model used in when making predictions
|
||||
dataset : DatasetH
|
||||
provide test data for model
|
||||
kwargs : dict
|
||||
arguments that will be passed into `reset` method
|
||||
"""
|
||||
self.model = model
|
||||
self.dataset = dataset
|
||||
self.pred_scores = self._convert_index_format(self.model.predict(dataset))
|
||||
self.pred_scores = convert_index_format(self.model.predict(dataset), level="datetime")
|
||||
# pred_score_dates = self.pred_scores.index.get_level_values(level="datetime")
|
||||
super(ModelStrategy, self).__init__(step_bar, start_time, end_time, **kwargs)
|
||||
|
||||
def _convert_index_format(self, df):
|
||||
if get_level_index(df, level="datetime") == 1:
|
||||
df = df.swaplevel().sort_index()
|
||||
return df
|
||||
|
||||
def _update_model(self):
|
||||
"""update pred score"""
|
||||
"""
|
||||
Update model in each bar when using online data as the following steps:
|
||||
- update dataset with online data, the dataset should support online update
|
||||
- make the latest prediction scores of the new bar
|
||||
- update the pred score into the latest prediction
|
||||
"""
|
||||
raise NotImplementedError("_update_model is not implemented!")
|
||||
|
||||
|
||||
class TradingEnhancement:
|
||||
def reset(self, trade_order_list=None):
|
||||
class RLStrategy(BaseStrategy):
|
||||
"""RL-based Strategy"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
step_bar: str,
|
||||
policy,
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
policy :
|
||||
RL policy for generate action
|
||||
"""
|
||||
super(RLStrategy, self).__init__(step_bar, start_time, end_time, **kwargs)
|
||||
self.policy = policy
|
||||
|
||||
|
||||
class RLIntStrategy(RLStrategy):
|
||||
"""(RL)-based (Strategy) with (Int)erpreter"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
step_bar: str,
|
||||
policy,
|
||||
state_interpreter: StateInterpreter,
|
||||
action_interpreter: ActionInterpreter,
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
state_interpret_kwargs: dict = {},
|
||||
action_interpret_kwargs: dict = {},
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
state_interpreter : StateInterpreter
|
||||
interpretor that interprets the qlib execute result into rl env state.
|
||||
action_interpreter : ActionInterpreter
|
||||
interpretor that interprets the rl agent action into qlib order list
|
||||
start_time : Union[str, pd.Timestamp], optional
|
||||
start time of trading, by default None
|
||||
end_time : Union[str, pd.Timestamp], optional
|
||||
end time of trading, by default None
|
||||
state_interpret_kwargs : dict, optional
|
||||
arguments may be used in `state_interpreter.interpret`, by default {}
|
||||
such as the following arguments:
|
||||
- trade exchange : Exchange
|
||||
Exchange that can provide market info
|
||||
action_interpret_kwargs: dict, optional
|
||||
arguments may be used in `action_interpreter.interpret`, by default {}
|
||||
such as the following arguments:
|
||||
- trade_order_list : List[Order]
|
||||
If the strategy is used to split order, it presents the trade order pool.
|
||||
"""
|
||||
super(RLIntStrategy, self).__init__(step_bar, policy, start_time, end_time, **kwargs)
|
||||
|
||||
self.policy = policy
|
||||
self.action_interpreter = action_interpreter
|
||||
self.state_interpreter = state_interpreter
|
||||
self.state_interpret_kwargs = state_interpret_kwargs
|
||||
self.action_interpret_kwargs = action_interpret_kwargs
|
||||
|
||||
def generate_order_list(self, execute_state):
|
||||
super(RLStrategy, self).step()
|
||||
_interpret_state = self.state_interpretor.interpret(
|
||||
execute_result=execute_state, **self.action_interpret_kwargs
|
||||
)
|
||||
_policy_action = self.policy.step(_interpret_state)
|
||||
_order_list = self.action_interpreter.interpret(action=_policy_action, **self.state_interpret_kwargs)
|
||||
return _order_list
|
||||
|
||||
|
||||
class OrderEnhancement:
|
||||
"""
|
||||
Order enhancement for strategy
|
||||
- If the strategy is used to split orders, the enhancement should be inherited
|
||||
- If the strategy is used for portfolio management, the enhancement can be ignored
|
||||
"""
|
||||
|
||||
def reset(self, trade_order_list: List[Order] = None):
|
||||
"""reset trade orders for split strategy
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trade_order_list for split strategy: List[Order], optional
|
||||
trading orders , by default None
|
||||
"""
|
||||
if trade_order_list is not None:
|
||||
self.trade_order_list = trade_order_list
|
||||
|
||||
@@ -800,217 +800,3 @@ def fname_to_code(fname: str):
|
||||
if fname.startswith(prefix):
|
||||
fname = fname.lstrip(prefix)
|
||||
return fname
|
||||
|
||||
|
||||
########################## Sample ############################
|
||||
def sample_calendar_bac(calendar_raw, freq_raw, freq_sam):
|
||||
"""
|
||||
freq_raw : "min" or "day"
|
||||
"""
|
||||
freq_raw = "1" + freq_raw if re.match("^[0-9]", freq_raw) is None else freq_raw
|
||||
freq_sam = "1" + freq_sam if re.match("^[0-9]", freq_sam) is None else freq_sam
|
||||
|
||||
if freq_sam.endswith(("minute", "min")):
|
||||
|
||||
def cal_next_sam_minute(x, sam_minutes):
|
||||
hour = x.hour
|
||||
minute = x.minute
|
||||
if 9 <= hour <= 11:
|
||||
minute_index = (11 - hour) * 60 + 30 - minute + 120
|
||||
elif 13 <= hour <= 15:
|
||||
minute_index = (15 - hour) * 60 - minute
|
||||
else:
|
||||
raise ValueError("calendar hour must be in [9, 11] or [13, 15]")
|
||||
|
||||
minute_index = minute_index // sam_minutes * sam_minutes
|
||||
|
||||
if 0 <= minute_index < 120:
|
||||
return 15 - (minute_index + 59) // 60, (120 - minute_index) % 60
|
||||
elif 120 <= minute_index < 240:
|
||||
return 11 - (minute_index - 120 + 29) // 60, (240 - minute_index + 30) % 60
|
||||
else:
|
||||
raise ValueError("calendar minute_index error")
|
||||
|
||||
sam_minutes = int(freq_sam[:-3]) if freq_sam.endswith("min") else int(freq_sam[:-6])
|
||||
|
||||
if not freq_raw.endswith(("minute", "min")):
|
||||
raise ValueError("when sampling minute calendar, freq of raw calendar must be minute or min")
|
||||
else:
|
||||
raw_minutes = int(freq_raw[:-3]) if freq_raw.endswith("min") else int(freq_raw[:-6])
|
||||
if raw_minutes > sam_minutes:
|
||||
raise ValueError("raw freq must be higher than sample freq")
|
||||
|
||||
_calendar_minute = np.unique(
|
||||
list(
|
||||
map(
|
||||
lambda x: pd.Timestamp(x.year, x.month, x.day, *cal_next_sam_minute(x, sam_minutes), 59),
|
||||
calendar_raw,
|
||||
)
|
||||
)
|
||||
)
|
||||
return _calendar_minute
|
||||
else:
|
||||
|
||||
_calendar_day = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, 23, 59, 59), calendar_raw)))
|
||||
if freq_sam.endswith(("day", "d")):
|
||||
sam_days = int(freq_sam[:-1]) if freq_sam.endswith("d") else int(freq_sam[:-3])
|
||||
return _calendar_day[(len(_calendar_day) + sam_days - 1) % sam_days :: sam_days]
|
||||
|
||||
elif freq_sam.endswith(("week", "w")):
|
||||
sam_weeks = int(freq_sam[:-1]) if freq_sam.endswith("w") else int(freq_sam[:-4])
|
||||
_day_in_week = np.array(list(map(lambda x: x.dayofweek, _calendar_day)))
|
||||
_calendar_week = _calendar_day[np.ediff1d(_day_in_week[::-1], to_begin=1)[::-1] > 0]
|
||||
return _calendar_week[(len(_calendar_week) + sam_weeks - 1) % sam_weeks :: sam_weeks]
|
||||
|
||||
elif freq_sam.endswith(("month", "m")):
|
||||
sam_months = int(freq_sam[:-1]) if freq_sam.endswith("m") else int(freq_sam[:-5])
|
||||
_day_in_month = np.array(list(map(lambda x: x.day, _calendar_day)))
|
||||
_calendar_month = _calendar_day[np.ediff1d(_day_in_month[::-1], to_begin=1)[::-1] > 0]
|
||||
return _calendar_month[(len(_calendar_month) + sam_months - 1) % sam_months :: sam_months]
|
||||
else:
|
||||
raise ValueError("sample freq must be xmin, xd, xw, xm")
|
||||
|
||||
|
||||
def parse_freq(freq):
|
||||
freq = freq.lower()
|
||||
search_obj = re.search("^([0-9]*)([a-z]+)", freq)
|
||||
if search_obj is None:
|
||||
raise ValueError("freq format is not supported")
|
||||
_count = int(search_obj.group(1) if search_obj.group(1) else "1")
|
||||
_freq = search_obj.group(2)
|
||||
_freq_format_dict = {
|
||||
"month": "month",
|
||||
"mon": "month",
|
||||
"week": "week",
|
||||
"w": "week",
|
||||
"day": "day",
|
||||
"d": "day",
|
||||
"minute": "minute",
|
||||
"min": "minute",
|
||||
}
|
||||
try:
|
||||
_freq = _freq_format_dict.get(_freq)
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
"freq format is not supported, the supported freq includes (x)month/m, (x)day/d, (x)minute/min"
|
||||
)
|
||||
return _count, _freq
|
||||
|
||||
|
||||
def sample_calendar(calendar_raw, freq_raw, freq_sam):
|
||||
"""
|
||||
freq_raw : "min" or "day"
|
||||
"""
|
||||
raw_count, freq_raw = parse_freq(freq_raw)
|
||||
sam_count, freq_sam = parse_freq(freq_sam)
|
||||
if not len(calendar_raw):
|
||||
return calendar_raw
|
||||
if freq_sam == "minute":
|
||||
|
||||
def cal_next_sam_minute(x, sam_minutes):
|
||||
hour = x.hour
|
||||
minute = x.minute
|
||||
if (hour == 9 and minute >= 30) or (9 < hour < 11) or (hour == 11 and minute < 30):
|
||||
minute_index = (hour - 9) * 60 + minute - 30
|
||||
elif 13 <= hour < 15:
|
||||
minute_index = (hour - 13) * 60 + minute + 120
|
||||
else:
|
||||
raise ValueError("calendar hour must be in [9, 11] or [13, 15]")
|
||||
|
||||
minute_index = minute_index // sam_minutes * sam_minutes
|
||||
|
||||
if 0 <= minute_index < 120:
|
||||
return 9 + (minute_index + 30) // 60, (minute_index + 30) % 60
|
||||
elif 120 <= minute_index < 240:
|
||||
return 13 + (minute_index - 120) // 60, (minute_index - 120) % 60
|
||||
else:
|
||||
raise ValueError("calendar minute_index error")
|
||||
|
||||
if req_raw != "minute":
|
||||
raise ValueError("when sampling minute calendar, freq of raw calendar must be minute or min")
|
||||
else:
|
||||
if raw_count > sam_count:
|
||||
raise ValueError("raw freq must be higher than sample freq")
|
||||
_calendar_minute = np.unique(
|
||||
list(
|
||||
map(lambda x: pd.Timestamp(x.year, x.month, x.day, *cal_next_sam_minute(x, sam_count), 0), calendar_raw)
|
||||
)
|
||||
)
|
||||
if calendar_raw[0] > _calendar_minute[0]:
|
||||
_calendar_minute[0] = calendar_raw[0]
|
||||
return _calendar_minute
|
||||
else:
|
||||
_calendar_day = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, 0, 0, 0), calendar_raw)))
|
||||
if freq_sam == "day":
|
||||
return _calendar_day[::sam_count]
|
||||
|
||||
elif freq_sam == "week":
|
||||
_day_in_week = np.array(list(map(lambda x: x.dayofweek, _calendar_day)))
|
||||
_calendar_week = _calendar_day[np.ediff1d(_day_in_week, to_begin=-1) < 0]
|
||||
return _calendar_week[::sam_count]
|
||||
|
||||
elif freq_sam == "month":
|
||||
_day_in_month = np.array(list(map(lambda x: x.day, _calendar_day)))
|
||||
_calendar_month = _calendar_day[np.ediff1d(_day_in_month, to_begin=-1) < 0]
|
||||
return _calendar_month[::sam_count]
|
||||
else:
|
||||
raise ValueError("sample freq must be xmin, xd, xw, xm")
|
||||
|
||||
|
||||
def get_sample_freq_calendar(start_time=None, end_time=None, freq="day", **kwargs):
|
||||
_, norm_freq = parse_freq(freq)
|
||||
|
||||
from ..data.data import Cal
|
||||
|
||||
try:
|
||||
_calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq=freq, **kwargs)
|
||||
freq, freq_sam = freq, None
|
||||
except ValueError:
|
||||
freq_sam = freq
|
||||
if norm_freq in ["month", "week", "day"]:
|
||||
try:
|
||||
_calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq="day", freq_sam=freq, **kwargs)
|
||||
freq = "day"
|
||||
except ValueError:
|
||||
raise
|
||||
_calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq="min", freq_sam=freq, **kwargs)
|
||||
freq = "min"
|
||||
elif norm_freq == "minute":
|
||||
_calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq="min", freq_sam=freq, **kwargs)
|
||||
freq = "min"
|
||||
else:
|
||||
raise ValueError(f"freq {freq} is not supported")
|
||||
return _calendar, freq, freq_sam
|
||||
|
||||
|
||||
def sample_feature(feature, start_time=None, end_time=None, fields=None, method="last", method_kwargs={}):
|
||||
selector_datetime = slice(start_time, end_time)
|
||||
fields = fields if fields else slice(None)
|
||||
|
||||
from ..data.dataset.utils import get_level_index
|
||||
|
||||
datetime_level = get_level_index(feature, level="datetime") == 0
|
||||
if isinstance(feature, pd.Series):
|
||||
feature = feature.loc[selector_datetime] if datetime_level else feature.loc[(slice(None), selector_datetime)]
|
||||
elif isinstance(feature, pd.DataFrame):
|
||||
feature = (
|
||||
feature.loc[selector_datetime, fields]
|
||||
if datetime_level
|
||||
else feature.loc[(slice(None), selector_datetime), fields]
|
||||
)
|
||||
if feature.empty:
|
||||
return None
|
||||
if isinstance(feature.index, pd.MultiIndex):
|
||||
if callable(method):
|
||||
method_func = method
|
||||
return feature.groupby(level="instrument").apply(lambda x: method_func(x, **method_kwargs))
|
||||
elif isinstance(method, str):
|
||||
return getattr(feature.groupby(level="instrument"), method)(**method_kwargs)
|
||||
else:
|
||||
if callable(method):
|
||||
method_func = method
|
||||
return method_func(feature, **method_kwargs)
|
||||
elif isinstance(method, str):
|
||||
return getattr(feature, method)(**method_kwargs)
|
||||
|
||||
return feature
|
||||
|
||||
300
qlib/utils/sample.py
Normal file
300
qlib/utils/sample.py
Normal file
@@ -0,0 +1,300 @@
|
||||
import re
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Tuple, List, Union, Optional, Callable
|
||||
|
||||
|
||||
def parse_freq(freq: str) -> Tuple[int, str]:
|
||||
"""
|
||||
Parse freq into a unified format
|
||||
|
||||
Parameters
|
||||
----------
|
||||
freq : str
|
||||
Raw freq, supported freq should match the re '^([0-9]*)(month|mon|week|w|day|d|minute|min)$'
|
||||
|
||||
Returns
|
||||
-------
|
||||
freq: Tuple[int, str]
|
||||
Unified freq, including freq count and unified freq unit. The freq unit should be '[month|week|day|minute]'.
|
||||
Example:
|
||||
|
||||
.. code-block::
|
||||
|
||||
print(parse_freq("day"))
|
||||
(1, "day" )
|
||||
print(parse_freq("2mon"))
|
||||
(2, "month")
|
||||
print(parse_freq("10w"))
|
||||
(10, "week")
|
||||
|
||||
"""
|
||||
freq = freq.lower()
|
||||
match_obj = re.match("^([0-9]*)(month|mon|week|w|day|d|minute|min)$", freq)
|
||||
if match_obj is None:
|
||||
raise ValueError(
|
||||
"freq format is not supported, the freq should be like (n)month/mon, (n)week/w, (n)day/d, (n)minute/min"
|
||||
)
|
||||
_count = int(match_obj.group(1) if match_obj.group(1) else "1")
|
||||
_freq = match_obj.group(2)
|
||||
_freq_format_dict = {
|
||||
"month": "month",
|
||||
"mon": "month",
|
||||
"week": "week",
|
||||
"w": "week",
|
||||
"day": "day",
|
||||
"d": "day",
|
||||
"minute": "minute",
|
||||
"min": "minute",
|
||||
}
|
||||
return _count, _freq_format_dict[_freq]
|
||||
|
||||
|
||||
def sample_calendar(calendar_raw: np.ndarray, freq_raw: str, freq_sam: str) -> np.ndarray:
|
||||
"""
|
||||
Sample the calendar with frequency freq_raw into the calendar with frequency freq_sam
|
||||
|
||||
Parameters
|
||||
----------
|
||||
calendar_raw : np.ndarray
|
||||
The calendar with frequency freq_raw
|
||||
freq_raw : str
|
||||
Frequency of the raw calendar
|
||||
freq_sam : str
|
||||
Sample frequency
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray
|
||||
The calendar with frequency freq_sam
|
||||
"""
|
||||
raw_count, freq_raw = parse_freq(freq_raw)
|
||||
sam_count, freq_sam = parse_freq(freq_sam)
|
||||
if not len(calendar_raw):
|
||||
return calendar_raw
|
||||
if freq_sam == "minute":
|
||||
|
||||
def cal_next_sam_minute(x, sam_minutes):
|
||||
hour = x.hour
|
||||
minute = x.minute
|
||||
if (hour == 9 and minute >= 30) or (9 < hour < 11) or (hour == 11 and minute < 30):
|
||||
minute_index = (hour - 9) * 60 + minute - 30
|
||||
elif 13 <= hour < 15:
|
||||
minute_index = (hour - 13) * 60 + minute + 120
|
||||
else:
|
||||
raise ValueError("calendar hour must be in [9, 11] or [13, 15]")
|
||||
|
||||
minute_index = minute_index // sam_minutes * sam_minutes
|
||||
|
||||
if 0 <= minute_index < 120:
|
||||
return 9 + (minute_index + 30) // 60, (minute_index + 30) % 60
|
||||
elif 120 <= minute_index < 240:
|
||||
return 13 + (minute_index - 120) // 60, (minute_index - 120) % 60
|
||||
else:
|
||||
raise ValueError("calendar minute_index error")
|
||||
|
||||
if freq_raw != "minute":
|
||||
raise ValueError("when sampling minute calendar, freq of raw calendar must be minute or min")
|
||||
else:
|
||||
if raw_count > sam_count:
|
||||
raise ValueError("raw freq must be higher than sampling freq")
|
||||
_calendar_minute = np.unique(
|
||||
list(
|
||||
map(lambda x: pd.Timestamp(x.year, x.month, x.day, *cal_next_sam_minute(x, sam_count), 0), calendar_raw)
|
||||
)
|
||||
)
|
||||
if calendar_raw[0] > _calendar_minute[0]:
|
||||
_calendar_minute[0] = calendar_raw[0]
|
||||
return _calendar_minute
|
||||
else:
|
||||
_calendar_day = np.unique(list(map(lambda x: pd.Timestamp(x.year, x.month, x.day, 0, 0, 0), calendar_raw)))
|
||||
if freq_sam == "day":
|
||||
return _calendar_day[::sam_count]
|
||||
|
||||
elif freq_sam == "week":
|
||||
_day_in_week = np.array(list(map(lambda x: x.dayofweek, _calendar_day)))
|
||||
_calendar_week = _calendar_day[np.ediff1d(_day_in_week, to_begin=-1) < 0]
|
||||
return _calendar_week[::sam_count]
|
||||
|
||||
elif freq_sam == "month":
|
||||
_day_in_month = np.array(list(map(lambda x: x.day, _calendar_day)))
|
||||
_calendar_month = _calendar_day[np.ediff1d(_day_in_month, to_begin=-1) < 0]
|
||||
return _calendar_month[::sam_count]
|
||||
else:
|
||||
raise ValueError("sampling freq must be xmin, xd, xw, xm")
|
||||
|
||||
|
||||
def get_sample_freq_calendar(
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
freq: str = "day",
|
||||
future: bool = False,
|
||||
) -> Tuple[np.ndarray, str, Optional[str]]:
|
||||
"""
|
||||
Get the calendar with frequency freq.
|
||||
|
||||
- If the calendar with the raw frequency freq exists, return it directly
|
||||
|
||||
- Else, sample from a higher frequency calendar automatically
|
||||
|
||||
Parameters
|
||||
----------
|
||||
start_time : Union[str, pd.Timestamp], optional
|
||||
start time of calendar, by default None
|
||||
end_time : Union[str, pd.Timestamp], optional
|
||||
end time of calendar, by default None
|
||||
freq : str, optional
|
||||
freq of calendar, by default "day"
|
||||
future : bool, optional
|
||||
whether including future trading day.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[np.ndarray, str, Optional[str]]
|
||||
|
||||
- the first value is the calendar
|
||||
- the second value is the raw freq of calendar
|
||||
- the third value is the sampling freq of calendar, it's None if the raw frequency freq exists.
|
||||
|
||||
"""
|
||||
|
||||
_, norm_freq = parse_freq(freq)
|
||||
|
||||
from ..data.data import Cal
|
||||
|
||||
try:
|
||||
_calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq=freq, future=future)
|
||||
freq, freq_sam = freq, None
|
||||
except ValueError:
|
||||
freq_sam = freq
|
||||
if norm_freq in ["month", "week", "day"]:
|
||||
try:
|
||||
_calendar = Cal.calendar(
|
||||
start_time=start_time, end_time=end_time, freq="day", freq_sam=freq, future=future
|
||||
)
|
||||
freq = "day"
|
||||
except ValueError:
|
||||
_calendar = Cal.calendar(
|
||||
start_time=start_time, end_time=end_time, freq="min", freq_sam=freq, future=future
|
||||
)
|
||||
freq = "min"
|
||||
elif norm_freq == "minute":
|
||||
_calendar = Cal.calendar(start_time=start_time, end_time=end_time, freq="min", freq_sam=freq, future=future)
|
||||
freq = "min"
|
||||
else:
|
||||
raise ValueError(f"freq {freq} is not supported")
|
||||
return _calendar, freq, freq_sam
|
||||
|
||||
|
||||
def sample_feature(
|
||||
feature: Union[pd.DataFrame, pd.Series],
|
||||
start_time: Union[str, pd.Timestamp] = None,
|
||||
end_time: Union[str, pd.Timestamp] = None,
|
||||
fields: Union[str, List[str]] = None,
|
||||
method: Union[str, Callable] = "last",
|
||||
method_kwargs: dict = {},
|
||||
):
|
||||
"""
|
||||
Sample value from pandas DataFrame or Series for each stock
|
||||
|
||||
- If `feature` has MultiIndex[instrument, datetime], apply the `method` to each instruemnt data with datetime in [start_time, end_time]
|
||||
Example:
|
||||
|
||||
.. code-block::
|
||||
|
||||
print(feature)
|
||||
$close $volume
|
||||
instrument datetime
|
||||
SH600000 2010-01-04 86.778313 16162960.0
|
||||
2010-01-05 87.433578 28117442.0
|
||||
2010-01-06 85.713585 23632884.0
|
||||
2010-01-07 83.788803 20813402.0
|
||||
2010-01-08 84.730675 16044853.0
|
||||
|
||||
SH600655 2010-01-04 2699.567383 158193.328125
|
||||
2010-01-08 2612.359619 77501.406250
|
||||
2010-01-11 2712.982422 160852.390625
|
||||
2010-01-12 2788.688232 164587.937500
|
||||
2010-01-13 2790.604004 145460.453125
|
||||
|
||||
print(sample_feature(feature, start_time="2010-01-04", end_time="2010-01-05", fields=["$close", "$volume"], method="last"))
|
||||
$close $volume
|
||||
instrument
|
||||
SH600000 87.433578 28117442.0
|
||||
SH600655 2699.567383 158193.328125
|
||||
|
||||
- Else, the `feature` should have Index[datetime], just apply the `method` to `feature` directly
|
||||
Example:
|
||||
|
||||
.. code-block::
|
||||
print(feature)
|
||||
$close $volume
|
||||
datetime
|
||||
2010-01-04 86.778313 16162960.0
|
||||
2010-01-05 87.433578 28117442.0
|
||||
2010-01-06 85.713585 23632884.0
|
||||
2010-01-07 83.788803 20813402.0
|
||||
2010-01-08 84.730675 16044853.0
|
||||
|
||||
print(sample_feature(feature, start_time="2010-01-04", end_time="2010-01-05", fields=["$close", "$volume"], method="last"))
|
||||
|
||||
$close 87.433578
|
||||
$volume 28117442.0
|
||||
|
||||
print(sample_feature(feature, start_time="2010-01-04", end_time="2010-01-05", fields="$close", method="last"))
|
||||
|
||||
87.433578
|
||||
|
||||
Parameters
|
||||
----------
|
||||
feature : Union[pd.DataFrame, pd.Series]
|
||||
Raw feature to be sampled
|
||||
start_time : Union[str, pd.Timestamp], optional
|
||||
start sampling time, by default None
|
||||
end_time : Union[str, pd.Timestamp], optional
|
||||
end sampling time, by default None
|
||||
fields : Union[str, List[str]], optional
|
||||
column names, it's ignored when sample pd.Series data, by default None(all columns)
|
||||
method : Union[str, Callable], optional
|
||||
sample method, apply method function to each stock series data, by default "last"
|
||||
- If type(method) is str, it should be an attribute of SeriesGroupBy or DataFrameGroupby, and run feature.groupby
|
||||
- If `feature` has MultiIndex[instrument, datetime], method must be a member of pandas.groupby when it's type is str.or callable function.
|
||||
method_kwargs : dict, optional
|
||||
arguments of method, by default {}
|
||||
|
||||
Returns
|
||||
-------
|
||||
The Sampled DataFrame/Series/Value
|
||||
"""
|
||||
|
||||
selector_datetime = slice(start_time, end_time)
|
||||
if fields is None:
|
||||
fields = slice(None)
|
||||
|
||||
from ..data.dataset.utils import get_level_index
|
||||
|
||||
datetime_level = get_level_index(feature, level="datetime") == 0
|
||||
if isinstance(feature, pd.Series):
|
||||
feature = feature.loc[selector_datetime] if datetime_level else feature.loc[(slice(None), selector_datetime)]
|
||||
elif isinstance(feature, pd.DataFrame):
|
||||
feature = (
|
||||
feature.loc[selector_datetime, fields]
|
||||
if datetime_level
|
||||
else feature.loc[(slice(None), selector_datetime), fields]
|
||||
)
|
||||
if feature.empty:
|
||||
return None
|
||||
if isinstance(feature.index, pd.MultiIndex):
|
||||
if callable(method):
|
||||
method_func = method
|
||||
return feature.groupby(level="instrument").apply(lambda x: method_func(x, **method_kwargs))
|
||||
elif isinstance(method, str):
|
||||
return getattr(feature.groupby(level="instrument"), method)(**method_kwargs)
|
||||
else:
|
||||
if callable(method):
|
||||
method_func = method
|
||||
return method_func(feature, **method_kwargs)
|
||||
elif isinstance(method, str):
|
||||
return getattr(feature, method)(**method_kwargs)
|
||||
|
||||
return feature
|
||||
@@ -14,7 +14,8 @@ from ..data.dataset import DatasetH
|
||||
from ..data.dataset.handler import DataHandlerLP
|
||||
from ..utils import init_instance_by_config, get_module_by_module_path
|
||||
from ..log import get_module_logger
|
||||
from ..utils import flatten_dict, parse_freq
|
||||
from ..utils import flatten_dict
|
||||
from ..utils.sample import parse_freq
|
||||
from ..strategy.base import BaseStrategy
|
||||
from ..contrib.eva.alpha import calc_ic, calc_long_short_return
|
||||
|
||||
@@ -315,16 +316,6 @@ class PortAnaRecord(RecordTemp):
|
||||
ret_freq.extend(self._get_report_freq(env_config["kwargs"]["sub_env"]))
|
||||
return ret_freq
|
||||
|
||||
def _cal_risk_analysis_scaler(self, freq):
|
||||
_count, _freq = parse_freq(freq)
|
||||
_freq_scaler = {
|
||||
"minute": 240 * 250,
|
||||
"day": 250,
|
||||
"week": 50,
|
||||
"month": 12,
|
||||
}
|
||||
return _count * _freq_scaler[_freq]
|
||||
|
||||
def generate(self, **kwargs):
|
||||
# custom strategy and get backtest
|
||||
report_dict = normal_backtest(env=self.env_config, strategy=self.strategy_config, **self.backtest_config)
|
||||
@@ -343,12 +334,11 @@ class PortAnaRecord(RecordTemp):
|
||||
else:
|
||||
report_normal, _ = report_dict.get(self.risk_analysis_freq)
|
||||
analysis = dict()
|
||||
risk_analysis_scaler = self._cal_risk_analysis_scaler(self.risk_analysis_freq)
|
||||
analysis["excess_return_without_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"], risk_analysis_scaler
|
||||
report_normal["return"] - report_normal["bench"], self.risk_analysis_freq
|
||||
)
|
||||
analysis["excess_return_with_cost"] = risk_analysis(
|
||||
report_normal["return"] - report_normal["bench"] - report_normal["cost"], risk_analysis_scaler
|
||||
report_normal["return"] - report_normal["bench"] - report_normal["cost"], self.risk_analysis_freq
|
||||
)
|
||||
analysis_df = pd.concat(analysis) # type: pd.DataFrame
|
||||
# log metrics
|
||||
|
||||
Reference in New Issue
Block a user