mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-02 10:31:00 +08:00
Merge branch 'nested_decision_exe' of https://github.com/microsoft/qlib into rl-dummy
This commit is contained in:
@@ -196,27 +196,40 @@
|
||||
"# prediction, backtest & analysis\n",
|
||||
"###################################\n",
|
||||
"port_analysis_config = {\n",
|
||||
" \"executor\": {\n",
|
||||
" \"class\": \"SimulatorExecutor\",\n",
|
||||
" \"module_path\": \"qlib.backtest.executor\",\n",
|
||||
" \"kwargs\": {\n",
|
||||
" \"time_per_step\": \"day\",\n",
|
||||
" \"generate_report\": True,\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" \"strategy\": {\n",
|
||||
" \"class\": \"TopkDropoutStrategy\",\n",
|
||||
" \"module_path\": \"qlib.contrib.strategy.strategy\",\n",
|
||||
" \"module_path\": \"qlib.contrib.strategy.model_strategy\",\n",
|
||||
" \"kwargs\": {\n",
|
||||
" \"model\": model,\n",
|
||||
" \"dataset\": dataset,\n",
|
||||
" \"topk\": 50,\n",
|
||||
" \"n_drop\": 5,\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
" \"backtest\": {\n",
|
||||
" \"verbose\": False,\n",
|
||||
" \"limit_threshold\": 0.095,\n",
|
||||
" \"start_time\": \"2017-01-01\",\n",
|
||||
" \"end_time\": \"2020-08-01\",\n",
|
||||
" \"account\": 100000000,\n",
|
||||
" \"benchmark\": benchmark,\n",
|
||||
" \"deal_price\": \"close\",\n",
|
||||
" \"open_cost\": 0.0005,\n",
|
||||
" \"close_cost\": 0.0015,\n",
|
||||
" \"min_cost\": 5,\n",
|
||||
" \"exchange_kwargs\": {\n",
|
||||
" \"freq\": \"day\",\n",
|
||||
" \"limit_threshold\": 0.095,\n",
|
||||
" \"deal_price\": \"close\",\n",
|
||||
" \"open_cost\": 0.0005,\n",
|
||||
" \"close_cost\": 0.0015,\n",
|
||||
" \"min_cost\": 5,\n",
|
||||
" },\n",
|
||||
" },\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# backtest and analysis\n",
|
||||
"with R.start(experiment_name=\"backtest_analysis\"):\n",
|
||||
" recorder = R.get_recorder(rid, experiment_name=\"train_model\")\n",
|
||||
@@ -229,7 +242,7 @@
|
||||
" sr.generate()\n",
|
||||
"\n",
|
||||
" # backtest & analysis\n",
|
||||
" par = PortAnaRecord(recorder, port_analysis_config)\n",
|
||||
" par = PortAnaRecord(recorder, port_analysis_config, \"day\")\n",
|
||||
" par.generate()\n"
|
||||
]
|
||||
},
|
||||
@@ -249,11 +262,12 @@
|
||||
"from qlib.contrib.report import analysis_model, analysis_position\n",
|
||||
"from qlib.data import D\n",
|
||||
"recorder = R.get_recorder(ba_rid, experiment_name=\"backtest_analysis\")\n",
|
||||
"print(recorder)\n",
|
||||
"pred_df = recorder.load_object(\"pred.pkl\")\n",
|
||||
"pred_df_dates = pred_df.index.get_level_values(level='datetime')\n",
|
||||
"report_normal_df = recorder.load_object(\"portfolio_analysis/report_normal.pkl\")\n",
|
||||
"positions = recorder.load_object(\"portfolio_analysis/positions_normal.pkl\")\n",
|
||||
"analysis_df = recorder.load_object(\"portfolio_analysis/port_analysis.pkl\")"
|
||||
"report_normal_df = recorder.load_object(\"portfolio_analysis/report_normal_1day.pkl\")\n",
|
||||
"positions = recorder.load_object(\"portfolio_analysis/positions_normal_1day.pkl\")\n",
|
||||
"analysis_df = recorder.load_object(\"portfolio_analysis/port_analysis_1day.pkl\")"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -348,9 +362,8 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
"name": "pythonjvsc74a57bd0fcc004278713aaede7c629a6a43738a929cb09abb52817d4f72eb70db44cd87b",
|
||||
"display_name": "Python 3.8 ('qlib_backtest': conda)"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
@@ -362,7 +375,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.3"
|
||||
"version": "3.8"
|
||||
},
|
||||
"toc": {
|
||||
"base_numbering": 1,
|
||||
@@ -376,6 +389,11 @@
|
||||
"toc_position": {},
|
||||
"toc_section_display": true,
|
||||
"toc_window_display": false
|
||||
},
|
||||
"metadata": {
|
||||
"interpreter": {
|
||||
"hash": "fcc004278713aaede7c629a6a43738a929cb09abb52817d4f72eb70db44cd87b"
|
||||
}
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
@@ -3,10 +3,12 @@
|
||||
|
||||
import qlib
|
||||
from qlib.config import REG_CN
|
||||
from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict
|
||||
from qlib.utils import init_instance_by_config, flatten_dict
|
||||
from qlib.workflow import R
|
||||
from qlib.workflow.record_temp import SignalRecord, PortAnaRecord
|
||||
from qlib.tests.data import GetData
|
||||
from qlib.tests.config import CSI300_BENCH, CSI300_GBDT_TASK
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -15,57 +17,8 @@ if __name__ == "__main__":
|
||||
GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True)
|
||||
qlib.init(provider_uri=provider_uri, region=REG_CN)
|
||||
|
||||
market = "csi300"
|
||||
benchmark = "SH000300"
|
||||
|
||||
###################################
|
||||
# train model
|
||||
###################################
|
||||
data_handler_config = {
|
||||
"start_time": "2008-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"fit_start_time": "2008-01-01",
|
||||
"fit_end_time": "2014-12-31",
|
||||
"instruments": market,
|
||||
}
|
||||
|
||||
task = {
|
||||
"model": {
|
||||
"class": "LGBModel",
|
||||
"module_path": "qlib.contrib.model.gbdt",
|
||||
"kwargs": {
|
||||
"loss": "mse",
|
||||
"colsample_bytree": 0.8879,
|
||||
"learning_rate": 0.0421,
|
||||
"subsample": 0.8789,
|
||||
"lambda_l1": 205.6999,
|
||||
"lambda_l2": 580.9768,
|
||||
"max_depth": 8,
|
||||
"num_leaves": 210,
|
||||
"num_threads": 20,
|
||||
},
|
||||
},
|
||||
"dataset": {
|
||||
"class": "DatasetH",
|
||||
"module_path": "qlib.data.dataset",
|
||||
"kwargs": {
|
||||
"handler": {
|
||||
"class": "Alpha158",
|
||||
"module_path": "qlib.contrib.data.handler",
|
||||
"kwargs": data_handler_config,
|
||||
},
|
||||
"segments": {
|
||||
"train": ("2008-01-01", "2014-12-31"),
|
||||
"valid": ("2015-01-01", "2016-12-31"),
|
||||
"test": ("2017-01-01", "2020-08-01"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# model initialization
|
||||
model = init_instance_by_config(task["model"])
|
||||
dataset = init_instance_by_config(task["dataset"])
|
||||
model = init_instance_by_config(CSI300_GBDT_TASK["model"])
|
||||
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
|
||||
|
||||
port_analysis_config = {
|
||||
"executor": {
|
||||
@@ -90,7 +43,7 @@ if __name__ == "__main__":
|
||||
"start_time": "2017-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"account": 100000000,
|
||||
"benchmark": benchmark,
|
||||
"benchmark": CSI300_BENCH,
|
||||
"exchange_kwargs": {
|
||||
"freq": "day",
|
||||
"limit_threshold": 0.095,
|
||||
|
||||
@@ -26,6 +26,7 @@ class Exchange:
|
||||
deal_price=None,
|
||||
subscribe_fields=[],
|
||||
limit_threshold=None,
|
||||
volume_threshold=None,
|
||||
open_cost=0.0015,
|
||||
close_cost=0.0025,
|
||||
trade_unit=None,
|
||||
@@ -41,6 +42,7 @@ class Exchange:
|
||||
:param deal_price: str, 'close', 'open', 'vwap'
|
||||
:param subscribe_fields: list, subscribe fields
|
||||
:param limit_threshold: float, 0.1 for example, default None
|
||||
:param volume_threshold: float, 0.1 for example, default None
|
||||
:param open_cost: cost rate for open, default 0.0015
|
||||
:param close_cost: cost rate for close, default 0.0025
|
||||
:param trade_unit: trade unit, 100 for China A market
|
||||
@@ -60,6 +62,7 @@ class Exchange:
|
||||
self.freq = freq
|
||||
self.start_time = start_time
|
||||
self.end_time = end_time
|
||||
|
||||
if trade_unit is None:
|
||||
trade_unit = C.trade_unit
|
||||
if limit_threshold is None:
|
||||
@@ -70,7 +73,6 @@ class Exchange:
|
||||
self.logger = get_module_logger("online operator", level=logging.INFO)
|
||||
|
||||
self.trade_unit = trade_unit
|
||||
|
||||
# TODO: the quote, trade_dates, codes are not necessray.
|
||||
# It is just for performance consideration.
|
||||
if limit_threshold is None:
|
||||
@@ -100,7 +102,7 @@ class Exchange:
|
||||
self.close_cost = close_cost
|
||||
self.min_cost = min_cost
|
||||
self.limit_threshold = limit_threshold
|
||||
|
||||
self.volume_threshold = volume_threshold
|
||||
self.extra_quote = extra_quote
|
||||
self.set_quote(codes, start_time, end_time)
|
||||
|
||||
@@ -120,14 +122,19 @@ class Exchange:
|
||||
# Use adjusted price
|
||||
self.trade_w_adj_price = True
|
||||
self.logger.warning("factor.day.bin file not exists or factor contains `nan`. Order using adjusted_price.")
|
||||
if self.trade_unit is not None:
|
||||
self.logger.warning(f"trade unit {self.trade_unit} is not supported in adjusted_price mode.")
|
||||
|
||||
else:
|
||||
# The `factor.day.bin` file exists and all data `close` and `factor` are not `nan`
|
||||
# Use normal price
|
||||
self.trade_w_adj_price = False
|
||||
|
||||
# update limit
|
||||
# check limit_threshold
|
||||
if self.limit_threshold is None:
|
||||
self.quote["limit"] = False
|
||||
self.quote["limit_buy"] = False
|
||||
self.quote["limit_sell"] = False
|
||||
else:
|
||||
# set limit
|
||||
self._update_limit(buy_limit=self.limit_threshold, sell_limit=self.limit_threshold)
|
||||
@@ -143,9 +150,13 @@ class Exchange:
|
||||
if "$factor" not in self.extra_quote.columns:
|
||||
self.extra_quote["$factor"] = 1.0
|
||||
self.logger.warning("No $factor set for extra_quote. Use 1.0 as $factor.")
|
||||
if "limit" not in self.extra_quote.columns:
|
||||
self.extra_quote["limit"] = False
|
||||
self.logger.warning("No limit set for extra_quote. All stock will be tradable.")
|
||||
if "limit_sell" not in self.extra_quote.columns:
|
||||
self.extra_quote["limit_sell"] = False
|
||||
self.logger.warning("No limit_sell set for extra_quote. All stock will be able to be sold.")
|
||||
if "limit_buy" not in self.extra_quote.columns:
|
||||
self.extra_quote["limit_buy"] = False
|
||||
self.logger.warning("No limit_buy set for extra_quote. All stock will be able to be bought.")
|
||||
|
||||
assert set(self.extra_quote.columns) == set(quote_df.columns) - {"$change"}
|
||||
quote_df = pd.concat([quote_df, self.extra_quote], sort=False, axis=0)
|
||||
|
||||
@@ -160,15 +171,30 @@ class Exchange:
|
||||
self.quote = quote_dict
|
||||
|
||||
def _update_limit(self, buy_limit, sell_limit):
|
||||
self.quote["limit"] = ~self.quote["$change"].between(-sell_limit, buy_limit, inclusive=False)
|
||||
self.quote["limit_buy"] = ~self.quote["$change"].lt(buy_limit)
|
||||
self.quote["limit_sell"] = ~self.quote["$change"].gt(-sell_limit)
|
||||
|
||||
def check_stock_limit(self, stock_id, start_time, end_time):
|
||||
"""Parameter
|
||||
stock_id
|
||||
trade_date
|
||||
is limtited
|
||||
def check_stock_limit(self, stock_id, start_time, end_time, direction=None):
|
||||
"""
|
||||
return resam_ts_data(self.quote[stock_id]["limit"], start_time, end_time, method="all").iloc[0]
|
||||
Parameters
|
||||
----------
|
||||
direction : int, optional
|
||||
trade direction, by default None
|
||||
- if direction is None, check if tradable for buying and selling.
|
||||
- if direction == Order.BUY, check the if tradable for buying
|
||||
- if direction == Order.SELL, check the sell limit for selling.
|
||||
|
||||
"""
|
||||
if direction is None:
|
||||
buy_limit = resam_ts_data(self.quote[stock_id]["limit_buy"], start_time, end_time, method="all").iloc[0]
|
||||
sell_limit = resam_ts_data(self.quote[stock_id]["limit_sell"], start_time, end_time, method="all").iloc[0]
|
||||
return buy_limit or sell_limit
|
||||
elif direction == Order.BUY:
|
||||
return resam_ts_data(self.quote[stock_id]["limit_buy"], start_time, end_time, method="all").iloc[0]
|
||||
elif direction == Order.SELL:
|
||||
return resam_ts_data(self.quote[stock_id]["limit_sell"], start_time, end_time, method="all").iloc[0]
|
||||
else:
|
||||
raise ValueError(f"direction {direction} is not supported!")
|
||||
|
||||
def check_stock_suspended(self, stock_id, start_time, end_time):
|
||||
# is suspended
|
||||
@@ -177,11 +203,11 @@ class Exchange:
|
||||
else:
|
||||
return True
|
||||
|
||||
def is_stock_tradable(self, stock_id, start_time, end_time):
|
||||
def is_stock_tradable(self, stock_id, start_time, end_time, direction=None):
|
||||
# check if stock can be traded
|
||||
# same as check in check_order
|
||||
if self.check_stock_suspended(stock_id, start_time, end_time) or self.check_stock_limit(
|
||||
stock_id, start_time, end_time
|
||||
stock_id, start_time, end_time, direction
|
||||
):
|
||||
return False
|
||||
else:
|
||||
@@ -190,7 +216,7 @@ class Exchange:
|
||||
def check_order(self, order):
|
||||
# check limit and suspended
|
||||
if self.check_stock_suspended(order.stock_id, order.start_time, order.end_time) or self.check_stock_limit(
|
||||
order.stock_id, order.start_time, order.end_time
|
||||
order.stock_id, order.start_time, order.end_time, order.direction
|
||||
):
|
||||
return False
|
||||
else:
|
||||
@@ -220,8 +246,8 @@ class Exchange:
|
||||
order, trade_account.current if trade_account else position
|
||||
)
|
||||
# update account
|
||||
if trade_val > 0:
|
||||
# If the order can only be deal 0 trade_val. Nothing to be updated
|
||||
if order.deal_amount > 1e-5:
|
||||
# If the order can only be deal 0 aomount. Nothing to be updated
|
||||
# Otherwise, it will result some stock with 0 amount in the position
|
||||
if trade_account:
|
||||
trade_account.update_order(order=order, trade_val=trade_val, cost=trade_cost, trade_price=trade_price)
|
||||
@@ -393,7 +419,7 @@ class Exchange:
|
||||
return value
|
||||
|
||||
def get_amount_of_trade_unit(self, factor):
|
||||
if not self.trade_w_adj_price:
|
||||
if not self.trade_w_adj_price and self.trade_unit is not None:
|
||||
return self.trade_unit / factor
|
||||
else:
|
||||
return None
|
||||
@@ -404,11 +430,18 @@ class Exchange:
|
||||
factor : float, adjusted factor
|
||||
return : float, real amount
|
||||
"""
|
||||
if not self.trade_w_adj_price:
|
||||
if not self.trade_w_adj_price and self.trade_unit is not None:
|
||||
# the minimal amount is 1. Add 0.1 for solving precision problem.
|
||||
return (deal_amount * factor + 0.1) // self.trade_unit * self.trade_unit / factor
|
||||
return deal_amount
|
||||
|
||||
def _get_amount_by_volume(self, stock_id, trade_start_time, trade_end_time, deal_amount):
|
||||
if self.volume_threshold is not None:
|
||||
tradable_amount = self.get_volume(stock_id, trade_start_time, trade_end_time) * self.volume_threshold
|
||||
return max(min(tradable_amount, deal_amount), 0)
|
||||
else:
|
||||
return deal_amount
|
||||
|
||||
def _calc_trade_info_by_order(self, order, position):
|
||||
"""
|
||||
Calculation of trade info
|
||||
@@ -422,9 +455,14 @@ class Exchange:
|
||||
if order.direction == Order.SELL:
|
||||
# sell
|
||||
if position is not None:
|
||||
if np.isclose(order.amount, position.get_stock_amount(order.stock_id)):
|
||||
current_amount = (
|
||||
position.get_stock_amount(order.stock_id) if position.check_stock(order.stock_id) else 0
|
||||
)
|
||||
if np.isclose(order.amount, current_amount):
|
||||
# when selling last stock. The amount don't need rounding
|
||||
order.deal_amount = order.amount
|
||||
elif order.amount > current_amount:
|
||||
order.deal_amount = self.round_amount_by_trade_unit(current_amount, order.factor)
|
||||
else:
|
||||
order.deal_amount = self.round_amount_by_trade_unit(order.amount, order.factor)
|
||||
else:
|
||||
@@ -432,6 +470,9 @@ class Exchange:
|
||||
# We choose to sell all
|
||||
order.deal_amount = order.amount
|
||||
|
||||
order.deal_amount = self._get_amount_by_volume(
|
||||
order.stock_id, order.start_time, order.end_time, order.deal_amount
|
||||
)
|
||||
trade_val = order.deal_amount * trade_price
|
||||
trade_cost = max(trade_val * self.close_cost, self.min_cost)
|
||||
elif order.direction == Order.BUY:
|
||||
@@ -451,6 +492,9 @@ class Exchange:
|
||||
# Unknown amount of money. Just round the amount
|
||||
order.deal_amount = self.round_amount_by_trade_unit(order.amount, order.factor)
|
||||
|
||||
order.deal_amount = self._get_amount_by_volume(
|
||||
order.stock_id, order.start_time, order.end_time, order.deal_amount
|
||||
)
|
||||
trade_val = order.deal_amount * trade_price
|
||||
trade_cost = trade_val * self.open_cost
|
||||
else:
|
||||
|
||||
@@ -118,7 +118,8 @@ class BaseExecutor:
|
||||
def get_report(self):
|
||||
raise NotImplementedError("get_report is not implemented!")
|
||||
|
||||
def get_all_executor(self):
|
||||
def get_all_executors(self):
|
||||
"""Return all executors"""
|
||||
return [self]
|
||||
|
||||
|
||||
@@ -247,8 +248,9 @@ class NestedExecutor(BaseExecutor):
|
||||
sub_env_report_dict.update({f"{_count}{_freq}": (_report, _positions)})
|
||||
return sub_env_report_dict
|
||||
|
||||
def get_all_executor(self):
|
||||
return [self, *self.inner_executor.get_all_executor()]
|
||||
def get_all_executors(self):
|
||||
"""Return all executors, including self and inner_executor.get_all_executors()"""
|
||||
return [self, *self.inner_executor.get_all_executors()]
|
||||
|
||||
|
||||
class SimulatorExecutor(BaseExecutor):
|
||||
|
||||
@@ -73,6 +73,9 @@ class Position:
|
||||
def del_stock(self, stock_id):
|
||||
del self.position[stock_id]
|
||||
|
||||
def check_stock(self, stock_id):
|
||||
return stock_id in self.position
|
||||
|
||||
def update_order(self, order, trade_val, cost, trade_price):
|
||||
# handle order, order is a order class, defined in exchange.py
|
||||
if order.direction == Order.BUY:
|
||||
|
||||
@@ -12,6 +12,7 @@ from pandas.core.frame import DataFrame
|
||||
|
||||
from ..utils.resam import parse_freq, resam_ts_data
|
||||
from ..data import D
|
||||
from ..tests.config import CSI300_BENCH
|
||||
|
||||
|
||||
class Report:
|
||||
@@ -67,7 +68,7 @@ class Report:
|
||||
self.bench = self._cal_benchmark(self.benchmark_config, self.freq)
|
||||
|
||||
def _cal_benchmark(self, benchmark_config, freq):
|
||||
benchmark = benchmark_config.get("benchmark", "SH000300")
|
||||
benchmark = benchmark_config.get("benchmark", CSI300_BENCH)
|
||||
if isinstance(benchmark, pd.Series):
|
||||
return benchmark
|
||||
else:
|
||||
|
||||
@@ -29,7 +29,7 @@ def risk_analysis(r, N: int = None, freq: str = "day"):
|
||||
r : pandas.Series
|
||||
daily return series.
|
||||
N: int
|
||||
scaler for annualizing information_ratio (day: 250, week: 50, month: 12), at least one of `N` and `freq` should exist
|
||||
scaler for annualizing information_ratio (day: 252, week: 50, month: 12), at least one of `N` and `freq` should exist
|
||||
freq: str
|
||||
analysis frequency used for calculating the scaler, at least one of `N` and `freq` should exist
|
||||
"""
|
||||
|
||||
@@ -14,27 +14,6 @@ from qlib.workflow.record_temp import SignalRecord, SigAnaRecord, PortAnaRecord
|
||||
from qlib.tests import TestAutoData
|
||||
from qlib.tests.config import CSI300_GBDT_TASK, CSI300_BENCH
|
||||
|
||||
port_analysis_config = {
|
||||
"strategy": {
|
||||
"class": "TopkDropoutStrategy",
|
||||
"module_path": "qlib.contrib.strategy.strategy",
|
||||
"kwargs": {
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
},
|
||||
},
|
||||
"backtest": {
|
||||
"verbose": False,
|
||||
"limit_threshold": 0.095,
|
||||
"account": 100000000,
|
||||
"benchmark": CSI300_BENCH,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def train():
|
||||
"""train model
|
||||
@@ -58,7 +37,7 @@ def train():
|
||||
with R.start(experiment_name="workflow"):
|
||||
R.log_params(**flatten_dict(CSI300_GBDT_TASK))
|
||||
model.fit(dataset)
|
||||
|
||||
R.save_objects(trained_model=model)
|
||||
# prediction
|
||||
recorder = R.get_recorder()
|
||||
# To test __repr__
|
||||
@@ -68,7 +47,6 @@ def train():
|
||||
rid = recorder.id
|
||||
sr = SignalRecord(model, dataset, recorder)
|
||||
sr.generate()
|
||||
pred_score = sr.load()
|
||||
|
||||
# calculate ic and ric
|
||||
sar = SigAnaRecord(recorder)
|
||||
@@ -76,7 +54,7 @@ def train():
|
||||
ic = sar.load(sar.get_path("ic.pkl"))
|
||||
ric = sar.load(sar.get_path("ric.pkl"))
|
||||
|
||||
return pred_score, {"ic": ic, "ric": ric}, rid
|
||||
return {"ic": ic, "ric": ric}, rid
|
||||
|
||||
|
||||
def train_with_sigana():
|
||||
@@ -103,10 +81,9 @@ def train_with_sigana():
|
||||
sar.generate()
|
||||
ic = sar.load(sar.get_path("ic.pkl"))
|
||||
ric = sar.load(sar.get_path("ric.pkl"))
|
||||
pred_score = sar.load("pred.pkl")
|
||||
|
||||
uri_path = R.get_uri()
|
||||
return pred_score, {"ic": ic, "ric": ric}, uri_path
|
||||
return {"ic": ic, "ric": ric}, uri_path
|
||||
|
||||
|
||||
def fake_experiment():
|
||||
@@ -130,13 +107,11 @@ def fake_experiment():
|
||||
return default_uri == default_uri_to_check, current_uri == current_uri_to_check, current_uri
|
||||
|
||||
|
||||
def backtest_analysis(pred, rid):
|
||||
def backtest_analysis(rid):
|
||||
"""backtest and analysis
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pred : pandas.DataFrame
|
||||
predict scores
|
||||
rid : str
|
||||
the id of the recorder to be used in this function
|
||||
|
||||
@@ -147,16 +122,54 @@ def backtest_analysis(pred, rid):
|
||||
|
||||
"""
|
||||
recorder = R.get_recorder(experiment_name="workflow", recorder_id=rid)
|
||||
|
||||
dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"])
|
||||
model = recorder.load_object("trained_model")
|
||||
|
||||
port_analysis_config = {
|
||||
"executor": {
|
||||
"class": "SimulatorExecutor",
|
||||
"module_path": "qlib.backtest.executor",
|
||||
"kwargs": {
|
||||
"time_per_step": "day",
|
||||
"generate_report": True,
|
||||
},
|
||||
},
|
||||
"strategy": {
|
||||
"class": "TopkDropoutStrategy",
|
||||
"module_path": "qlib.contrib.strategy.model_strategy",
|
||||
"kwargs": {
|
||||
"model": model,
|
||||
"dataset": dataset,
|
||||
"topk": 50,
|
||||
"n_drop": 5,
|
||||
},
|
||||
},
|
||||
"backtest": {
|
||||
"start_time": "2017-01-01",
|
||||
"end_time": "2020-08-01",
|
||||
"account": 100000000,
|
||||
"benchmark": CSI300_BENCH,
|
||||
"exchange_kwargs": {
|
||||
"freq": "day",
|
||||
"limit_threshold": 0.095,
|
||||
"deal_price": "close",
|
||||
"open_cost": 0.0005,
|
||||
"close_cost": 0.0015,
|
||||
"min_cost": 5,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# backtest
|
||||
par = PortAnaRecord(recorder, port_analysis_config)
|
||||
par = PortAnaRecord(recorder, port_analysis_config, risk_analysis_freq="day")
|
||||
par.generate()
|
||||
analysis_df = par.load(par.get_path("port_analysis.pkl"))
|
||||
analysis_df = par.load(par.get_path("port_analysis_1day.pkl"))
|
||||
print(analysis_df)
|
||||
return analysis_df
|
||||
|
||||
|
||||
class TestAllFlow(TestAutoData):
|
||||
PRED_SCORE = None
|
||||
REPORT_NORMAL = None
|
||||
POSITIONS = None
|
||||
RID = None
|
||||
@@ -166,18 +179,18 @@ class TestAllFlow(TestAutoData):
|
||||
shutil.rmtree(str(Path(C["exp_manager"]["kwargs"]["uri"].strip("file:")).resolve()))
|
||||
|
||||
def test_0_train_with_sigana(self):
|
||||
TestAllFlow.PRED_SCORE, ic_ric, uri_path = train_with_sigana()
|
||||
ic_ric, uri_path = train_with_sigana()
|
||||
self.assertGreaterEqual(ic_ric["ic"].all(), 0, "train failed")
|
||||
self.assertGreaterEqual(ic_ric["ric"].all(), 0, "train failed")
|
||||
shutil.rmtree(str(Path(uri_path.strip("file:")).resolve()))
|
||||
|
||||
def test_1_train(self):
|
||||
TestAllFlow.PRED_SCORE, ic_ric, TestAllFlow.RID = train()
|
||||
ic_ric, TestAllFlow.RID = train()
|
||||
self.assertGreaterEqual(ic_ric["ic"].all(), 0, "train failed")
|
||||
self.assertGreaterEqual(ic_ric["ric"].all(), 0, "train failed")
|
||||
|
||||
def test_2_backtest(self):
|
||||
analyze_df = backtest_analysis(TestAllFlow.PRED_SCORE, TestAllFlow.RID)
|
||||
analyze_df = backtest_analysis(TestAllFlow.RID)
|
||||
self.assertGreaterEqual(
|
||||
analyze_df.loc(axis=0)["excess_return_with_cost", "annualized_return"].values[0],
|
||||
0.10,
|
||||
|
||||
Reference in New Issue
Block a user