diff --git a/examples/workflow_by_code.ipynb b/examples/workflow_by_code.ipynb index b4da1bfe4..3d99bf1e1 100644 --- a/examples/workflow_by_code.ipynb +++ b/examples/workflow_by_code.ipynb @@ -196,27 +196,40 @@ "# prediction, backtest & analysis\n", "###################################\n", "port_analysis_config = {\n", + " \"executor\": {\n", + " \"class\": \"SimulatorExecutor\",\n", + " \"module_path\": \"qlib.backtest.executor\",\n", + " \"kwargs\": {\n", + " \"time_per_step\": \"day\",\n", + " \"generate_report\": True,\n", + " },\n", + " },\n", " \"strategy\": {\n", " \"class\": \"TopkDropoutStrategy\",\n", - " \"module_path\": \"qlib.contrib.strategy.strategy\",\n", + " \"module_path\": \"qlib.contrib.strategy.model_strategy\",\n", " \"kwargs\": {\n", + " \"model\": model,\n", + " \"dataset\": dataset,\n", " \"topk\": 50,\n", " \"n_drop\": 5,\n", " },\n", " },\n", " \"backtest\": {\n", - " \"verbose\": False,\n", - " \"limit_threshold\": 0.095,\n", + " \"start_time\": \"2017-01-01\",\n", + " \"end_time\": \"2020-08-01\",\n", " \"account\": 100000000,\n", " \"benchmark\": benchmark,\n", - " \"deal_price\": \"close\",\n", - " \"open_cost\": 0.0005,\n", - " \"close_cost\": 0.0015,\n", - " \"min_cost\": 5,\n", + " \"exchange_kwargs\": {\n", + " \"freq\": \"day\",\n", + " \"limit_threshold\": 0.095,\n", + " \"deal_price\": \"close\",\n", + " \"open_cost\": 0.0005,\n", + " \"close_cost\": 0.0015,\n", + " \"min_cost\": 5,\n", + " },\n", " },\n", "}\n", "\n", - "\n", "# backtest and analysis\n", "with R.start(experiment_name=\"backtest_analysis\"):\n", " recorder = R.get_recorder(rid, experiment_name=\"train_model\")\n", @@ -229,7 +242,7 @@ " sr.generate()\n", "\n", " # backtest & analysis\n", - " par = PortAnaRecord(recorder, port_analysis_config)\n", + " par = PortAnaRecord(recorder, port_analysis_config, \"day\")\n", " par.generate()\n" ] }, @@ -249,11 +262,12 @@ "from qlib.contrib.report import analysis_model, analysis_position\n", "from qlib.data import D\n", "recorder = R.get_recorder(ba_rid, experiment_name=\"backtest_analysis\")\n", + "print(recorder)\n", "pred_df = recorder.load_object(\"pred.pkl\")\n", "pred_df_dates = pred_df.index.get_level_values(level='datetime')\n", - "report_normal_df = recorder.load_object(\"portfolio_analysis/report_normal.pkl\")\n", - "positions = recorder.load_object(\"portfolio_analysis/positions_normal.pkl\")\n", - "analysis_df = recorder.load_object(\"portfolio_analysis/port_analysis.pkl\")" + "report_normal_df = recorder.load_object(\"portfolio_analysis/report_normal_1day.pkl\")\n", + "positions = recorder.load_object(\"portfolio_analysis/positions_normal_1day.pkl\")\n", + "analysis_df = recorder.load_object(\"portfolio_analysis/port_analysis_1day.pkl\")" ] }, { @@ -348,9 +362,8 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" + "name": "pythonjvsc74a57bd0fcc004278713aaede7c629a6a43738a929cb09abb52817d4f72eb70db44cd87b", + "display_name": "Python 3.8 ('qlib_backtest': conda)" }, "language_info": { "codemirror_mode": { @@ -362,7 +375,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.3" + "version": "3.8" }, "toc": { "base_numbering": 1, @@ -376,6 +389,11 @@ "toc_position": {}, "toc_section_display": true, "toc_window_display": false + }, + "metadata": { + "interpreter": { + "hash": "fcc004278713aaede7c629a6a43738a929cb09abb52817d4f72eb70db44cd87b" + } } }, "nbformat": 4, diff --git a/examples/workflow_by_code.py b/examples/workflow_by_code.py index b02ea91b1..d7bb544f9 100644 --- a/examples/workflow_by_code.py +++ b/examples/workflow_by_code.py @@ -3,10 +3,12 @@ import qlib from qlib.config import REG_CN -from qlib.utils import exists_qlib_data, init_instance_by_config, flatten_dict +from qlib.utils import init_instance_by_config, flatten_dict from qlib.workflow import R from qlib.workflow.record_temp import SignalRecord, PortAnaRecord from qlib.tests.data import GetData +from qlib.tests.config import CSI300_BENCH, CSI300_GBDT_TASK + if __name__ == "__main__": @@ -15,57 +17,8 @@ if __name__ == "__main__": GetData().qlib_data(target_dir=provider_uri, region=REG_CN, exists_skip=True) qlib.init(provider_uri=provider_uri, region=REG_CN) - market = "csi300" - benchmark = "SH000300" - - ################################### - # train model - ################################### - data_handler_config = { - "start_time": "2008-01-01", - "end_time": "2020-08-01", - "fit_start_time": "2008-01-01", - "fit_end_time": "2014-12-31", - "instruments": market, - } - - task = { - "model": { - "class": "LGBModel", - "module_path": "qlib.contrib.model.gbdt", - "kwargs": { - "loss": "mse", - "colsample_bytree": 0.8879, - "learning_rate": 0.0421, - "subsample": 0.8789, - "lambda_l1": 205.6999, - "lambda_l2": 580.9768, - "max_depth": 8, - "num_leaves": 210, - "num_threads": 20, - }, - }, - "dataset": { - "class": "DatasetH", - "module_path": "qlib.data.dataset", - "kwargs": { - "handler": { - "class": "Alpha158", - "module_path": "qlib.contrib.data.handler", - "kwargs": data_handler_config, - }, - "segments": { - "train": ("2008-01-01", "2014-12-31"), - "valid": ("2015-01-01", "2016-12-31"), - "test": ("2017-01-01", "2020-08-01"), - }, - }, - }, - } - - # model initialization - model = init_instance_by_config(task["model"]) - dataset = init_instance_by_config(task["dataset"]) + model = init_instance_by_config(CSI300_GBDT_TASK["model"]) + dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"]) port_analysis_config = { "executor": { @@ -90,7 +43,7 @@ if __name__ == "__main__": "start_time": "2017-01-01", "end_time": "2020-08-01", "account": 100000000, - "benchmark": benchmark, + "benchmark": CSI300_BENCH, "exchange_kwargs": { "freq": "day", "limit_threshold": 0.095, diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py index de2df98be..4fc01d8e2 100644 --- a/qlib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -26,6 +26,7 @@ class Exchange: deal_price=None, subscribe_fields=[], limit_threshold=None, + volume_threshold=None, open_cost=0.0015, close_cost=0.0025, trade_unit=None, @@ -41,6 +42,7 @@ class Exchange: :param deal_price: str, 'close', 'open', 'vwap' :param subscribe_fields: list, subscribe fields :param limit_threshold: float, 0.1 for example, default None + :param volume_threshold: float, 0.1 for example, default None :param open_cost: cost rate for open, default 0.0015 :param close_cost: cost rate for close, default 0.0025 :param trade_unit: trade unit, 100 for China A market @@ -60,6 +62,7 @@ class Exchange: self.freq = freq self.start_time = start_time self.end_time = end_time + if trade_unit is None: trade_unit = C.trade_unit if limit_threshold is None: @@ -70,7 +73,6 @@ class Exchange: self.logger = get_module_logger("online operator", level=logging.INFO) self.trade_unit = trade_unit - # TODO: the quote, trade_dates, codes are not necessray. # It is just for performance consideration. if limit_threshold is None: @@ -100,7 +102,7 @@ class Exchange: self.close_cost = close_cost self.min_cost = min_cost self.limit_threshold = limit_threshold - + self.volume_threshold = volume_threshold self.extra_quote = extra_quote self.set_quote(codes, start_time, end_time) @@ -120,14 +122,19 @@ class Exchange: # Use adjusted price self.trade_w_adj_price = True self.logger.warning("factor.day.bin file not exists or factor contains `nan`. Order using adjusted_price.") + if self.trade_unit is not None: + self.logger.warning(f"trade unit {self.trade_unit} is not supported in adjusted_price mode.") + else: # The `factor.day.bin` file exists and all data `close` and `factor` are not `nan` # Use normal price self.trade_w_adj_price = False + # update limit # check limit_threshold if self.limit_threshold is None: - self.quote["limit"] = False + self.quote["limit_buy"] = False + self.quote["limit_sell"] = False else: # set limit self._update_limit(buy_limit=self.limit_threshold, sell_limit=self.limit_threshold) @@ -143,9 +150,13 @@ class Exchange: if "$factor" not in self.extra_quote.columns: self.extra_quote["$factor"] = 1.0 self.logger.warning("No $factor set for extra_quote. Use 1.0 as $factor.") - if "limit" not in self.extra_quote.columns: - self.extra_quote["limit"] = False - self.logger.warning("No limit set for extra_quote. All stock will be tradable.") + if "limit_sell" not in self.extra_quote.columns: + self.extra_quote["limit_sell"] = False + self.logger.warning("No limit_sell set for extra_quote. All stock will be able to be sold.") + if "limit_buy" not in self.extra_quote.columns: + self.extra_quote["limit_buy"] = False + self.logger.warning("No limit_buy set for extra_quote. All stock will be able to be bought.") + assert set(self.extra_quote.columns) == set(quote_df.columns) - {"$change"} quote_df = pd.concat([quote_df, self.extra_quote], sort=False, axis=0) @@ -160,15 +171,30 @@ class Exchange: self.quote = quote_dict def _update_limit(self, buy_limit, sell_limit): - self.quote["limit"] = ~self.quote["$change"].between(-sell_limit, buy_limit, inclusive=False) + self.quote["limit_buy"] = ~self.quote["$change"].lt(buy_limit) + self.quote["limit_sell"] = ~self.quote["$change"].gt(-sell_limit) - def check_stock_limit(self, stock_id, start_time, end_time): - """Parameter - stock_id - trade_date - is limtited + def check_stock_limit(self, stock_id, start_time, end_time, direction=None): """ - return resam_ts_data(self.quote[stock_id]["limit"], start_time, end_time, method="all").iloc[0] + Parameters + ---------- + direction : int, optional + trade direction, by default None + - if direction is None, check if tradable for buying and selling. + - if direction == Order.BUY, check the if tradable for buying + - if direction == Order.SELL, check the sell limit for selling. + + """ + if direction is None: + buy_limit = resam_ts_data(self.quote[stock_id]["limit_buy"], start_time, end_time, method="all").iloc[0] + sell_limit = resam_ts_data(self.quote[stock_id]["limit_sell"], start_time, end_time, method="all").iloc[0] + return buy_limit or sell_limit + elif direction == Order.BUY: + return resam_ts_data(self.quote[stock_id]["limit_buy"], start_time, end_time, method="all").iloc[0] + elif direction == Order.SELL: + return resam_ts_data(self.quote[stock_id]["limit_sell"], start_time, end_time, method="all").iloc[0] + else: + raise ValueError(f"direction {direction} is not supported!") def check_stock_suspended(self, stock_id, start_time, end_time): # is suspended @@ -177,11 +203,11 @@ class Exchange: else: return True - def is_stock_tradable(self, stock_id, start_time, end_time): + def is_stock_tradable(self, stock_id, start_time, end_time, direction=None): # check if stock can be traded # same as check in check_order if self.check_stock_suspended(stock_id, start_time, end_time) or self.check_stock_limit( - stock_id, start_time, end_time + stock_id, start_time, end_time, direction ): return False else: @@ -190,7 +216,7 @@ class Exchange: def check_order(self, order): # check limit and suspended if self.check_stock_suspended(order.stock_id, order.start_time, order.end_time) or self.check_stock_limit( - order.stock_id, order.start_time, order.end_time + order.stock_id, order.start_time, order.end_time, order.direction ): return False else: @@ -220,8 +246,8 @@ class Exchange: order, trade_account.current if trade_account else position ) # update account - if trade_val > 0: - # If the order can only be deal 0 trade_val. Nothing to be updated + if order.deal_amount > 1e-5: + # If the order can only be deal 0 aomount. Nothing to be updated # Otherwise, it will result some stock with 0 amount in the position if trade_account: trade_account.update_order(order=order, trade_val=trade_val, cost=trade_cost, trade_price=trade_price) @@ -393,7 +419,7 @@ class Exchange: return value def get_amount_of_trade_unit(self, factor): - if not self.trade_w_adj_price: + if not self.trade_w_adj_price and self.trade_unit is not None: return self.trade_unit / factor else: return None @@ -404,11 +430,18 @@ class Exchange: factor : float, adjusted factor return : float, real amount """ - if not self.trade_w_adj_price: + if not self.trade_w_adj_price and self.trade_unit is not None: # the minimal amount is 1. Add 0.1 for solving precision problem. return (deal_amount * factor + 0.1) // self.trade_unit * self.trade_unit / factor return deal_amount + def _get_amount_by_volume(self, stock_id, trade_start_time, trade_end_time, deal_amount): + if self.volume_threshold is not None: + tradable_amount = self.get_volume(stock_id, trade_start_time, trade_end_time) * self.volume_threshold + return max(min(tradable_amount, deal_amount), 0) + else: + return deal_amount + def _calc_trade_info_by_order(self, order, position): """ Calculation of trade info @@ -422,9 +455,14 @@ class Exchange: if order.direction == Order.SELL: # sell if position is not None: - if np.isclose(order.amount, position.get_stock_amount(order.stock_id)): + current_amount = ( + position.get_stock_amount(order.stock_id) if position.check_stock(order.stock_id) else 0 + ) + if np.isclose(order.amount, current_amount): # when selling last stock. The amount don't need rounding order.deal_amount = order.amount + elif order.amount > current_amount: + order.deal_amount = self.round_amount_by_trade_unit(current_amount, order.factor) else: order.deal_amount = self.round_amount_by_trade_unit(order.amount, order.factor) else: @@ -432,6 +470,9 @@ class Exchange: # We choose to sell all order.deal_amount = order.amount + order.deal_amount = self._get_amount_by_volume( + order.stock_id, order.start_time, order.end_time, order.deal_amount + ) trade_val = order.deal_amount * trade_price trade_cost = max(trade_val * self.close_cost, self.min_cost) elif order.direction == Order.BUY: @@ -451,6 +492,9 @@ class Exchange: # Unknown amount of money. Just round the amount order.deal_amount = self.round_amount_by_trade_unit(order.amount, order.factor) + order.deal_amount = self._get_amount_by_volume( + order.stock_id, order.start_time, order.end_time, order.deal_amount + ) trade_val = order.deal_amount * trade_price trade_cost = trade_val * self.open_cost else: diff --git a/qlib/backtest/executor.py b/qlib/backtest/executor.py index e68047e38..656073759 100644 --- a/qlib/backtest/executor.py +++ b/qlib/backtest/executor.py @@ -118,7 +118,8 @@ class BaseExecutor: def get_report(self): raise NotImplementedError("get_report is not implemented!") - def get_all_executor(self): + def get_all_executors(self): + """Return all executors""" return [self] @@ -247,8 +248,9 @@ class NestedExecutor(BaseExecutor): sub_env_report_dict.update({f"{_count}{_freq}": (_report, _positions)}) return sub_env_report_dict - def get_all_executor(self): - return [self, *self.inner_executor.get_all_executor()] + def get_all_executors(self): + """Return all executors, including self and inner_executor.get_all_executors()""" + return [self, *self.inner_executor.get_all_executors()] class SimulatorExecutor(BaseExecutor): diff --git a/qlib/backtest/position.py b/qlib/backtest/position.py index c6368606a..92b549063 100644 --- a/qlib/backtest/position.py +++ b/qlib/backtest/position.py @@ -73,6 +73,9 @@ class Position: def del_stock(self, stock_id): del self.position[stock_id] + def check_stock(self, stock_id): + return stock_id in self.position + def update_order(self, order, trade_val, cost, trade_price): # handle order, order is a order class, defined in exchange.py if order.direction == Order.BUY: diff --git a/qlib/backtest/report.py b/qlib/backtest/report.py index 4b9b0ce26..0668f81cf 100644 --- a/qlib/backtest/report.py +++ b/qlib/backtest/report.py @@ -12,6 +12,7 @@ from pandas.core.frame import DataFrame from ..utils.resam import parse_freq, resam_ts_data from ..data import D +from ..tests.config import CSI300_BENCH class Report: @@ -67,7 +68,7 @@ class Report: self.bench = self._cal_benchmark(self.benchmark_config, self.freq) def _cal_benchmark(self, benchmark_config, freq): - benchmark = benchmark_config.get("benchmark", "SH000300") + benchmark = benchmark_config.get("benchmark", CSI300_BENCH) if isinstance(benchmark, pd.Series): return benchmark else: diff --git a/qlib/contrib/evaluate.py b/qlib/contrib/evaluate.py index 8d4052cdb..0ef8f95a5 100644 --- a/qlib/contrib/evaluate.py +++ b/qlib/contrib/evaluate.py @@ -29,7 +29,7 @@ def risk_analysis(r, N: int = None, freq: str = "day"): r : pandas.Series daily return series. N: int - scaler for annualizing information_ratio (day: 250, week: 50, month: 12), at least one of `N` and `freq` should exist + scaler for annualizing information_ratio (day: 252, week: 50, month: 12), at least one of `N` and `freq` should exist freq: str analysis frequency used for calculating the scaler, at least one of `N` and `freq` should exist """ diff --git a/tests/test_all_pipeline.py b/tests/test_all_pipeline.py index 4c20405fa..ea171f31e 100644 --- a/tests/test_all_pipeline.py +++ b/tests/test_all_pipeline.py @@ -14,27 +14,6 @@ from qlib.workflow.record_temp import SignalRecord, SigAnaRecord, PortAnaRecord from qlib.tests import TestAutoData from qlib.tests.config import CSI300_GBDT_TASK, CSI300_BENCH -port_analysis_config = { - "strategy": { - "class": "TopkDropoutStrategy", - "module_path": "qlib.contrib.strategy.strategy", - "kwargs": { - "topk": 50, - "n_drop": 5, - }, - }, - "backtest": { - "verbose": False, - "limit_threshold": 0.095, - "account": 100000000, - "benchmark": CSI300_BENCH, - "deal_price": "close", - "open_cost": 0.0005, - "close_cost": 0.0015, - "min_cost": 5, - }, -} - def train(): """train model @@ -58,7 +37,7 @@ def train(): with R.start(experiment_name="workflow"): R.log_params(**flatten_dict(CSI300_GBDT_TASK)) model.fit(dataset) - + R.save_objects(trained_model=model) # prediction recorder = R.get_recorder() # To test __repr__ @@ -68,7 +47,6 @@ def train(): rid = recorder.id sr = SignalRecord(model, dataset, recorder) sr.generate() - pred_score = sr.load() # calculate ic and ric sar = SigAnaRecord(recorder) @@ -76,7 +54,7 @@ def train(): ic = sar.load(sar.get_path("ic.pkl")) ric = sar.load(sar.get_path("ric.pkl")) - return pred_score, {"ic": ic, "ric": ric}, rid + return {"ic": ic, "ric": ric}, rid def train_with_sigana(): @@ -103,10 +81,9 @@ def train_with_sigana(): sar.generate() ic = sar.load(sar.get_path("ic.pkl")) ric = sar.load(sar.get_path("ric.pkl")) - pred_score = sar.load("pred.pkl") uri_path = R.get_uri() - return pred_score, {"ic": ic, "ric": ric}, uri_path + return {"ic": ic, "ric": ric}, uri_path def fake_experiment(): @@ -130,13 +107,11 @@ def fake_experiment(): return default_uri == default_uri_to_check, current_uri == current_uri_to_check, current_uri -def backtest_analysis(pred, rid): +def backtest_analysis(rid): """backtest and analysis Parameters ---------- - pred : pandas.DataFrame - predict scores rid : str the id of the recorder to be used in this function @@ -147,16 +122,54 @@ def backtest_analysis(pred, rid): """ recorder = R.get_recorder(experiment_name="workflow", recorder_id=rid) + + dataset = init_instance_by_config(CSI300_GBDT_TASK["dataset"]) + model = recorder.load_object("trained_model") + + port_analysis_config = { + "executor": { + "class": "SimulatorExecutor", + "module_path": "qlib.backtest.executor", + "kwargs": { + "time_per_step": "day", + "generate_report": True, + }, + }, + "strategy": { + "class": "TopkDropoutStrategy", + "module_path": "qlib.contrib.strategy.model_strategy", + "kwargs": { + "model": model, + "dataset": dataset, + "topk": 50, + "n_drop": 5, + }, + }, + "backtest": { + "start_time": "2017-01-01", + "end_time": "2020-08-01", + "account": 100000000, + "benchmark": CSI300_BENCH, + "exchange_kwargs": { + "freq": "day", + "limit_threshold": 0.095, + "deal_price": "close", + "open_cost": 0.0005, + "close_cost": 0.0015, + "min_cost": 5, + }, + }, + } + # backtest - par = PortAnaRecord(recorder, port_analysis_config) + par = PortAnaRecord(recorder, port_analysis_config, risk_analysis_freq="day") par.generate() - analysis_df = par.load(par.get_path("port_analysis.pkl")) + analysis_df = par.load(par.get_path("port_analysis_1day.pkl")) print(analysis_df) return analysis_df class TestAllFlow(TestAutoData): - PRED_SCORE = None REPORT_NORMAL = None POSITIONS = None RID = None @@ -166,18 +179,18 @@ class TestAllFlow(TestAutoData): shutil.rmtree(str(Path(C["exp_manager"]["kwargs"]["uri"].strip("file:")).resolve())) def test_0_train_with_sigana(self): - TestAllFlow.PRED_SCORE, ic_ric, uri_path = train_with_sigana() + ic_ric, uri_path = train_with_sigana() self.assertGreaterEqual(ic_ric["ic"].all(), 0, "train failed") self.assertGreaterEqual(ic_ric["ric"].all(), 0, "train failed") shutil.rmtree(str(Path(uri_path.strip("file:")).resolve())) def test_1_train(self): - TestAllFlow.PRED_SCORE, ic_ric, TestAllFlow.RID = train() + ic_ric, TestAllFlow.RID = train() self.assertGreaterEqual(ic_ric["ic"].all(), 0, "train failed") self.assertGreaterEqual(ic_ric["ric"].all(), 0, "train failed") def test_2_backtest(self): - analyze_df = backtest_analysis(TestAllFlow.PRED_SCORE, TestAllFlow.RID) + analyze_df = backtest_analysis(TestAllFlow.RID) self.assertGreaterEqual( analyze_df.loc(axis=0)["excess_return_with_cost", "annualized_return"].values[0], 0.10,