From 6a94b455032a8c6702db4efe543ab228a7e56c21 Mon Sep 17 00:00:00 2001 From: you-n-g Date: Fri, 15 Oct 2021 13:52:55 +0800 Subject: [PATCH 01/28] Update order_generator.py --- qlib/contrib/strategy/order_generator.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/qlib/contrib/strategy/order_generator.py b/qlib/contrib/strategy/order_generator.py index eff938dd7..9c0f684fd 100644 --- a/qlib/contrib/strategy/order_generator.py +++ b/qlib/contrib/strategy/order_generator.py @@ -124,8 +124,8 @@ class OrderGenWInteract(OrderGenerator): order_list = trade_exchange.generate_order_for_target_amount_position( target_position=target_amount_dict, current_position=current_amount_dict, - trade_start_time=trade_start_time, - trade_end_time=trade_end_time, + start_time=trade_start_time, + end_time=trade_end_time, ) return TradeDecisionWO(order_list, self) @@ -188,7 +188,7 @@ class OrderGenWOInteract(OrderGenerator): order_list = trade_exchange.generate_order_for_target_amount_position( target_position=amount_dict, current_position=current.get_stock_amount_dict(), - trade_start_time=trade_start_time, - trade_end_time=trade_end_time, + start_time=trade_start_time, + end_time=trade_end_time, ) return TradeDecisionWO(order_list, self) From 3ab5721448858248e32a0f3c846bb43b194dc2f3 Mon Sep 17 00:00:00 2001 From: you-n-g Date: Fri, 15 Oct 2021 14:28:08 +0800 Subject: [PATCH 02/28] Fix OrderGenerator's return value --- qlib/contrib/strategy/order_generator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/qlib/contrib/strategy/order_generator.py b/qlib/contrib/strategy/order_generator.py index 9c0f684fd..34102d88a 100644 --- a/qlib/contrib/strategy/order_generator.py +++ b/qlib/contrib/strategy/order_generator.py @@ -127,7 +127,7 @@ class OrderGenWInteract(OrderGenerator): start_time=trade_start_time, end_time=trade_end_time, ) - return TradeDecisionWO(order_list, self) + return order_list class OrderGenWOInteract(OrderGenerator): @@ -163,7 +163,7 @@ class OrderGenWOInteract(OrderGenerator): :param trade_date: :type trade_date: pd.Timestamp - :rtype: list + :rtype: list of generated orders """ risk_total_value = risk_degree * current.calculate_value() @@ -191,4 +191,4 @@ class OrderGenWOInteract(OrderGenerator): start_time=trade_start_time, end_time=trade_end_time, ) - return TradeDecisionWO(order_list, self) + return order_list From 2e49a5f7c0f47e157f3512c65f37f1a62ec060d1 Mon Sep 17 00:00:00 2001 From: Dong Zhou Date: Fri, 15 Oct 2021 07:04:47 +0000 Subject: [PATCH 03/28] fix order generator --- qlib/contrib/strategy/order_generator.py | 25 +++++++++++++----------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/qlib/contrib/strategy/order_generator.py b/qlib/contrib/strategy/order_generator.py index 34102d88a..ee20be947 100644 --- a/qlib/contrib/strategy/order_generator.py +++ b/qlib/contrib/strategy/order_generator.py @@ -82,16 +82,17 @@ class OrderGenWInteract(OrderGenerator): """ # calculate current_tradable_value current_amount_dict = current.get_stock_amount_dict() + current_total_value = trade_exchange.calculate_amount_position_value( amount_dict=current_amount_dict, - trade_start_time=trade_start_time, - trade_end_time=trade_end_time, + start_time=trade_start_time, + end_time=trade_end_time, only_tradable=False, ) current_tradable_value = trade_exchange.calculate_amount_position_value( amount_dict=current_amount_dict, - trade_start_time=trade_start_time, - trade_end_time=trade_end_time, + start_time=trade_start_time, + end_time=trade_end_time, only_tradable=True, ) # add cash @@ -105,9 +106,7 @@ class OrderGenWInteract(OrderGenerator): # value. Then just sell all the stocks target_amount_dict = copy.deepcopy(current_amount_dict.copy()) for stock_id in list(target_amount_dict.keys()): - if trade_exchange.is_stock_tradable( - stock_id, trade_start_time=trade_start_time, trade_end_time=trade_end_time - ): + if trade_exchange.is_stock_tradable(stock_id, start_time=trade_start_time, end_time=trade_end_time): del target_amount_dict[stock_id] else: # consider cost rate @@ -118,8 +117,8 @@ class OrderGenWInteract(OrderGenerator): target_amount_dict = trade_exchange.generate_amount_position_from_weight_position( weight_position=target_weight_position, cash=current_tradable_value, - trade_start_time=trade_start_time, - trade_end_time=trade_end_time, + start_time=trade_start_time, + end_time=trade_end_time, ) order_list = trade_exchange.generate_order_for_target_amount_position( target_position=target_amount_dict, @@ -172,13 +171,17 @@ class OrderGenWOInteract(OrderGenerator): for stock_id in target_weight_position: # Current rule will ignore the stock that not hold and cannot be traded at predict date if trade_exchange.is_stock_tradable( - stock_id=stock_id, trade_start_time=trade_start_time, trade_end_time=trade_end_time + stock_id=stock_id, start_time=trade_start_time, end_time=trade_end_time + ) and trade_exchange.is_stock_tradable( + stock_id=stock_id, start_time=pred_start_time, end_time=pred_end_time ): amount_dict[stock_id] = ( risk_total_value * target_weight_position[stock_id] - / trade_exchange.get_close(stock_id, trade_start_time=pred_start_time, trade_end_time=pred_end_time) + / trade_exchange.get_close(stock_id, start_time=pred_start_time, end_time=pred_end_time) ) + # TODO: Qlib use None to represent trading suspension. So last close price can't be the estimated trading price. + # Maybe a close price with forward fill will be a better solution. elif stock_id in current_stock: amount_dict[stock_id] = ( risk_total_value * target_weight_position[stock_id] / current.get_stock_price(stock_id) From df9745f13433a7b0b03a8017c052e414d79c0ddb Mon Sep 17 00:00:00 2001 From: Dong Zhou Date: Fri, 15 Oct 2021 09:07:03 +0000 Subject: [PATCH 04/28] support empty order --- qlib/contrib/strategy/order_generator.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/qlib/contrib/strategy/order_generator.py b/qlib/contrib/strategy/order_generator.py index ee20be947..5dfef1510 100644 --- a/qlib/contrib/strategy/order_generator.py +++ b/qlib/contrib/strategy/order_generator.py @@ -80,6 +80,9 @@ class OrderGenWInteract(OrderGenerator): :rtype: list """ + if target_weight_position is None: + return [] + # calculate current_tradable_value current_amount_dict = current.get_stock_amount_dict() @@ -164,6 +167,9 @@ class OrderGenWOInteract(OrderGenerator): :rtype: list of generated orders """ + if target_weight_position is None: + return [] + risk_total_value = risk_degree * current.calculate_value() current_stock = current.get_stock_list() From ac08468330c94e7e11be798f8bbad49ea8b73ea1 Mon Sep 17 00:00:00 2001 From: Young Date: Fri, 15 Oct 2021 11:21:03 +0000 Subject: [PATCH 05/28] Make static prediction easier --- .../nested_decision_execution/workflow.py | 4 +- examples/workflow_by_code.py | 2 +- qlib/backtest/signal.py | 83 +++++++++++++++++++ qlib/contrib/strategy/__init__.py | 2 +- qlib/contrib/strategy/cost_control.py | 2 +- qlib/contrib/strategy/rule_strategy.py | 2 + .../{model_strategy.py => signal_strategy.py} | 34 +++++--- qlib/strategy/base.py | 41 +-------- tests/test_all_pipeline.py | 2 +- 9 files changed, 115 insertions(+), 57 deletions(-) create mode 100644 qlib/backtest/signal.py rename qlib/contrib/strategy/{model_strategy.py => signal_strategy.py} (91%) diff --git a/examples/nested_decision_execution/workflow.py b/examples/nested_decision_execution/workflow.py index ef6906018..72b6067b3 100644 --- a/examples/nested_decision_execution/workflow.py +++ b/examples/nested_decision_execution/workflow.py @@ -151,7 +151,7 @@ class NestedDecisionExecutionWorkflow: self._train_model(model, dataset) strategy_config = { "class": "TopkDropoutStrategy", - "module_path": "qlib.contrib.strategy.model_strategy", + "module_path": "qlib.contrib.strategy.signal_strategy", "kwargs": { "model": model, "dataset": dataset, @@ -189,7 +189,7 @@ class NestedDecisionExecutionWorkflow: backtest_config["benchmark"] = self.benchmark strategy_config = { "class": "TopkDropoutStrategy", - "module_path": "qlib.contrib.strategy.model_strategy", + "module_path": "qlib.contrib.strategy.signal_strategy", "kwargs": { "model": model, "dataset": dataset, diff --git a/examples/workflow_by_code.py b/examples/workflow_by_code.py index 486e694a7..248f0423f 100644 --- a/examples/workflow_by_code.py +++ b/examples/workflow_by_code.py @@ -31,7 +31,7 @@ if __name__ == "__main__": }, "strategy": { "class": "TopkDropoutStrategy", - "module_path": "qlib.contrib.strategy.model_strategy", + "module_path": "qlib.contrib.strategy.signal_strategy", "kwargs": { "model": model, "dataset": dataset, diff --git a/qlib/backtest/signal.py b/qlib/backtest/signal.py new file mode 100644 index 000000000..192f690c4 --- /dev/null +++ b/qlib/backtest/signal.py @@ -0,0 +1,83 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +from typing import Union +from ..model.base import BaseModel +from ..data.dataset import Dataset +from ..data.dataset.utils import convert_index_format +from ..utils.resam import resam_ts_data +import pandas as pd +import abc + + +class Signal(metaclass=abc.ABCMeta): + """ + Some trading strategy make decisions based on other prediction signals + The signals may comes from different sources(e.g. prepared data, online prediction from model and dataset) + + This + """ + + @abc.abstractmethod + def get_signal(self, start_time, end_time) -> Union[pd.Series, pd.DataFrame, None]: + """ + get the signal at the end of the decision step(from `start_time` to `end_time`) + + Returns + ------- + Union[pd.Series, pd.DataFrame, None]: + returns None if no signal in the specific day + """ + ... + + +class SignalWCache(Signal): + """ + Signal With pandas with based Cache + SignalWCache will store the prepared signal as a attribute and give the according signal based on input query + """ + + def __init__(self, signal: Union[pd.Series, pd.DataFrame]): + """ + + Parameters + ---------- + signal : Union[pd.Series, pd.DataFrame] + The expected format of the signal is like the data below (the order of index is not important and can be automatically adjusted) + + instrument datetime + SH600000 2008-01-02 0.079704 + 2008-01-03 0.120125 + 2008-01-04 0.878860 + 2008-01-07 0.505539 + 2008-01-08 0.395004 + """ + self.signal_cache = convert_index_format(signal, level="datetime") + + def get_signal(self, start_time, end_time) -> Union[pd.Series, pd.DataFrame]: + # the frequency of the signal may not algin with the decision frequency of strategy + # so resampling from the data is necessary + # the latest signal leverage more recent data and therefore is used in trading. + signal = resam_ts_data(self.signal_cache, start_time=start_time, end_time=end_time, method="last") + return signal + + +class ModelSignal(SignalWCache): + ... + + def __init__(self, model: BaseModel, dataset: Dataset): + self.model = model + self.dataset = dataset + pred_scores = self.model.predict(dataset) + if isinstance(pred_scores, pd.DataFrame): + pred_scores = pred_scores.iloc[:, 0] + super().__init__(pred_scores) + + def _update_model(self): + """ + When using online data, update model in each bar as the following steps: + - update dataset with online data, the dataset should support online update + - make the latest prediction scores of the new bar + - update the pred score into the latest prediction + """ + # TODO: this method is not included in the framework and could be refactor later + raise NotImplementedError("_update_model is not implemented!") diff --git a/qlib/contrib/strategy/__init__.py b/qlib/contrib/strategy/__init__.py index e308c1a05..adc1679c1 100644 --- a/qlib/contrib/strategy/__init__.py +++ b/qlib/contrib/strategy/__init__.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. -from .model_strategy import ( +from .signal_strategy import ( TopkDropoutStrategy, WeightStrategyBase, ) diff --git a/qlib/contrib/strategy/cost_control.py b/qlib/contrib/strategy/cost_control.py index b45c03ae9..aaebe3543 100644 --- a/qlib/contrib/strategy/cost_control.py +++ b/qlib/contrib/strategy/cost_control.py @@ -6,7 +6,7 @@ This strategy is not well maintained from .order_generator import OrderGenWInteract -from .model_strategy import WeightStrategyBase +from .signal_strategy import WeightStrategyBase import copy diff --git a/qlib/contrib/strategy/rule_strategy.py b/qlib/contrib/strategy/rule_strategy.py index 23fdd2991..dcf4667ff 100644 --- a/qlib/contrib/strategy/rule_strategy.py +++ b/qlib/contrib/strategy/rule_strategy.py @@ -1,3 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. from pathlib import Path import warnings import numpy as np diff --git a/qlib/contrib/strategy/model_strategy.py b/qlib/contrib/strategy/signal_strategy.py similarity index 91% rename from qlib/contrib/strategy/model_strategy.py rename to qlib/contrib/strategy/signal_strategy.py index 1d22153a7..1adfc517e 100644 --- a/qlib/contrib/strategy/model_strategy.py +++ b/qlib/contrib/strategy/signal_strategy.py @@ -1,27 +1,33 @@ import copy +from qlib.backtest.signal import ModelSignal, Signal, SignalWCache +from typing import Union +from qlib.data.dataset import Dataset +from qlib.model.base import BaseModel from qlib.backtest.position import Position import warnings import numpy as np import pandas as pd from ...utils.resam import resam_ts_data -from ...strategy.base import ModelStrategy +from ...strategy.base import BaseStrategy from ...backtest.decision import Order, BaseTradeDecision, OrderDir, TradeDecisionWO from .order_generator import OrderGenWInteract -class TopkDropoutStrategy(ModelStrategy): +class TopkDropoutStrategy(BaseStrategy): # TODO: # 1. Supporting leverage the get_range_limit result from the decision # 2. Supporting alter_outer_trade_decision # 3. Supporting checking the availability of trade decision def __init__( self, - model, - dataset, + *, topk, n_drop, + model: BaseModel = None, + dataset: Dataset = None, + signal: Union[pd.DataFrame, pd.Series] = None, method_sell="bottom", method_buy="top", risk_degree=0.95, @@ -64,7 +70,7 @@ class TopkDropoutStrategy(ModelStrategy): """ super(TopkDropoutStrategy, self).__init__( - model, dataset, level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs + level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs ) self.topk = topk self.n_drop = n_drop @@ -73,6 +79,8 @@ class TopkDropoutStrategy(ModelStrategy): self.risk_degree = risk_degree self.hold_thresh = hold_thresh self.only_tradable = only_tradable + assert signal is not None or dataset is not None and model is not None + self.signal: Signal = ModelSignal(model=model, dataset=dataset) if signal is None else SignalWCache(signal) def get_risk_degree(self, trade_step=None): """get_risk_degree @@ -87,7 +95,7 @@ class TopkDropoutStrategy(ModelStrategy): trade_step = self.trade_calendar.get_trade_step() trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step) pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1) - pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last") + pred_score = self.signal.get_signal(start_time=pred_start_time, end_time=pred_end_time) if pred_score is None: return TradeDecisionWO([], self) if self.only_tradable: @@ -235,15 +243,17 @@ class TopkDropoutStrategy(ModelStrategy): return TradeDecisionWO(sell_order_list + buy_order_list, self) -class WeightStrategyBase(ModelStrategy): +class WeightStrategyBase(BaseStrategy): # TODO: # 1. Supporting leverage the get_range_limit result from the decision # 2. Supporting alter_outer_trade_decision # 3. Supporting checking the availability of trade decision def __init__( self, - model, - dataset, + *, + model: BaseModel = None, + dataset: Dataset = None, + signal: Union[pd.DataFrame, pd.Series] = None, order_generator_cls_or_obj=OrderGenWInteract, trade_exchange=None, level_infra=None, @@ -260,12 +270,14 @@ class WeightStrategyBase(ModelStrategy): - In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended. """ super(WeightStrategyBase, self).__init__( - model, dataset, level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs + level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs ) if isinstance(order_generator_cls_or_obj, type): self.order_generator = order_generator_cls_or_obj() else: self.order_generator = order_generator_cls_or_obj + assert signal is not None or dataset is not None and model is not None + self.signal: Signal = ModelSignal(model=model, dataset=dataset) if signal is None else SignalWCache(signal) def get_risk_degree(self, trade_step=None): """get_risk_degree @@ -298,7 +310,7 @@ class WeightStrategyBase(ModelStrategy): trade_step = self.trade_calendar.get_trade_step() trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step) pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1) - pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last") + pred_score = self.signal.get_signal(start_time=pred_start_time, end_time=pred_end_time) if pred_score is None: return TradeDecisionWO([], self) current_temp = copy.deepcopy(self.trade_position) diff --git a/qlib/strategy/base.py b/qlib/strategy/base.py index bd5d3dbd3..860477544 100644 --- a/qlib/strategy/base.py +++ b/qlib/strategy/base.py @@ -17,7 +17,7 @@ from ..utils import init_instance_by_config from ..backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager from ..backtest.decision import BaseTradeDecision -__all__ = ["BaseStrategy", "ModelStrategy", "RLStrategy", "RLIntStrategy"] +__all__ = ["BaseStrategy", "RLStrategy", "RLIntStrategy"] class BaseStrategy: @@ -194,45 +194,6 @@ class BaseStrategy: return max(cal_range[0], range_limit[0]), min(cal_range[1], range_limit[1]) -class ModelStrategy(BaseStrategy): - """Model-based trading strategy, use model to make predictions for trading""" - - def __init__( - self, - model: BaseModel, - dataset: DatasetH, - outer_trade_decision: BaseTradeDecision = None, - level_infra: LevelInfrastructure = None, - common_infra: CommonInfrastructure = None, - **kwargs, - ): - """ - Parameters - ---------- - model : BaseModel - the model used in when making predictions - dataset : DatasetH - provide test data for model - kwargs : dict - arguments that will be passed into `reset` method - """ - super(ModelStrategy, self).__init__(outer_trade_decision, level_infra, common_infra, **kwargs) - self.model = model - self.dataset = dataset - self.pred_scores = convert_index_format(self.model.predict(dataset), level="datetime") - if isinstance(self.pred_scores, pd.DataFrame): - self.pred_scores = self.pred_scores.iloc[:, 0] - - def _update_model(self): - """ - When using online data, pdate model in each bar as the following steps: - - update dataset with online data, the dataset should support online update - - make the latest prediction scores of the new bar - - update the pred score into the latest prediction - """ - raise NotImplementedError("_update_model is not implemented!") - - class RLStrategy(BaseStrategy): """RL-based strategy""" diff --git a/tests/test_all_pipeline.py b/tests/test_all_pipeline.py index da68139a8..69de8b129 100644 --- a/tests/test_all_pipeline.py +++ b/tests/test_all_pipeline.py @@ -144,7 +144,7 @@ def backtest_analysis(pred, rid, uri_path: str = None): }, "strategy": { "class": "TopkDropoutStrategy", - "module_path": "qlib.contrib.strategy.model_strategy", + "module_path": "qlib.contrib.strategy.signal_strategy", "kwargs": { "model": model, "dataset": dataset, From 052aad798278ed3779cc04e78f9227a58272f512 Mon Sep 17 00:00:00 2001 From: Young Date: Fri, 15 Oct 2021 14:04:21 +0000 Subject: [PATCH 06/28] simplify signal parameter --- docs/component/workflow.rst | 6 +++++ .../ALSTM/workflow_config_alstm_Alpha158.yaml | 5 ++-- .../ALSTM/workflow_config_alstm_Alpha360.yaml | 7 ++--- .../workflow_config_catboost_Alpha158.yaml | 5 ++-- .../workflow_config_catboost_Alpha360.yaml | 5 ++-- ...rkflow_config_doubleensemble_Alpha158.yaml | 5 ++-- ...rkflow_config_doubleensemble_Alpha360.yaml | 7 ++--- .../GATs/workflow_config_gats_Alpha158.yaml | 7 ++--- .../GATs/workflow_config_gats_Alpha360.yaml | 5 ++-- .../GRU/workflow_config_gru_Alpha158.yaml | 5 ++-- .../GRU/workflow_config_gru_Alpha360.yaml | 7 ++--- .../LSTM/workflow_config_lstm_Alpha158.yaml | 5 ++-- .../LSTM/workflow_config_lstm_Alpha360.yaml | 7 ++--- .../workflow_config_lightgbm_Alpha158.yaml | 5 ++-- ...w_config_lightgbm_Alpha158_multi_freq.yaml | 5 +++- .../workflow_config_lightgbm_Alpha360.yaml | 7 ++--- ..._config_lightgbm_configurable_dataset.yaml | 5 ++-- .../workflow_config_lightgbm_multi_freq.yaml | 5 ++-- .../workflow_config_linear_Alpha158.yaml | 5 ++-- .../workflow_config_localformer_Alpha158.yaml | 5 ++-- .../workflow_config_localformer_Alpha360.yaml | 5 ++-- .../MLP/workflow_config_mlp_Alpha158.yaml | 7 ++--- .../MLP/workflow_config_mlp_Alpha360.yaml | 7 ++--- .../SFM/workflow_config_sfm_Alpha360.yaml | 5 ++-- .../TCTS/workflow_config_tcts_Alpha360.yaml | 7 ++--- .../TFT/workflow_config_tft_Alpha158.yaml | 5 ++-- .../TRA/workflow_config_tra_Alpha158.yaml | 5 ++-- .../workflow_config_tra_Alpha158_full.yaml | 5 ++-- .../TRA/workflow_config_tra_Alpha360.yaml | 5 ++-- .../workflow_config_TabNet_Alpha158.yaml | 5 ++-- .../workflow_config_TabNet_Alpha360.yaml | 5 ++-- .../workflow_config_transformer_Alpha158.yaml | 5 ++-- .../workflow_config_transformer_Alpha360.yaml | 5 ++-- .../workflow_config_xgboost_Alpha158.yaml | 5 ++-- .../workflow_config_xgboost_Alpha360.yaml | 5 ++-- .../nested_decision_execution/workflow.py | 6 ++--- examples/workflow_by_code.py | 3 +-- qlib/backtest/signal.py | 25 +++++++++++++++-- qlib/contrib/strategy/signal_strategy.py | 27 ++++++++++--------- tests/test_all_pipeline.py | 3 +-- 40 files changed, 160 insertions(+), 98 deletions(-) diff --git a/docs/component/workflow.rst b/docs/component/workflow.rst index 84522af99..1b15212ac 100644 --- a/docs/component/workflow.rst +++ b/docs/component/workflow.rst @@ -53,6 +53,9 @@ Below is a typical config file of ``qrun``. kwargs: topk: 50 n_drop: 5 + signal: + - + - backtest: limit_threshold: 0.095 account: 100000000 @@ -240,6 +243,9 @@ The following script is the configuration of `backtest` and the `strategy` used kwargs: topk: 50 n_drop: 5 + signal: + - + - backtest: limit_threshold: 0.095 account: 100000000 diff --git a/examples/benchmarks/ALSTM/workflow_config_alstm_Alpha158.yaml b/examples/benchmarks/ALSTM/workflow_config_alstm_Alpha158.yaml index 039040d8f..a8e89e360 100755 --- a/examples/benchmarks/ALSTM/workflow_config_alstm_Alpha158.yaml +++ b/examples/benchmarks/ALSTM/workflow_config_alstm_Alpha158.yaml @@ -36,8 +36,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/ALSTM/workflow_config_alstm_Alpha360.yaml b/examples/benchmarks/ALSTM/workflow_config_alstm_Alpha360.yaml index 88c6fcd07..3aa8147fc 100644 --- a/examples/benchmarks/ALSTM/workflow_config_alstm_Alpha360.yaml +++ b/examples/benchmarks/ALSTM/workflow_config_alstm_Alpha360.yaml @@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: @@ -86,4 +87,4 @@ task: - class: PortAnaRecord module_path: qlib.workflow.record_temp kwargs: - config: *port_analysis_config \ No newline at end of file + config: *port_analysis_config diff --git a/examples/benchmarks/CatBoost/workflow_config_catboost_Alpha158.yaml b/examples/benchmarks/CatBoost/workflow_config_catboost_Alpha158.yaml index 18e19bd0f..2eb642741 100644 --- a/examples/benchmarks/CatBoost/workflow_config_catboost_Alpha158.yaml +++ b/examples/benchmarks/CatBoost/workflow_config_catboost_Alpha158.yaml @@ -14,8 +14,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/CatBoost/workflow_config_catboost_Alpha360.yaml b/examples/benchmarks/CatBoost/workflow_config_catboost_Alpha360.yaml index a6cdd1882..982963eea 100644 --- a/examples/benchmarks/CatBoost/workflow_config_catboost_Alpha360.yaml +++ b/examples/benchmarks/CatBoost/workflow_config_catboost_Alpha360.yaml @@ -21,8 +21,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/DoubleEnsemble/workflow_config_doubleensemble_Alpha158.yaml b/examples/benchmarks/DoubleEnsemble/workflow_config_doubleensemble_Alpha158.yaml index fb8cce74d..12da23171 100644 --- a/examples/benchmarks/DoubleEnsemble/workflow_config_doubleensemble_Alpha158.yaml +++ b/examples/benchmarks/DoubleEnsemble/workflow_config_doubleensemble_Alpha158.yaml @@ -14,8 +14,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/DoubleEnsemble/workflow_config_doubleensemble_Alpha360.yaml b/examples/benchmarks/DoubleEnsemble/workflow_config_doubleensemble_Alpha360.yaml index d1fbd7807..d9481f12d 100644 --- a/examples/benchmarks/DoubleEnsemble/workflow_config_doubleensemble_Alpha360.yaml +++ b/examples/benchmarks/DoubleEnsemble/workflow_config_doubleensemble_Alpha360.yaml @@ -21,8 +21,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: @@ -100,4 +101,4 @@ task: - class: PortAnaRecord module_path: qlib.workflow.record_temp kwargs: - config: *port_analysis_config \ No newline at end of file + config: *port_analysis_config diff --git a/examples/benchmarks/GATs/workflow_config_gats_Alpha158.yaml b/examples/benchmarks/GATs/workflow_config_gats_Alpha158.yaml index 5387adc24..e056bc845 100644 --- a/examples/benchmarks/GATs/workflow_config_gats_Alpha158.yaml +++ b/examples/benchmarks/GATs/workflow_config_gats_Alpha158.yaml @@ -35,8 +35,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: @@ -94,4 +95,4 @@ task: - class: PortAnaRecord module_path: qlib.workflow.record_temp kwargs: - config: *port_analysis_config \ No newline at end of file + config: *port_analysis_config diff --git a/examples/benchmarks/GATs/workflow_config_gats_Alpha360.yaml b/examples/benchmarks/GATs/workflow_config_gats_Alpha360.yaml index 1ffd6780e..2effecd61 100644 --- a/examples/benchmarks/GATs/workflow_config_gats_Alpha360.yaml +++ b/examples/benchmarks/GATs/workflow_config_gats_Alpha360.yaml @@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/GRU/workflow_config_gru_Alpha158.yaml b/examples/benchmarks/GRU/workflow_config_gru_Alpha158.yaml index 82c690889..7c525c12a 100755 --- a/examples/benchmarks/GRU/workflow_config_gru_Alpha158.yaml +++ b/examples/benchmarks/GRU/workflow_config_gru_Alpha158.yaml @@ -36,8 +36,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/GRU/workflow_config_gru_Alpha360.yaml b/examples/benchmarks/GRU/workflow_config_gru_Alpha360.yaml index 02c81c850..2daaa0136 100644 --- a/examples/benchmarks/GRU/workflow_config_gru_Alpha360.yaml +++ b/examples/benchmarks/GRU/workflow_config_gru_Alpha360.yaml @@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: @@ -85,4 +86,4 @@ task: - class: PortAnaRecord module_path: qlib.workflow.record_temp kwargs: - config: *port_analysis_config \ No newline at end of file + config: *port_analysis_config diff --git a/examples/benchmarks/LSTM/workflow_config_lstm_Alpha158.yaml b/examples/benchmarks/LSTM/workflow_config_lstm_Alpha158.yaml index f4412c262..bf3738bc0 100755 --- a/examples/benchmarks/LSTM/workflow_config_lstm_Alpha158.yaml +++ b/examples/benchmarks/LSTM/workflow_config_lstm_Alpha158.yaml @@ -36,8 +36,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/LSTM/workflow_config_lstm_Alpha360.yaml b/examples/benchmarks/LSTM/workflow_config_lstm_Alpha360.yaml index 10a1dc5df..d550cacb2 100644 --- a/examples/benchmarks/LSTM/workflow_config_lstm_Alpha360.yaml +++ b/examples/benchmarks/LSTM/workflow_config_lstm_Alpha360.yaml @@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: @@ -85,4 +86,4 @@ task: - class: PortAnaRecord module_path: qlib.workflow.record_temp kwargs: - config: *port_analysis_config \ No newline at end of file + config: *port_analysis_config diff --git a/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml b/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml index 2bb21d41d..e1171a85d 100644 --- a/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml +++ b/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml @@ -14,8 +14,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158_multi_freq.yaml b/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158_multi_freq.yaml index 46b5c0f80..3d0a7859c 100644 --- a/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158_multi_freq.yaml +++ b/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158_multi_freq.yaml @@ -33,6 +33,9 @@ port_analysis_config: &port_analysis_config kwargs: topk: 50 n_drop: 5 + signal: + - + - backtest: verbose: False limit_threshold: 0.095 @@ -80,4 +83,4 @@ task: - class: PortAnaRecord module_path: qlib.workflow.record_temp kwargs: - config: *port_analysis_config \ No newline at end of file + config: *port_analysis_config diff --git a/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha360.yaml b/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha360.yaml index b8af19ec1..053c5bd29 100644 --- a/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha360.yaml +++ b/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha360.yaml @@ -21,8 +21,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: @@ -76,4 +77,4 @@ task: - class: PortAnaRecord module_path: qlib.workflow.record_temp kwargs: - config: *port_analysis_config \ No newline at end of file + config: *port_analysis_config diff --git a/examples/benchmarks/LightGBM/workflow_config_lightgbm_configurable_dataset.yaml b/examples/benchmarks/LightGBM/workflow_config_lightgbm_configurable_dataset.yaml index a92f342a1..f1ffc45da 100644 --- a/examples/benchmarks/LightGBM/workflow_config_lightgbm_configurable_dataset.yaml +++ b/examples/benchmarks/LightGBM/workflow_config_lightgbm_configurable_dataset.yaml @@ -29,8 +29,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/LightGBM/workflow_config_lightgbm_multi_freq.yaml b/examples/benchmarks/LightGBM/workflow_config_lightgbm_multi_freq.yaml index 89fbcb153..20cf7de6e 100644 --- a/examples/benchmarks/LightGBM/workflow_config_lightgbm_multi_freq.yaml +++ b/examples/benchmarks/LightGBM/workflow_config_lightgbm_multi_freq.yaml @@ -31,8 +31,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/Linear/workflow_config_linear_Alpha158.yaml b/examples/benchmarks/Linear/workflow_config_linear_Alpha158.yaml index 9f055a62c..c4e4d8e21 100644 --- a/examples/benchmarks/Linear/workflow_config_linear_Alpha158.yaml +++ b/examples/benchmarks/Linear/workflow_config_linear_Alpha158.yaml @@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/Localformer/workflow_config_localformer_Alpha158.yaml b/examples/benchmarks/Localformer/workflow_config_localformer_Alpha158.yaml index cd31ecd1e..7f5a78e74 100644 --- a/examples/benchmarks/Localformer/workflow_config_localformer_Alpha158.yaml +++ b/examples/benchmarks/Localformer/workflow_config_localformer_Alpha158.yaml @@ -36,8 +36,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/Localformer/workflow_config_localformer_Alpha360.yaml b/examples/benchmarks/Localformer/workflow_config_localformer_Alpha360.yaml index f9cc091fd..9de80a350 100644 --- a/examples/benchmarks/Localformer/workflow_config_localformer_Alpha360.yaml +++ b/examples/benchmarks/Localformer/workflow_config_localformer_Alpha360.yaml @@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/MLP/workflow_config_mlp_Alpha158.yaml b/examples/benchmarks/MLP/workflow_config_mlp_Alpha158.yaml index 8303f3945..b0f95e696 100644 --- a/examples/benchmarks/MLP/workflow_config_mlp_Alpha158.yaml +++ b/examples/benchmarks/MLP/workflow_config_mlp_Alpha158.yaml @@ -41,8 +41,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: @@ -98,4 +99,4 @@ task: - class: PortAnaRecord module_path: qlib.workflow.record_temp kwargs: - config: *port_analysis_config \ No newline at end of file + config: *port_analysis_config diff --git a/examples/benchmarks/MLP/workflow_config_mlp_Alpha360.yaml b/examples/benchmarks/MLP/workflow_config_mlp_Alpha360.yaml index f52c5930d..053dd455a 100644 --- a/examples/benchmarks/MLP/workflow_config_mlp_Alpha360.yaml +++ b/examples/benchmarks/MLP/workflow_config_mlp_Alpha360.yaml @@ -29,8 +29,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: @@ -85,4 +86,4 @@ task: - class: PortAnaRecord module_path: qlib.workflow.record_temp kwargs: - config: *port_analysis_config \ No newline at end of file + config: *port_analysis_config diff --git a/examples/benchmarks/SFM/workflow_config_sfm_Alpha360.yaml b/examples/benchmarks/SFM/workflow_config_sfm_Alpha360.yaml index 5c66400bb..d750a9980 100644 --- a/examples/benchmarks/SFM/workflow_config_sfm_Alpha360.yaml +++ b/examples/benchmarks/SFM/workflow_config_sfm_Alpha360.yaml @@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/TCTS/workflow_config_tcts_Alpha360.yaml b/examples/benchmarks/TCTS/workflow_config_tcts_Alpha360.yaml index 484ed45b1..9e0e735d1 100644 --- a/examples/benchmarks/TCTS/workflow_config_tcts_Alpha360.yaml +++ b/examples/benchmarks/TCTS/workflow_config_tcts_Alpha360.yaml @@ -30,8 +30,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: @@ -94,4 +95,4 @@ task: - class: PortAnaRecord module_path: qlib.workflow.record_temp kwargs: - config: *port_analysis_config \ No newline at end of file + config: *port_analysis_config diff --git a/examples/benchmarks/TFT/workflow_config_tft_Alpha158.yaml b/examples/benchmarks/TFT/workflow_config_tft_Alpha158.yaml index 0508ce676..d83878e3e 100644 --- a/examples/benchmarks/TFT/workflow_config_tft_Alpha158.yaml +++ b/examples/benchmarks/TFT/workflow_config_tft_Alpha158.yaml @@ -16,8 +16,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml b/examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml index f273f62ee..c86f87fc6 100644 --- a/examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml +++ b/examples/benchmarks/TRA/workflow_config_tra_Alpha158.yaml @@ -57,8 +57,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml b/examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml index 8dc82cb99..75f18f3ee 100644 --- a/examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml +++ b/examples/benchmarks/TRA/workflow_config_tra_Alpha158_full.yaml @@ -51,8 +51,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml b/examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml index bd5b132ee..9ab5b904b 100644 --- a/examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml +++ b/examples/benchmarks/TRA/workflow_config_tra_Alpha360.yaml @@ -51,8 +51,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/TabNet/workflow_config_TabNet_Alpha158.yaml b/examples/benchmarks/TabNet/workflow_config_TabNet_Alpha158.yaml index 1d1c7da1c..d9b94e86c 100644 --- a/examples/benchmarks/TabNet/workflow_config_TabNet_Alpha158.yaml +++ b/examples/benchmarks/TabNet/workflow_config_TabNet_Alpha158.yaml @@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/TabNet/workflow_config_TabNet_Alpha360.yaml b/examples/benchmarks/TabNet/workflow_config_TabNet_Alpha360.yaml index 3d11efe60..830943d6b 100644 --- a/examples/benchmarks/TabNet/workflow_config_TabNet_Alpha360.yaml +++ b/examples/benchmarks/TabNet/workflow_config_TabNet_Alpha360.yaml @@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/Transformer/workflow_config_transformer_Alpha158.yaml b/examples/benchmarks/Transformer/workflow_config_transformer_Alpha158.yaml index 6174abf2e..e36d44c43 100644 --- a/examples/benchmarks/Transformer/workflow_config_transformer_Alpha158.yaml +++ b/examples/benchmarks/Transformer/workflow_config_transformer_Alpha158.yaml @@ -36,8 +36,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/Transformer/workflow_config_transformer_Alpha360.yaml b/examples/benchmarks/Transformer/workflow_config_transformer_Alpha360.yaml index 883c18cdc..cab46a4d4 100644 --- a/examples/benchmarks/Transformer/workflow_config_transformer_Alpha360.yaml +++ b/examples/benchmarks/Transformer/workflow_config_transformer_Alpha360.yaml @@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/XGBoost/workflow_config_xgboost_Alpha158.yaml b/examples/benchmarks/XGBoost/workflow_config_xgboost_Alpha158.yaml index 502a5e73c..5ee38cf70 100644 --- a/examples/benchmarks/XGBoost/workflow_config_xgboost_Alpha158.yaml +++ b/examples/benchmarks/XGBoost/workflow_config_xgboost_Alpha158.yaml @@ -14,8 +14,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/benchmarks/XGBoost/workflow_config_xgboost_Alpha360.yaml b/examples/benchmarks/XGBoost/workflow_config_xgboost_Alpha360.yaml index a2e40eefb..7c98bd40c 100644 --- a/examples/benchmarks/XGBoost/workflow_config_xgboost_Alpha360.yaml +++ b/examples/benchmarks/XGBoost/workflow_config_xgboost_Alpha360.yaml @@ -21,8 +21,9 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - model: - dataset: + signal: + - + - topk: 50 n_drop: 5 backtest: diff --git a/examples/nested_decision_execution/workflow.py b/examples/nested_decision_execution/workflow.py index 72b6067b3..d7f5fc813 100644 --- a/examples/nested_decision_execution/workflow.py +++ b/examples/nested_decision_execution/workflow.py @@ -153,8 +153,7 @@ class NestedDecisionExecutionWorkflow: "class": "TopkDropoutStrategy", "module_path": "qlib.contrib.strategy.signal_strategy", "kwargs": { - "model": model, - "dataset": dataset, + "signal": (model, dataset), "topk": 50, "n_drop": 5, }, @@ -191,8 +190,7 @@ class NestedDecisionExecutionWorkflow: "class": "TopkDropoutStrategy", "module_path": "qlib.contrib.strategy.signal_strategy", "kwargs": { - "model": model, - "dataset": dataset, + "signal": (model, dataset), "topk": 50, "n_drop": 5, }, diff --git a/examples/workflow_by_code.py b/examples/workflow_by_code.py index 248f0423f..7fd299338 100644 --- a/examples/workflow_by_code.py +++ b/examples/workflow_by_code.py @@ -33,8 +33,7 @@ if __name__ == "__main__": "class": "TopkDropoutStrategy", "module_path": "qlib.contrib.strategy.signal_strategy", "kwargs": { - "model": model, - "dataset": dataset, + "signal": (model, dataset), "topk": 50, "n_drop": 5, }, diff --git a/qlib/backtest/signal.py b/qlib/backtest/signal.py index 0a56ed281..a342a58be 100644 --- a/qlib/backtest/signal.py +++ b/qlib/backtest/signal.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -from typing import Union +from qlib.utils import init_instance_by_config +from typing import Dict, List, Text, Tuple, Union from ..model.base import BaseModel from ..data.dataset import Dataset from ..data.dataset.utils import convert_index_format @@ -14,7 +15,7 @@ class Signal(metaclass=abc.ABCMeta): Some trading strategy make decisions based on other prediction signals The signals may comes from different sources(e.g. prepared data, online prediction from model and dataset) - This + This interface is tries to provide unified interface for those different sources """ @abc.abstractmethod @@ -79,3 +80,23 @@ class ModelSignal(SignalWCache): """ # TODO: this method is not included in the framework and could be refactor later raise NotImplementedError("_update_model is not implemented!") + + +def create_signal_from( + obj: Union[Signal, Tuple[BaseModel, Dataset], List, Dict, Text, pd.Series, pd.DataFrame] +) -> Signal: + """ + create signal from diverse information + This method will choose the right method to create a signal based on `obj` + Please refer to the code below. + """ + if isinstance(obj, Signal): + return obj + elif isinstance(obj, (tuple, list)): + return ModelSignal(*obj) + elif isinstance(obj, (dict, str)): + return init_instance_by_config(obj) + elif isinstance(obj, (pd.DataFrame, pd.Series)): + return SignalWCache(signal=obj) + else: + raise NotImplementedError(f"This type of signal is not supported") diff --git a/qlib/contrib/strategy/signal_strategy.py b/qlib/contrib/strategy/signal_strategy.py index 1adfc517e..b47da9ed7 100644 --- a/qlib/contrib/strategy/signal_strategy.py +++ b/qlib/contrib/strategy/signal_strategy.py @@ -1,6 +1,8 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. import copy -from qlib.backtest.signal import ModelSignal, Signal, SignalWCache -from typing import Union +from qlib.backtest.signal import Signal, create_signal_from +from typing import Dict, List, Text, Tuple, Union from qlib.data.dataset import Dataset from qlib.model.base import BaseModel from qlib.backtest.position import Position @@ -25,9 +27,7 @@ class TopkDropoutStrategy(BaseStrategy): *, topk, n_drop, - model: BaseModel = None, - dataset: Dataset = None, - signal: Union[pd.DataFrame, pd.Series] = None, + signal: Union[Signal, Tuple[BaseModel, Dataset], List, Dict, Text, pd.Series, pd.DataFrame], method_sell="bottom", method_buy="top", risk_degree=0.95, @@ -45,6 +45,9 @@ class TopkDropoutStrategy(BaseStrategy): the number of stocks in the portfolio. n_drop : int number of stocks to be replaced in each trading date. + signal : + the information to describe a signal. Please refer to the docs of `qlib.backtest.signal.create_signal_from` + the decision of the strategy will base on the given signal method_sell : str dropout method_sell, random/bottom. method_buy : str @@ -79,8 +82,7 @@ class TopkDropoutStrategy(BaseStrategy): self.risk_degree = risk_degree self.hold_thresh = hold_thresh self.only_tradable = only_tradable - assert signal is not None or dataset is not None and model is not None - self.signal: Signal = ModelSignal(model=model, dataset=dataset) if signal is None else SignalWCache(signal) + self.signal: Signal = create_signal_from(signal) def get_risk_degree(self, trade_step=None): """get_risk_degree @@ -251,9 +253,7 @@ class WeightStrategyBase(BaseStrategy): def __init__( self, *, - model: BaseModel = None, - dataset: Dataset = None, - signal: Union[pd.DataFrame, pd.Series] = None, + signal: Union[Signal, Tuple[BaseModel, Dataset], List, Dict, Text, pd.Series, pd.DataFrame], order_generator_cls_or_obj=OrderGenWInteract, trade_exchange=None, level_infra=None, @@ -261,6 +261,9 @@ class WeightStrategyBase(BaseStrategy): **kwargs, ): """ + signal : + the information to describe a signal. Please refer to the docs of `qlib.backtest.signal.create_signal_from` + the decision of the strategy will base on the given signal trade_exchange : Exchange exchange that provides market info, used to deal order and generate report - If `trade_exchange` is None, self.trade_exchange will be set with common_infra @@ -276,8 +279,8 @@ class WeightStrategyBase(BaseStrategy): self.order_generator = order_generator_cls_or_obj() else: self.order_generator = order_generator_cls_or_obj - assert signal is not None or dataset is not None and model is not None - self.signal: Signal = ModelSignal(model=model, dataset=dataset) if signal is None else SignalWCache(signal) + + self.signal: Signal = create_signal_from(signal) def get_risk_degree(self, trade_step=None): """get_risk_degree diff --git a/tests/test_all_pipeline.py b/tests/test_all_pipeline.py index 69de8b129..24c6765aa 100644 --- a/tests/test_all_pipeline.py +++ b/tests/test_all_pipeline.py @@ -146,8 +146,7 @@ def backtest_analysis(pred, rid, uri_path: str = None): "class": "TopkDropoutStrategy", "module_path": "qlib.contrib.strategy.signal_strategy", "kwargs": { - "model": model, - "dataset": dataset, + "signal": (model, dataset), "topk": 50, "n_drop": 5, }, From 4efb0a75c15909f245fcfa053bf1f3fbc8ced83b Mon Sep 17 00:00:00 2001 From: Young Date: Fri, 15 Oct 2021 15:06:06 +0000 Subject: [PATCH 07/28] Being compatible with previous Qlib version --- .../LightGBM/workflow_config_lightgbm_Alpha158.yaml | 5 ++--- qlib/contrib/strategy/signal_strategy.py | 10 +++++++++- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml b/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml index e1171a85d..2d441dea9 100644 --- a/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml +++ b/examples/benchmarks/LightGBM/workflow_config_lightgbm_Alpha158.yaml @@ -14,9 +14,8 @@ port_analysis_config: &port_analysis_config class: TopkDropoutStrategy module_path: qlib.contrib.strategy kwargs: - signal: - - - - + model: + dataset: topk: 50 n_drop: 5 backtest: diff --git a/qlib/contrib/strategy/signal_strategy.py b/qlib/contrib/strategy/signal_strategy.py index b47da9ed7..ae69b4bb6 100644 --- a/qlib/contrib/strategy/signal_strategy.py +++ b/qlib/contrib/strategy/signal_strategy.py @@ -27,7 +27,7 @@ class TopkDropoutStrategy(BaseStrategy): *, topk, n_drop, - signal: Union[Signal, Tuple[BaseModel, Dataset], List, Dict, Text, pd.Series, pd.DataFrame], + signal: Union[Signal, Tuple[BaseModel, Dataset], List, Dict, Text, pd.Series, pd.DataFrame] = None, method_sell="bottom", method_buy="top", risk_degree=0.95, @@ -36,6 +36,8 @@ class TopkDropoutStrategy(BaseStrategy): trade_exchange=None, level_infra=None, common_infra=None, + model=None, + dataset=None, **kwargs, ): """ @@ -82,6 +84,12 @@ class TopkDropoutStrategy(BaseStrategy): self.risk_degree = risk_degree self.hold_thresh = hold_thresh self.only_tradable = only_tradable + + # This is trying to be compatible with previous version of qlib task config + if model is not None and dataset is not None: + warnings.warn("`model` `dataset` is deprecated; use `signal`.", DeprecationWarning) + signal = model, dataset + self.signal: Signal = create_signal_from(signal) def get_risk_degree(self, trade_step=None): From 22ff8fdc4467ed743e11532ebf0dc95786ffd067 Mon Sep 17 00:00:00 2001 From: Young Date: Sat, 16 Oct 2021 17:09:28 +0000 Subject: [PATCH 08/28] simple change log --- CHANGES.rst | 17 ++++++++++++++++- examples/benchmarks/README.md | 4 ++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/CHANGES.rst b/CHANGES.rst index 114d577f3..3daa1e8e6 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -159,6 +159,21 @@ Version 0.5.0 - Add baselines - public data crawler -Version greater than Version 0.5.0 + +Version 0.8.0 +-------------------- +- The backtest is greatly refactored. + - Nested decision execution framework is supported + - There are lots of changes for daily trading, it is hard to list all of them. But a few important changes could be noticed + - The trading limitation is more accurate; + - In `previous version `_, longing and shorting actions share the same action. + - In `current verison `_, the trading limitation is different between loging and shorting action. + - The constant is different when calculating annualized metrics. + - `Current version `_ uses more accurate constant than `previous version `_ + - `A new version `_ of data is released. Due to the unstability of Yahoo data source, the data may be different after downloading data again. + - Users could chec kout the backtesting results between `Current version `_ and `previous version `_ + + +Other Versions ---------------------------------- Please refer to `Github release Notes `_ diff --git a/examples/benchmarks/README.md b/examples/benchmarks/README.md index b1b1be82a..5e0d2a61b 100644 --- a/examples/benchmarks/README.md +++ b/examples/benchmarks/README.md @@ -8,6 +8,10 @@ The numbers shown below demonstrate the performance of the entire `workflow` of > > In the new version of qlib, the default dataset is **v2**. Since the data is collected from the YahooFinance API (which is not very stable), the results of *v2* and *v1* may differ --> +> NOTE: +> The backtest start from 0.8.0 is quite different from previous version. Please check out the changelog for the difference. + + ## Alpha158 dataset | Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown | From c427c64845ef6469ec74db041f34576efd5a6afa Mon Sep 17 00:00:00 2001 From: Dong Zhou Date: Tue, 19 Oct 2021 06:17:53 +0000 Subject: [PATCH 09/28] fix calendar --- qlib/backtest/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qlib/backtest/utils.py b/qlib/backtest/utils.py index 51130712d..5db7658b0 100644 --- a/qlib/backtest/utils.py +++ b/qlib/backtest/utils.py @@ -70,7 +70,7 @@ class TradeCalendarManager: - If self.trade_step >= self.self.trade_len, it means the trading is finished - If self.trade_step < self.self.trade_len, it means the number of trading step finished is self.trade_step """ - return self.trade_step >= self.trade_len + return self.trade_step >= self.trade_len - 1 def step(self): if self.finished(): From f537222ce395c5baa0f1c962b798c3c4a5202739 Mon Sep 17 00:00:00 2001 From: Young Date: Thu, 21 Oct 2021 11:58:54 +0000 Subject: [PATCH 10/28] make handler seperable --- qlib/data/dataset/handler.py | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index 507e5ea81..47fda4686 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -82,8 +82,6 @@ class DataHandler(Serializable): fetch_orig : bool Return the original data instead of copy if possible. """ - # Set logger - self.logger = get_module_logger("DataHandler") # Setup data loader assert data_loader is not None # to make start_time end_time could have None default value @@ -302,6 +300,7 @@ class DataHandlerLP(DataHandler): DK_R = "raw" DK_I = "infer" DK_L = "learn" + ATTR_MAP = {DK_R: "_data", DK_I: "_infer", DK_L: "_learn"} # process type PTYPE_I = "independent" @@ -543,7 +542,7 @@ class DataHandlerLP(DataHandler): raise AttributeError( "DataHandlerLP has not attribute _data, please set drop_raw = False if you want to use raw data" ) - df = getattr(self, {self.DK_R: "_data", self.DK_I: "_infer", self.DK_L: "_learn"}[data_key]) + df = getattr(self, self.ATTR_MAP[data_key]) return df def fetch( @@ -624,3 +623,27 @@ class DataHandlerLP(DataHandler): df = self._get_df_by_key(data_key).head() df = fetch_df_by_col(df, col_set) return df.columns.to_list() + + @classmethod + def cast(cls, handler: "DataHandlerLP") -> "DataHandlerLP": + """ + Motivation + - A user create a datahandler in his customized package. Then he want to share the processed handler to other users without introduce the package dependency and complicated data processing logic. + - This class make it possible by casting the class to DataHandlerLP and only keep the processed data + + Parameters + ---------- + handler : DataHandlerLP + A subclass of DataHandlerLP + + Returns + ------- + DataHandlerLP: + the converted processed data + """ + new_hd: DataHandlerLP = object.__new__(DataHandlerLP) + new_hd.from_cast = True # add a mark for the casted instance + + for key in list(DataHandlerLP.ATTR_MAP.values()) + ["instruments", "start_time", "end_time", "fetch_orig"]: + setattr(new_hd, key, getattr(handler, key, None)) + return new_hd From a58bc03a8e25e77011e0fde8b4799dc047b5d88d Mon Sep 17 00:00:00 2001 From: Young Date: Thu, 21 Oct 2021 13:15:02 +0000 Subject: [PATCH 11/28] add sepdf(make mini project only rely on qlib) --- qlib/contrib/data/utils/__init__.py | 0 qlib/contrib/data/utils/sepdf.py | 166 ++++++++++++++++++++++++++++ 2 files changed, 166 insertions(+) create mode 100644 qlib/contrib/data/utils/__init__.py create mode 100644 qlib/contrib/data/utils/sepdf.py diff --git a/qlib/contrib/data/utils/__init__.py b/qlib/contrib/data/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/qlib/contrib/data/utils/sepdf.py b/qlib/contrib/data/utils/sepdf.py new file mode 100644 index 000000000..b8b4dacda --- /dev/null +++ b/qlib/contrib/data/utils/sepdf.py @@ -0,0 +1,166 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +import pandas as pd +from typing import Dict, Iterable + + +def align_index(df_dict, join): + res = {} + for k, df in df_dict.items(): + if join is not None and k != join: + df = df.reindex(df_dict[join].index) + res[k] = df + return res + + +# Mocking the pd.DataFrame class +class SepDataFrame: + """ + (Sep)erate DataFrame + We usually concat multiple dataframe to be processed together(Such as feature, label, weight, filter). + However, they are usally be used seperately at last. + This will result in extra cost for concating and spliting data(reshaping and copying data in the memory is very expensive) + + SepDataFrame tries to act like a DataFrame whose column with multiindex + """ + + def __init__(self, df_dict: Dict[str, pd.DataFrame], join: str, skip_align=False): + """ + initialize the data based on the dataframe dictionary + + Parameters + ---------- + df_dict : Dict[str, pd.DataFrame] + dataframe dictionary + join : str + how to join the data + It will reindex the dataframe based on the join key. + If join is None, the reindex step will be skipped + + skip_align : + for some cases, we can improve performance by skipping aligning index + """ + self.join = join + + if skip_align: + self._df_dict = df_dict + else: + self._df_dict = align_index(df_dict, join) + + @property + def loc(self): + return SDFLoc(self, join=self.join) + + @property + def index(self): + return self._df_dict[self.join].index + + def apply_each(self, method: str, skip_align=True, *args, **kwargs): + """ + Assumptions: + - inplace methods will return None + """ + inplace = False + df_dict = {} + for k, df in self._df_dict.items(): + df_dict[k] = getattr(df, method)(*args, **kwargs) + if df_dict[k] is None: + inplace = True + if not inplace: + return SepDataFrame(df_dict=df_dict, join=self.join, skip_align=skip_align) + + def sort_index(self, *args, **kwargs): + return self.apply_each("sort_index", True, *args, **kwargs) + + def copy(self, *args, **kwargs): + return self.apply_each("copy", True, *args, **kwargs) + + def __getitem__(self, item): + return self._df_dict[item] + + def __setitem__(self, item: str, df: pd.DataFrame): + # TODO: consider the join behavior + self._df_dict[item] = df + + def __contains__(self, item): + return item in self._df_dict + + def droplevel(self, *args, **kwargs): + raise NotImplementedError(f"Please implement the `droplevel` method") + + @property + def columns(self): + dfs = [] + for k, df in self._df_dict.items(): + df = df.head(0) + df.columns = pd.MultiIndex.from_product([[k], df.columns]) + dfs.append(df) + return pd.concat(dfs, axis=1).columns + + # Useless methods + @staticmethod + def merge(df_dict: Dict[str, pd.DataFrame], join: str): + all_df = df_dict[join] + for k, df in df_dict.items(): + if k != join: + all_df = all_df.join(df) + return all_df + + +class SDFLoc: + """Mock Class""" + + def __init__(self, sdf: SepDataFrame, join): + self._sdf = sdf + self.axis = None + self.join = join + + def __call__(self, axis): + self.axis = axis + return self + + def __getitem__(self, args): + if self.axis == 1: + if isinstance(args, str): + return self._sdf[args] + elif isinstance(args, (tuple, list)): + return SepDataFrame({k: self._sdf[k] for k in args}, join=self.join) + else: + raise NotImplementedError(f"This type of input is not supported") + elif self.axis == 0: + return SepDataFrame({k: df.loc(axis=0)[args] for k, df in self._sdf._df_dict.items()}, join=self.join) + else: + ax0, *ax1 = args + if len(ax1) == 0: + ax1 = None + df = self._sdf + if ax1 is not None: + df = df.loc(axis=1)[ax1] + if ax0 is not None: + df = df.loc(axis=0)[ax0] + return df + + +# Patch pandas DataFrame +# Tricking isinstance to accept SepDataFrame as its subclass +import builtins + + +def _isinstance(instance, cls): + if isinstance_orig(instance, SepDataFrame): # pylint: disable=E0602 + if isinstance(cls, Iterable): + for c in cls: + if c is pd.DataFrame: + return True + elif cls is pd.DataFrame: + return True + return isinstance_orig(instance, cls) # pylint: disable=E0602 + + +builtins.isinstance_orig = builtins.isinstance +builtins.isinstance = _isinstance + +if __name__ == "__main__": + sdf = SepDataFrame({}, join=None) + print(isinstance(sdf, (pd.DataFrame,))) + print(isinstance(sdf, pd.DataFrame)) From 64130d9407c38d1450e3ad72ebc2e033092b79f6 Mon Sep 17 00:00:00 2001 From: Young Date: Fri, 22 Oct 2021 15:20:45 +0800 Subject: [PATCH 12/28] Fix the aggregation function of IndexData --- qlib/backtest/high_performance_ds.py | 5 +++++ qlib/utils/index_data.py | 12 +++++++++--- tests/misc/test_index_data.py | 14 +++++++++++++- 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/qlib/backtest/high_performance_ds.py b/qlib/backtest/high_performance_ds.py index 235bd054b..51847cac3 100644 --- a/qlib/backtest/high_performance_ds.py +++ b/qlib/backtest/high_performance_ds.py @@ -160,6 +160,11 @@ class NumpyQuote(BaseQuote): if is_single_value(start_time, end_time, self.freq, self.region): # this is a very special case. # skip aggregating function to speed-up the query calculation + + # FIXME: + # it will go to the else logic when it comes to the + # 1) the day before holiday when daily trading + # 2) the last minute of the day when intraday trading try: return self.data[stock_id].loc[start_time, field] except KeyError: diff --git a/qlib/utils/index_data.py b/qlib/utils/index_data.py index 5e3942db5..06fb42a5e 100644 --- a/qlib/utils/index_data.py +++ b/qlib/utils/index_data.py @@ -401,6 +401,10 @@ class IndexData(metaclass=index_data_ops_creator): def columns(self): return self.indices[1] + def __getitem__(self, args): + # NOTE: this tries to behave like a numpy array to be compatible with numpy aggregating function like nansum and nanmean + return self.iloc[args] + def _align_indices(self, other: "IndexData") -> "IndexData": """ Align all indices of `other` to `self` before performing the arithmetic operations. @@ -409,7 +413,7 @@ class IndexData(metaclass=index_data_ops_creator): Parameters ---------- other : "IndexData" - the index in `other` is to be chagned + the index in `other` is to be changed Returns ------- @@ -455,7 +459,8 @@ class IndexData(metaclass=index_data_ops_creator): """ return len(self.data) - def sum(self, axis=None): + def sum(self, axis=None, dtype=None, out=None): + assert out is None and dtype is None, "`out` is just for compatible with numpy's aggregating function" # FIXME: weird logic and not general if axis is None: return np.nansum(self.data) @@ -468,7 +473,8 @@ class IndexData(metaclass=index_data_ops_creator): else: raise ValueError(f"axis must be None, 0 or 1") - def mean(self, axis=None): + def mean(self, axis=None, dtype=None, out=None): + assert out is None and dtype is None, "`out` is just for compatible with numpy's aggregating function" # FIXME: weird logic and not general if axis is None: return np.nanmean(self.data) diff --git a/tests/misc/test_index_data.py b/tests/misc/test_index_data.py index 3cd819a0f..20cda69ff 100644 --- a/tests/misc/test_index_data.py +++ b/tests/misc/test_index_data.py @@ -1,6 +1,5 @@ import numpy as np import pandas as pd - import qlib.utils.index_data as idd import unittest @@ -115,6 +114,19 @@ class IndexDataTest(unittest.TestCase): # sd2 = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"]) # 2 * sd2 + def test_squeeze(self): + sd1 = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"]) + # automatically squeezing + self.assertTrue(not isinstance(np.nansum(sd1), idd.IndexData)) + self.assertTrue(not isinstance(np.sum(sd1), idd.IndexData)) + self.assertTrue(not isinstance(sd1.sum(), idd.IndexData)) + self.assertEqual(np.nansum(sd1), 10) + self.assertEqual(np.sum(sd1), 10) + self.assertEqual(sd1.sum(), 10) + self.assertEqual(np.nanmean(sd1), 2.5) + self.assertEqual(np.mean(sd1), 2.5) + self.assertEqual(sd1.mean(), 2.5) + if __name__ == "__main__": unittest.main() From 96b422a9066160c22ebe848638a699a1aa456b2b Mon Sep 17 00:00:00 2001 From: Dong Zhou Date: Fri, 22 Oct 2021 08:44:47 +0000 Subject: [PATCH 13/28] support market impact cost --- qlib/backtest/exchange.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py index 9e40e1877..6d5e12a2d 100644 --- a/qlib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -34,6 +34,7 @@ class Exchange: open_cost=0.0015, close_cost=0.0025, min_cost=5, + impact_cost=0.0, extra_quote=None, quote_cls=NumpyQuote, **kwargs, @@ -95,6 +96,7 @@ class Exchange: **NOTE**: `trade_unit` is included in the `kwargs`. It is necessary because we must distinguish `not set` and `disable trade_unit` :param min_cost: min cost, default 5 + :param impact_cost: market impact cost rate (a.k.a. slippage) :param extra_quote: pandas, dataframe consists of columns: like ['$vwap', '$close', '$volume', '$factor', 'limit_sell', 'limit_buy']. The limit indicates that the etf is tradable on a specific day. @@ -164,9 +166,12 @@ class Exchange: all_fields = list(all_fields | set(subscribe_fields)) self.all_fields = all_fields + self.open_cost = open_cost self.close_cost = close_cost self.min_cost = min_cost + self.impact_cost = impact_cost + self.limit_threshold: Union[Tuple[str, str], float, None] = limit_threshold self.volume_threshold = volume_threshold self.extra_quote = extra_quote @@ -718,6 +723,7 @@ class Exchange: :return: trade_price, trade_val, trade_cost """ trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time, direction=order.direction) + total_trade_val = self.get_volume(order.stock_id, order.start_time, order.end_time) * trade_price order.factor = self.get_factor(order.stock_id, order.start_time, order.end_time) order.deal_amount = order.amount # set to full amount and clip it step by step # Clipping amount first @@ -773,6 +779,7 @@ class Exchange: raise NotImplementedError("order type {} error".format(order.type)) trade_val = order.deal_amount * trade_price + cost_ratio += self.impact_cost * (trade_val / total_trade_val) ** 2 trade_cost = max(trade_val * cost_ratio, self.min_cost) if trade_val <= 1e-5: # if dealing is not successful, the trade_cost should be zero. From b70caff5226c2e962dfa691a0c9afa63b8887faa Mon Sep 17 00:00:00 2001 From: Dong Zhou Date: Fri, 22 Oct 2021 08:49:20 +0000 Subject: [PATCH 14/28] add doc --- qlib/backtest/exchange.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py index 6d5e12a2d..67a11651d 100644 --- a/qlib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -96,7 +96,7 @@ class Exchange: **NOTE**: `trade_unit` is included in the `kwargs`. It is necessary because we must distinguish `not set` and `disable trade_unit` :param min_cost: min cost, default 5 - :param impact_cost: market impact cost rate (a.k.a. slippage) + :param impact_cost: market impact cost rate (a.k.a. slippage). A recommended value is 0.1. :param extra_quote: pandas, dataframe consists of columns: like ['$vwap', '$close', '$volume', '$factor', 'limit_sell', 'limit_buy']. The limit indicates that the etf is tradable on a specific day. From 7313b4dad01b548686a3a8f53f999476978d7238 Mon Sep 17 00:00:00 2001 From: Dong Zhou Date: Fri, 22 Oct 2021 08:58:37 +0000 Subject: [PATCH 15/28] fix impact cost --- qlib/backtest/exchange.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py index 67a11651d..41abce226 100644 --- a/qlib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -732,8 +732,12 @@ class Exchange: # - It simulates that the large order is submitted, but partial is dealt regardless of rounding by trading unit. self._clip_amount_by_volume(order, dealt_order_amount) + # TODO: the adjusted cost ratio can be overestimated as deal_amount will be clipped in the next steps + trade_val = order.deal_amount * trade_price + adj_cost_ratio = self.impact_cost * (trade_val / total_trade_val) ** 2 + if order.direction == Order.SELL: - cost_ratio = self.close_cost + cost_ratio = self.close_cost + adj_cost_ratio # sell # if we don't know current position, we choose to sell all # Otherwise, we clip the amount based on current position @@ -756,7 +760,7 @@ class Exchange: self.logger.debug(f"Order clipped due to cash limitation: {order}") elif order.direction == Order.BUY: - cost_ratio = self.open_cost + cost_ratio = self.open_cost + adj_cost_ratio # buy if position is not None: cash = position.get_cash() @@ -778,8 +782,6 @@ class Exchange: else: raise NotImplementedError("order type {} error".format(order.type)) - trade_val = order.deal_amount * trade_price - cost_ratio += self.impact_cost * (trade_val / total_trade_val) ** 2 trade_cost = max(trade_val * cost_ratio, self.min_cost) if trade_val <= 1e-5: # if dealing is not successful, the trade_cost should be zero. From 3d7ebd1fe09a15d22cae4c81b936af8b00f372c1 Mon Sep 17 00:00:00 2001 From: Dong Zhou Date: Fri, 22 Oct 2021 10:13:15 +0000 Subject: [PATCH 16/28] add back trade_val --- qlib/backtest/exchange.py | 1 + 1 file changed, 1 insertion(+) diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py index 41abce226..e2707ad39 100644 --- a/qlib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -782,6 +782,7 @@ class Exchange: else: raise NotImplementedError("order type {} error".format(order.type)) + trade_val = order.deal_amount * trade_price trade_cost = max(trade_val * cost_ratio, self.min_cost) if trade_val <= 1e-5: # if dealing is not successful, the trade_cost should be zero. From c6bb11fe560d72d2aa33775c7f6c95faa058e14b Mon Sep 17 00:00:00 2001 From: Dong Zhou Date: Mon, 25 Oct 2021 05:46:12 +0000 Subject: [PATCH 17/28] avoid trade without enough cash --- qlib/backtest/exchange.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/qlib/backtest/exchange.py b/qlib/backtest/exchange.py index e2707ad39..cc88528fd 100644 --- a/qlib/backtest/exchange.py +++ b/qlib/backtest/exchange.py @@ -690,12 +690,14 @@ class Exchange: f"Order clipped due to volume limitation: {order}, {[(vol, rule) for vol, rule in zip(vol_limit_num, vol_limit)]}" ) - def _get_buy_amount_by_cash_limit(self, trade_price, cash): + def _get_buy_amount_by_cash_limit(self, trade_price, cash, cost_ratio): """return the real order amount after cash limit for buying. Parameters ---------- trade_price : float position : cash + cost_ratio : float + Return ---------- float @@ -704,10 +706,10 @@ class Exchange: max_trade_amount = 0 if cash >= self.min_cost: # critical_price means the stock transaction price when the service fee is equal to min_cost. - critical_price = self.min_cost / self.open_cost + self.min_cost + critical_price = self.min_cost / cost_ratio + self.min_cost if cash >= critical_price: - # the service fee is equal to open_cost * trade_amount - max_trade_amount = cash / (1 + self.open_cost) / trade_price + # the service fee is equal to cost_ratio * trade_amount + max_trade_amount = cash / (1 + cost_ratio) / trade_price else: # the service fee is equal to min_cost max_trade_amount = (cash - self.min_cost) / trade_price @@ -765,9 +767,13 @@ class Exchange: if position is not None: cash = position.get_cash() trade_val = order.deal_amount * trade_price - if cash < trade_val + max(trade_val * cost_ratio, self.min_cost): + if cash < max(trade_val * cost_ratio, self.min_cost): + # cash cannot cover cost + order.deal_amount = 0 + self.logger.debug(f"Order clipped due to cost higher than cash: {order}") + elif cash < trade_val + max(trade_val * cost_ratio, self.min_cost): # The money is not enough - max_buy_amount = self._get_buy_amount_by_cash_limit(trade_price, cash) + max_buy_amount = self._get_buy_amount_by_cash_limit(trade_price, cash, cost_ratio) order.deal_amount = self.round_amount_by_trade_unit( min(max_buy_amount, order.deal_amount), order.factor ) From 5fa56703ae07d4fd797e592b81cf5e4896e12d94 Mon Sep 17 00:00:00 2001 From: Young Date: Tue, 26 Oct 2021 23:32:33 +0800 Subject: [PATCH 18/28] add handler pickle attr, enhance init_instance_by_config --- qlib/contrib/data/utils/sepdf.py | 3 ++- qlib/data/dataset/handler.py | 8 +++++++- qlib/utils/__init__.py | 12 +++++++++--- 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/qlib/contrib/data/utils/sepdf.py b/qlib/contrib/data/utils/sepdf.py index b8b4dacda..9650b7729 100644 --- a/qlib/contrib/data/utils/sepdf.py +++ b/qlib/contrib/data/utils/sepdf.py @@ -124,7 +124,8 @@ class SDFLoc: if isinstance(args, str): return self._sdf[args] elif isinstance(args, (tuple, list)): - return SepDataFrame({k: self._sdf[k] for k in args}, join=self.join) + new_df_dict = {k: self._sdf[k] for k in args} + return SepDataFrame(new_df_dict, join=self.join if self.join in args else args[0]) else: raise NotImplementedError(f"This type of input is not supported") elif self.axis == 0: diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index 47fda4686..134091c22 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -644,6 +644,12 @@ class DataHandlerLP(DataHandler): new_hd: DataHandlerLP = object.__new__(DataHandlerLP) new_hd.from_cast = True # add a mark for the casted instance - for key in list(DataHandlerLP.ATTR_MAP.values()) + ["instruments", "start_time", "end_time", "fetch_orig"]: + for key in list(DataHandlerLP.ATTR_MAP.values()) + [ + "instruments", + "start_time", + "end_time", + "fetch_orig", + "drop_raw", + ]: setattr(new_hd, key, getattr(handler, key, None)) return new_hd diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index f6a6632ea..12553411c 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -199,6 +199,7 @@ def get_callable_kwargs(config: Union[dict, str], default_module: Union[str, Mod ---------- config : [dict, str] similar to config + please refer to the doc of init_instance_by_config default_module : Python module or str It should be a python module to load the class type @@ -219,9 +220,12 @@ def get_callable_kwargs(config: Union[dict, str], default_module: Union[str, Mod _callable = config["class"] # the class type itself is passed in kwargs = config.get("kwargs", {}) elif isinstance(config, str): - module = get_module_by_module_path(default_module) + # a.b.c.ClassName + *m_path, cls = config.split(".") + m_path = ".".join(m_path) + module = get_module_by_module_path(default_module if m_path == "" else m_path) - _callable = getattr(module, config) + _callable = getattr(module, cls) kwargs = {} else: raise NotImplementedError(f"This type of input is not supported") @@ -260,7 +264,9 @@ def init_instance_by_config( 1) specify a pickle object - path like 'file:////obj.pkl' 2) specify a class name - - "ClassName": getattr(module, config)() will be used. + - "ClassName": getattr(module, "ClassName")() will be used. + 3) specify module path with class name + - "a.b.c.ClassName" getattr(, "ClassName")() will be used. object example: instance of accept_types default_module : Python module From 31e9d529de3b41d19f89bd59e9171561f29a2a9c Mon Sep 17 00:00:00 2001 From: Young Date: Thu, 28 Oct 2021 00:01:19 +0800 Subject: [PATCH 19/28] Add multi horizon task generator --- qlib/__init__.py | 13 +++++---- qlib/workflow/task/gen.py | 56 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 5 deletions(-) diff --git a/qlib/__init__.py b/qlib/__init__.py index 107819860..3989b3692 100644 --- a/qlib/__init__.py +++ b/qlib/__init__.py @@ -152,8 +152,11 @@ def init_from_yaml_conf(conf_path, **kwargs): :param conf_path: A path to the qlib config in yml format """ - with open(conf_path) as f: - config = yaml.safe_load(f) + if conf_path is None: + config = {} + else: + with open(conf_path) as f: + config = yaml.safe_load(f) config.update(kwargs) default_conf = config.pop("default_conf", "client") init(default_conf, **config) @@ -216,7 +219,7 @@ def auto_init(**kwargs): .. code-block:: yaml conf_type: ref - qlib_cfg: '' + qlib_cfg: '' # this could be null reference no config from other files # following configs in `qlib_cfg_update` is project=specific qlib_cfg_update: exp_manager: @@ -259,8 +262,8 @@ def auto_init(**kwargs): # - There is a shared configure file and you don't want to edit it inplace. # - The shared configure may be updated later and you don't want to copy it. # - You have some customized config. - qlib_conf_path = conf["qlib_cfg"] - qlib_conf_update = conf.get("qlib_cfg_update") + qlib_conf_path = conf.get("qlib_cfg", None) + qlib_conf_update = conf.get("qlib_cfg_update", {}) init_from_yaml_conf(qlib_conf_path, **qlib_conf_update, **kwargs) logger = get_module_logger("Initialization") logger.info(f"Auto load project config: {conf_pp}") diff --git a/qlib/workflow/task/gen.py b/qlib/workflow/task/gen.py index 2fc87b1a4..2aab07b4b 100644 --- a/qlib/workflow/task/gen.py +++ b/qlib/workflow/task/gen.py @@ -112,6 +112,9 @@ def handler_mod(task: dict, rolling_gen): except KeyError: # Maybe dataset do not have handler, then do nothing. pass + except TypeError: + # May be the handler is a string. `"handler.pkl"["kwargs"]` will raise TypeError + pass class RollingGen(TaskGen): @@ -259,3 +262,56 @@ class RollingGen(TaskGen): # Update the following rolling res.extend(self.gen_following_tasks(t, test_end)) return res + + +class MultiHorizonGenBase(TaskGen): + def __init__(self, horizon: List[int] = [5], label_leak_n=2): + """ + This task generator tries to genrate tasks for different horizons based on an existing task + + Parameters + ---------- + horizon : List[int] + the possible horizons of the tasks + label_leak_n : int + How many future days it will take to get complete label after the day making prediction + For example: + - User make prediction on day `T`(after getting the close price on `T`) + - The label is the return of buying stock on `T + 1` and selling it on `T + 2` + - the `label_leak_n` will be 2 (e.g. two days of information is leaked to leverage this sample) + """ + self.horizon = list(horizon) + self.label_leak_n = label_leak_n + self.ta = TimeAdjuster() + self.test_key = "test" + + @abc.abstractmethod + def set_horizon(self, task: dict, hr: int): + """ + This method is designed to change the task **in place** + + Parameters + ---------- + task : dict + Qlib's task + hr : int + the horizon of task + """ + + def generate(self, task: dict): + res = [] + for hr in self.horizon: + + # Add horizon + t = copy.deepcopy(task) + self.set_horizon(t, hr) + + # adjust segment + segments = self.ta.align_seg(t["dataset"]["kwargs"]["segments"]) + test_start = min(t for t in segments[self.test_key] if t is not None) + for k in list(segments.keys()): + if k != self.test_key: + segments[k] = self.ta.truncate(segments[k], test_start, hr + self.label_leak_n) + t["dataset"]["kwargs"]["segments"] = segments + res.append(t) + return res From 82f8ff906695b105eb5e0e82baf7c88cf87505ba Mon Sep 17 00:00:00 2001 From: Young Date: Mon, 1 Nov 2021 00:51:21 +0800 Subject: [PATCH 20/28] Update seperate dataframe --- qlib/contrib/data/utils/sepdf.py | 36 +++++++++++++++++++++++--------- 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/qlib/contrib/data/utils/sepdf.py b/qlib/contrib/data/utils/sepdf.py index 9650b7729..58664c46c 100644 --- a/qlib/contrib/data/utils/sepdf.py +++ b/qlib/contrib/data/utils/sepdf.py @@ -75,6 +75,10 @@ class SepDataFrame: def copy(self, *args, **kwargs): return self.apply_each("copy", True, *args, **kwargs) + def _update_join(self): + if self.join not in self: + self.join = next(iter(self._df_dict.keys())) + def __getitem__(self, item): return self._df_dict[item] @@ -82,9 +86,16 @@ class SepDataFrame: # TODO: consider the join behavior self._df_dict[item] = df + def __delitem__(self, item: str): + del self._df_dict[item] + self._update_join() + def __contains__(self, item): return item in self._df_dict + def __len__(self): + return len(self._df_dict[self.join]) + def droplevel(self, *args, **kwargs): raise NotImplementedError(f"Please implement the `droplevel` method") @@ -125,21 +136,26 @@ class SDFLoc: return self._sdf[args] elif isinstance(args, (tuple, list)): new_df_dict = {k: self._sdf[k] for k in args} - return SepDataFrame(new_df_dict, join=self.join if self.join in args else args[0]) + return SepDataFrame(new_df_dict, join=self.join if self.join in args else args[0], skip_align=True) else: raise NotImplementedError(f"This type of input is not supported") elif self.axis == 0: - return SepDataFrame({k: df.loc(axis=0)[args] for k, df in self._sdf._df_dict.items()}, join=self.join) + return SepDataFrame( + {k: df.loc(axis=0)[args] for k, df in self._sdf._df_dict.items()}, join=self.join, skip_align=True + ) else: - ax0, *ax1 = args - if len(ax1) == 0: - ax1 = None df = self._sdf - if ax1 is not None: - df = df.loc(axis=1)[ax1] - if ax0 is not None: - df = df.loc(axis=0)[ax0] - return df + if isinstance(args, tuple): + ax0, *ax1 = args + if len(ax1) == 0: + ax1 = None + if ax1 is not None: + df = df.loc(axis=1)[ax1] + if ax0 is not None: + df = df.loc(axis=0)[ax0] + return df + else: + return df.loc(axis=0)[args] # Patch pandas DataFrame From 426b98a3bc92fa0d0e60a610fb417ac7ec5187cd Mon Sep 17 00:00:00 2001 From: Young Date: Sun, 31 Oct 2021 08:25:41 +0000 Subject: [PATCH 21/28] make the logic of online manager cleaner --- qlib/data/dataset/__init__.py | 2 +- qlib/workflow/online/manager.py | 88 +++++++++++++++++++++++++++------ qlib/workflow/task/gen.py | 5 ++ 3 files changed, 80 insertions(+), 15 deletions(-) diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index 1002df8ba..7ad5f4c6d 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -546,7 +546,7 @@ class TSDatasetH(DatasetH): dtype = kwargs.pop("dtype", None) start, end = slc.start, slc.stop flt_col = kwargs.pop("flt_col", None) - # TSDatasetH will retrieve more data for complete + # TSDatasetH will retrieve more data for complete time-series data = self._prepare_raw_seg(slc, **kwargs) flt_kwargs = deepcopy(kwargs) diff --git a/qlib/workflow/online/manager.py b/qlib/workflow/online/manager.py index b4b509483..e9f0fe9d2 100644 --- a/qlib/workflow/online/manager.py +++ b/qlib/workflow/online/manager.py @@ -21,19 +21,65 @@ Situations Description Online + Trainer When you want to do a REAL routine, the Trainer will help you train the models. It will train models task by task and strategy by strategy. -Online + DelayTrainer When your models don't have any temporal dependence, the DelayTrainer will train - nothing until all tasks have been prepared. It makes user can train all tasks in - the end of `routine` or `first_train`. +Online + DelayTrainer DelayTrainer will skip concrete training until all tasks have been prepared by + different strategies. It makes users can parallelly train all tasks at the end of + `routine` or `first_train`. Otherwise, these functions will get stuck when each + strategy prepare tasks. -Simulation + Trainer When your models have some temporal dependence on the previous models, then you - need to consider using Trainer. This means it will REAL train your models in - every routine and prepare signals for every routine. +Simulation + Trainer It will behave in the same way as `Online + Trainer`. The only difference is that it + is for simulation/backtesting instead of online trading Simulation + DelayTrainer When your models don't have any temporal dependence, you can use DelayTrainer for the ability to multitasking. It means all tasks in all routines can be REAL trained at the end of simulating. The signals will be prepared well at different time segments (based on whether or not any new model is online). ========================= =================================================================================== + +Here is some pseudo code the demonstrate the workflow of each situation + +For simplicity + - Only one strategy is used in the strategy + - `update_online_pred` is only called in the online mode and is ignored + +1) `Online + Trainer` + +.. code-block:: python + + tasks = first_train() + models = trainer.train(tasks) + trainer.end_train(models) + for day in online_trading_days: + # OnlineManager.routine + models = trainer.train(strategy.prepare_tasks()) # for each strategy + strategy.prepare_online_models(models) # for each strategy + + trainer.end_train(models) + prepare_signals() # prepare trading signals daily + + +`Online + DelayTrainer`: the workflow is the same as `Online + Trainer`. + + +2) `Simulation + DelayTrainer` + +.. code-block:: python + + # simulate + tasks = first_train() + models = trainer.train(tasks) + for day in historical_calendars: + # OnlineManager.routine + models = trainer.train(strategy.prepare_tasks()) # for each strategy + strategy.prepare_online_models(models) # for each strategy + # delay_prepare() + # FIXME: Currently the delay_prepare is not implemented in a proper way. + trainer.end_train() + prepare_signals() + + +# Can we simplify current workflow? +- Can reduce the number of state of tasks? + - For each task, we have three phases (i.e. task, partly trained task, final trained task) """ import logging @@ -58,7 +104,7 @@ class OnlineManager(Serializable): """ STATUS_SIMULATING = "simulating" # when calling `simulate` - STATUS_NORMAL = "normal" # the normal status + STATUS_ONLINE = "online" # the normal status. It is used when online trading def __init__( self, @@ -87,12 +133,24 @@ class OnlineManager(Serializable): self.begin_time = pd.Timestamp(begin_time) self.cur_time = self.begin_time # OnlineManager will recorder the history of online models, which is a dict like {pd.Timestamp, {strategy, [online_models]}}. + # It records the online servnig models of each strategy for each day. self.history = {} if trainer is None: trainer = TrainerR() self.trainer = trainer self.signals = None - self.status = self.STATUS_NORMAL + self.status = self.STATUS_ONLINE + + def _postpone_action(self): + """ + Should the workflow to postpone the following actions to the end (in delay_prepare) + - trainer.end_train + - prepare_signals + + Postpone these actions is to support simulating/backtest online strategies without time dependencies. + All the actions can be done parallelly at the end. + """ + return self.status == self.STATUS_SIMULATING and self.trainer.is_delay() def first_train(self, strategies: List[OnlineStrategy] = None, model_kwargs: dict = {}): """ @@ -113,12 +171,12 @@ class OnlineManager(Serializable): models = self.trainer.train(tasks, experiment_name=strategy.name_id) models_list.append(models) self.logger.info(f"Finished training {len(models)} models.") - # FIXME: Traing multiple online models at `first_train` will result in getting too much online models at the + # FIXME: Train multiple online models at `first_train` will result in getting too much online models at the # start. online_models = strategy.prepare_online_models(models, **model_kwargs) self.history.setdefault(self.cur_time, {})[strategy] = online_models - if not self.status == self.STATUS_SIMULATING or not self.trainer.is_delay(): + if not self._postpone_action(): for strategy, models in zip(strategies, models_list): models = self.trainer.end_train(models, experiment_name=strategy.name_id) @@ -160,10 +218,10 @@ class OnlineManager(Serializable): # The online model may changes in the above processes # So updating the predictions of online models should be the last step - if self.status == self.STATUS_NORMAL: + if self.status == self.STATUS_ONLINE: strategy.tool.update_online_pred() - if not self.status == self.STATUS_SIMULATING or not self.trainer.is_delay(): + if not self._postpone_action(): for strategy, models in zip(self.strategies, models_list): models = self.trainer.end_train(models, experiment_name=strategy.name_id) self.prepare_signals(**signal_kwargs) @@ -278,13 +336,13 @@ class OnlineManager(Serializable): signal_kwargs=signal_kwargs, ) # delay prepare the models and signals - if self.trainer.is_delay(): + if self._postpone_action(): self.delay_prepare(model_kwargs=model_kwargs, signal_kwargs=signal_kwargs) # FIXME: get logging level firstly and restore it here set_global_logger_level(logging.DEBUG) self.logger.info(f"Finished preparing signals") - self.status = self.STATUS_NORMAL + self.status = self.STATUS_ONLINE return self.get_signals() def delay_prepare(self, model_kwargs={}, signal_kwargs={}): @@ -295,6 +353,8 @@ class OnlineManager(Serializable): model_kwargs: the params for `end_train` signal_kwargs: the params for `prepare_signals` """ + # FIXME: + # This method is not implemented in the proper way!!! last_models = {} signals_time = D.calendar()[0] need_prepare = False diff --git a/qlib/workflow/task/gen.py b/qlib/workflow/task/gen.py index 2aab07b4b..45fba12da 100644 --- a/qlib/workflow/task/gen.py +++ b/qlib/workflow/task/gen.py @@ -94,6 +94,11 @@ class TaskGen(metaclass=abc.ABCMeta): def handler_mod(task: dict, rolling_gen): """ Help to modify the handler end time when using RollingGen + It try to handle the following case + - Hander's data end_time is earlier than dataset's test_data's segments. + - To handle this, handler's data's end_time is extended. + + If the handler's end_time is None, then it is not necessary to change it's end time. Args: task (dict): a task template From e54b019ee22a59d2a5d055db45fbb9b32f73b0bd Mon Sep 17 00:00:00 2001 From: Young Date: Mon, 1 Nov 2021 06:22:03 +0000 Subject: [PATCH 22/28] solve init kwargs conflictions --- qlib/__init__.py | 11 +++++++++-- qlib/workflow/task/collect.py | 5 ++++- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/qlib/__init__.py b/qlib/__init__.py index 3989b3692..19a7e09af 100644 --- a/qlib/__init__.py +++ b/qlib/__init__.py @@ -249,6 +249,7 @@ def auto_init(**kwargs): except FileNotFoundError: init(**kwargs) else: + logger = get_module_logger("Initialization") conf_pp = pp / "config.yaml" with conf_pp.open() as f: conf = yaml.safe_load(f) @@ -263,7 +264,13 @@ def auto_init(**kwargs): # - The shared configure may be updated later and you don't want to copy it. # - You have some customized config. qlib_conf_path = conf.get("qlib_cfg", None) + + # merge the arguments qlib_conf_update = conf.get("qlib_cfg_update", {}) - init_from_yaml_conf(qlib_conf_path, **qlib_conf_update, **kwargs) - logger = get_module_logger("Initialization") + for k, v in kwargs.items(): + if k in qlib_conf_update: + logger.warning(f"`qlib_conf_update` from conf_pp is override by `kwargs` on key '{k}'") + qlib_conf_update.update(kwargs) + + init_from_yaml_conf(qlib_conf_path, **qlib_conf_update) logger.info(f"Auto load project config: {conf_pp}") diff --git a/qlib/workflow/task/collect.py b/qlib/workflow/task/collect.py index 467281666..13fcd0202 100644 --- a/qlib/workflow/task/collect.py +++ b/qlib/workflow/task/collect.py @@ -5,6 +5,7 @@ Collector module can collect objects from everywhere and process them such as merging, grouping, averaging and so on. """ +from libs.qlib.qlib.log import TimeInspector from typing import Callable, Dict, List from qlib.log import get_module_logger from qlib.utils.serial import Serializable @@ -190,7 +191,9 @@ class RecorderCollector(Collector): collect_dict = {} # filter records - recs = self.experiment.list_recorders(**self.list_kwargs) + + with TimeInspector.logt("Time to `list_recorders` in RecorderCollector"): + recs = self.experiment.list_recorders(**self.list_kwargs) recs_flt = {} for rid, rec in recs.items(): if rec_filter_func is None or rec_filter_func(rec): From d929d4bb210574fb63861c3e6489c3bf117a3548 Mon Sep 17 00:00:00 2001 From: Dong Zhou Date: Mon, 1 Nov 2021 09:29:44 +0000 Subject: [PATCH 23/28] rm recorder temp file --- qlib/workflow/recorder.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/qlib/workflow/recorder.py b/qlib/workflow/recorder.py index 0bf6f4841..73da5e19f 100644 --- a/qlib/workflow/recorder.py +++ b/qlib/workflow/recorder.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import os from qlib.utils.serial import Serializable import mlflow, logging import shutil, os, pickle, tempfile, codecs, pickle @@ -333,7 +334,9 @@ class MLflowRecorder(Recorder): try: path = self.client.download_artifacts(self.id, name) with Path(path).open("rb") as f: - return pickle.load(f) + data = pickle.load(f) + os.remove(path) + return data except Exception as e: raise LoadObjectError(message=str(e)) From 7a884fa9f250d57f022bf144510e67b0cf92a468 Mon Sep 17 00:00:00 2001 From: Young Date: Mon, 1 Nov 2021 18:55:44 +0800 Subject: [PATCH 24/28] remove redundant file only when remote artifact --- qlib/workflow/recorder.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/qlib/workflow/recorder.py b/qlib/workflow/recorder.py index 73da5e19f..13c4bc7a0 100644 --- a/qlib/workflow/recorder.py +++ b/qlib/workflow/recorder.py @@ -11,6 +11,7 @@ from datetime import datetime from qlib.utils.exceptions import LoadObjectError from ..utils.objm import FileManager from ..log import get_module_logger +from mlflow.store.artifact.azure_blob_artifact_repo import AzureBlobArtifactRepository logger = get_module_logger("workflow", logging.INFO) @@ -335,7 +336,11 @@ class MLflowRecorder(Recorder): path = self.client.download_artifacts(self.id, name) with Path(path).open("rb") as f: data = pickle.load(f) - os.remove(path) + ar = self.client._tracking_client._get_artifact_repo(self.id) + if isinstance(ar, AzureBlobArtifactRepository): + # for saving disk space + # For safety, only remove redundant file for specific ArtifactRepository + shutil.rmtree(Path(path).absolute().parent) return data except Exception as e: raise LoadObjectError(message=str(e)) From 25931857210170fcbf06a7d302ddb40f2a513bdc Mon Sep 17 00:00:00 2001 From: Young Date: Tue, 2 Nov 2021 11:03:23 +0800 Subject: [PATCH 25/28] Simplify TSDataset and async recorder --- qlib/data/dataset/__init__.py | 24 +++++++------- qlib/utils/paral.py | 59 ++++++++++++++++++++++++++++++++++- qlib/workflow/recorder.py | 13 +++++++- 3 files changed, 82 insertions(+), 14 deletions(-) diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index 7ad5f4c6d..5cc7d3c2d 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -524,20 +524,18 @@ class TSDatasetH(DatasetH): def setup_data(self, **kwargs): super().setup_data(**kwargs) + # make sure the calendar is updated to latest when loading data from new config cal = self.handler.fetch(col_set=self.handler.CS_RAW).index.get_level_values("datetime").unique() - cal = sorted(cal) - self.cal = cal + self.cal = sorted(cal) - def _prepare_raw_seg(self, slc: slice, **kwargs) -> pd.DataFrame: + @staticmethod + def _extend_slice(slc: slice, cal: list, step_len: int) -> slice: # Dataset decide how to slice data(Get more data for timeseries). start, end = slc.start, slc.stop - start_idx = bisect.bisect_left(self.cal, pd.Timestamp(start)) - pad_start_idx = max(0, start_idx - self.step_len) - pad_start = self.cal[pad_start_idx] - - # TSDatasetH will retrieve more data for complete - data = super()._prepare_seg(slice(pad_start, end), **kwargs) - return data + start_idx = bisect.bisect_left(cal, pd.Timestamp(start)) + pad_start_idx = max(0, start_idx - step_len) + pad_start = cal[pad_start_idx] + return slice(pad_start, end) def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler: """ @@ -547,12 +545,14 @@ class TSDatasetH(DatasetH): start, end = slc.start, slc.stop flt_col = kwargs.pop("flt_col", None) # TSDatasetH will retrieve more data for complete time-series - data = self._prepare_raw_seg(slc, **kwargs) + + ext_slice = self._extend_slice(slc, self.cal, self.step_len) + data = super()._prepare_seg(ext_slice, **kwargs) flt_kwargs = deepcopy(kwargs) if flt_col is not None: flt_kwargs["col_set"] = flt_col - flt_data = self._prepare_raw_seg(slc, **flt_kwargs) + flt_data = self._prepare_seg(ext_slice, **flt_kwargs) assert len(flt_data.columns) == 1 else: flt_data = None diff --git a/qlib/utils/paral.py b/qlib/utils/paral.py index 075a1adb8..48b427a28 100644 --- a/qlib/utils/paral.py +++ b/qlib/utils/paral.py @@ -1,9 +1,15 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import pandas as pd +from functools import partial +from threading import Thread +from typing import Callable + from joblib import Parallel, delayed from joblib._parallel_backends import MultiprocessingBackend +import pandas as pd + +from queue import Queue class ParallelExt(Parallel): @@ -46,3 +52,54 @@ def datetime_groupby_apply(df, apply_func, axis=0, level="datetime", resample_ru return pd.concat(dfs, axis=axis).sort_index() else: return _naive_group_apply(df) + + +class AsyncCaller: + """ + This AsyncCaller tries to make it easier to async call + + Currently, it is used in MLflowRecorder to make functions like `log_params` async + + NOTE: + - This caller didn't consider the return value + """ + + STOP_MARK = "__STOP" + + def __init__(self) -> None: + self._q = Queue() + self._stop = False + self._t = Thread(target=self.run) + self._t.start() + + def close(self): + self._q.put(self.STOP_MARK) + + def run(self): + while True: + data = self._q.get() + if data == self.STOP_MARK: + break + else: + data() + + def __call__(self, func, *args, **kwargs): + self._q.put(partial(func, *args, **kwargs)) + + def wait(self, close=True): + if close: + self.close() + self._t.join() + + @staticmethod + def async_dec(ac_attr): + def decorator_func(func): + def wrapper(self, *args, **kwargs): + if isinstance(getattr(self, ac_attr, None), Callable): + return getattr(self, ac_attr)(func, self, *args, **kwargs) + else: + return func(self, *args, **kwargs) + + return wrapper + + return decorator_func diff --git a/qlib/workflow/recorder.py b/qlib/workflow/recorder.py index 13c4bc7a0..2fff37eaa 100644 --- a/qlib/workflow/recorder.py +++ b/qlib/workflow/recorder.py @@ -9,8 +9,9 @@ from pathlib import Path from datetime import datetime from qlib.utils.exceptions import LoadObjectError +from qlib.utils.paral import AsyncCaller from ..utils.objm import FileManager -from ..log import get_module_logger +from ..log import TimeInspector, get_module_logger from mlflow.store.artifact.azure_blob_artifact_repo import AzureBlobArtifactRepository logger = get_module_logger("workflow", logging.INFO) @@ -229,6 +230,7 @@ class MLflowRecorder(Recorder): if mlflow_run.info.end_time is not None else None ) + self.async_log = None def __repr__(self): name = self.__class__.__name__ @@ -287,6 +289,10 @@ class MLflowRecorder(Recorder): self.status = Recorder.STATUS_R logger.info(f"Recorder {self.id} starts running under Experiment {self.experiment_id} ...") + # NOTE: making logging async. + # - This may cause delay when uploading results + # - The logging time may not be accurate + self.async_log = AsyncCaller() return run def end_run(self, status: str = Recorder.STATUS_S): @@ -300,6 +306,8 @@ class MLflowRecorder(Recorder): self.end_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S") if self.status != Recorder.STATUS_S: self.status = status + with TimeInspector.logt("waiting `async_log`"): + self.async_log.wait() def save_objects(self, local_path=None, artifact_path=None, **kwargs): assert self.uri is not None, "Please start the experiment and recorder first before using recorder directly." @@ -345,14 +353,17 @@ class MLflowRecorder(Recorder): except Exception as e: raise LoadObjectError(message=str(e)) + @AsyncCaller.async_dec(ac_attr="async_log") def log_params(self, **kwargs): for name, data in kwargs.items(): self.client.log_param(self.id, name, data) + @AsyncCaller.async_dec(ac_attr="async_log") def log_metrics(self, step=None, **kwargs): for name, data in kwargs.items(): self.client.log_metric(self.id, name, data, step=step) + @AsyncCaller.async_dec(ac_attr="async_log") def set_tags(self, **kwargs): for name, data in kwargs.items(): self.client.set_tag(self.id, name, data) From 3943b7001fcaa3ac484c0b8e0a6cd7a8dc4a5bb4 Mon Sep 17 00:00:00 2001 From: Young Date: Tue, 2 Nov 2021 14:32:09 +0800 Subject: [PATCH 26/28] fix CI bug for AyncCaller --- qlib/workflow/recorder.py | 1 + 1 file changed, 1 insertion(+) diff --git a/qlib/workflow/recorder.py b/qlib/workflow/recorder.py index 2fff37eaa..056d75be1 100644 --- a/qlib/workflow/recorder.py +++ b/qlib/workflow/recorder.py @@ -308,6 +308,7 @@ class MLflowRecorder(Recorder): self.status = status with TimeInspector.logt("waiting `async_log`"): self.async_log.wait() + self.async_log = None def save_objects(self, local_path=None, artifact_path=None, **kwargs): assert self.uri is not None, "Please start the experiment and recorder first before using recorder directly." From 4f2d6b0d849d572ec69243ad7d51977864027324 Mon Sep 17 00:00:00 2001 From: Young Date: Tue, 2 Nov 2021 20:41:39 +0800 Subject: [PATCH 27/28] fix pytorch memory amount error --- qlib/data/dataset/__init__.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index 5cc7d3c2d..46b90402d 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -320,6 +320,7 @@ class TSDataSampler: self.flt_data = flt_data.values self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map) self.data_index = self.data_index[np.where(self.flt_data == True)[0]] + self.idx_map = self.idx_map2arr(self.idx_map) self.start_idx, self.end_idx = self.data_index.slice_locs( start=time_to_slc_point(start), end=time_to_slc_point(end) @@ -328,6 +329,25 @@ class TSDataSampler: del self.data # save memory + @staticmethod + def idx_map2arr(idx_map): + # pytorch data sampler will have better memory control without large dict or list + # - https://github.com/pytorch/pytorch/issues/13243 + # - https://github.com/airctic/icevision/issues/613 + # So we convert the dict into int array. + # The arr_map is expected to behave the same as idx_map + + dtype = np.int32 + # set a index out of bound to indicate the none existing + no_existing_idx = (np.iinfo(dtype).max, np.iinfo(dtype).max) + + max_idx = max(idx_map.keys()) + arr_map = [] + for i in range(max_idx + 1): + arr_map.append(idx_map.get(i, no_existing_idx)) + arr_map = np.array(arr_map, dtype=dtype) + return arr_map + @staticmethod def flt_idx_map(flt_data, idx_map): idx = 0 From 3fa48d7017b94dab2c458131295d0da0a02a362f Mon Sep 17 00:00:00 2001 From: Young Date: Fri, 5 Nov 2021 11:34:21 +0000 Subject: [PATCH 28/28] simplify record tmp --- qlib/contrib/workflow/record_temp.py | 7 +- qlib/workflow/record_temp.py | 139 +++++++++++++++------------ tests/test_all_pipeline.py | 14 +-- 3 files changed, 90 insertions(+), 70 deletions(-) diff --git a/qlib/contrib/workflow/record_temp.py b/qlib/contrib/workflow/record_temp.py index e7c80cf6e..8d10b2ab4 100644 --- a/qlib/contrib/workflow/record_temp.py +++ b/qlib/contrib/workflow/record_temp.py @@ -49,7 +49,7 @@ class MultiSegRecord(RecordTemp): if save: save_name = "results-{:}.pkl".format(key) - self.recorder.save_objects(**{save_name: results}) + self.save(**{save_name: results}) logger.info( "The record '{:}' has been saved as the artifact of the Experiment {:}".format( save_name, self.recorder.experiment_id @@ -79,9 +79,8 @@ class SignalMseRecord(RecordTemp): metrics = {"MSE": mse, "RMSE": np.sqrt(mse)} objects = {"mse.pkl": mse, "rmse.pkl": np.sqrt(mse)} self.recorder.log_metrics(**metrics) - self.recorder.save_objects(**objects, artifact_path=self.get_path()) + self.save(**objects) logger.info("The evaluation results in SignalMseRecord is {:}".format(metrics)) def list(self): - paths = [self.get_path("mse.pkl"), self.get_path("rmse.pkl")] - return paths + return ["mse.pkl", "rmse.pkl"] diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index 0d85311ee..07422243d 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -9,6 +9,9 @@ import pandas as pd from pathlib import Path from pprint import pprint from typing import Union, List +from collections import defaultdict + +from qlib.utils.exceptions import LoadObjectError from ..contrib.evaluate import indicator_analysis, risk_analysis, indicator_analysis from ..data.dataset import DatasetH @@ -45,6 +48,16 @@ class RecordTemp: return "/".join(names) + def save(self, **kwargs): + """ + It behaves the same as self.recorder.save_objects. + But it is an easier interface because users don't have to care about `get_path` and `artifact_path` + """ + art_path = self.get_path() + if art_path == "": + art_path = None + self.recorder.save_objects(artifact_path=art_path, **kwargs) + def __init__(self, recorder): self._recorder = recorder @@ -67,31 +80,37 @@ class RecordTemp: """ raise NotImplementedError(f"Please implement the `generate` method.") - def load(self, name): + def load(self, name: str, parents: bool = True): """ - Load the stored records. Due to the fact that some problems occured when we tried to balancing a clean API - with the Python's inheritance. This method has to be used in a rather ugly way, and we will try to fix them - in the future:: - - sar = SigAnaRecord(recorder) - ic = sar.load(sar.get_path("ic.pkl")) + It behaves the same as self.recorder.load_object. + But it is an easier interface because users don't have to care about `get_path` and `artifact_path` Parameters ---------- name : str the name for the file to be load. + parents : bool + Each recorder has different `artifact_path`. + So parents recursively find the path in parents + Sub classes has higher priority + Return ------ The stored records. """ - # try to load the saved object - obj = self.recorder.load_object(name) - return obj + try: + return self.recorder.load_object(self.get_path(name)) + except LoadObjectError: + if parents: + if self.depend_cls is not None: + with class_casting(self, self.depend_cls): + return self.load(name, parents=True) def list(self): """ List the supported artifacts. + Users don't have to consider self.get_path Return ------ @@ -99,7 +118,7 @@ class RecordTemp: """ return [] - def check(self, include_self: bool = False): + def check(self, include_self: bool = False, parents: bool = True): """ Check if the records is properly generated and saved. It is useful in following examples @@ -110,19 +129,34 @@ class RecordTemp: ---------- include_self : bool is the file generated by self included + parents : bool + will we check parents Raise ------ - FileExistsError: whether the records are stored properly. + FileNotFoundError + : whether the records are stored properly. """ - artifacts = set(self.recorder.list_artifacts()) if include_self: + + # Some mlflow backend will not list the directly recursively. + # So we force to the directly + artifacts = {} + + def _get_arts(dirn): + if dirn not in artifacts: + artifacts[dirn] = self.recorder.list_artifacts(dirn) + return artifacts[dirn] + for item in self.list(): - if item not in artifacts: - raise FileExistsError(item) - if self.depend_cls is not None: - with class_casting(self, self.depend_cls): - self.check(include_self=True) + ps = self.get_path(item).split("/") + dirn, fn = "/".join(ps[:-1]), ps[-1] + if self.get_path(item) not in _get_arts(dirn): + raise FileNotFoundError + if parents: + if self.depend_cls is not None: + with class_casting(self, self.depend_cls): + self.check(include_self=True) class SignalRecord(RecordTemp): @@ -158,7 +192,7 @@ class SignalRecord(RecordTemp): pred = self.model.predict(self.dataset) if isinstance(pred, pd.Series): pred = pred.to_frame("score") - self.recorder.save_objects(**{"pred.pkl": pred}) + self.save(**{"pred.pkl": pred}) logger.info( f"Signal record 'pred.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}" @@ -169,15 +203,11 @@ class SignalRecord(RecordTemp): if isinstance(self.dataset, DatasetH): raw_label = self.generate_label(self.dataset) - self.recorder.save_objects(**{"label.pkl": raw_label}) + self.save(**{"label.pkl": raw_label}) - @staticmethod - def list(): + def list(self): return ["pred.pkl", "label.pkl"] - def load(self, name="pred.pkl"): - return super().load(name) - class HFSignalRecord(SignalRecord): """ @@ -218,19 +248,11 @@ class HFSignalRecord(SignalRecord): } ) self.recorder.log_metrics(**metrics) - self.recorder.save_objects(**objects, artifact_path=self.get_path()) + self.save(**objects) pprint(metrics) def list(self): - paths = [ - self.get_path("ic.pkl"), - self.get_path("ric.pkl"), - self.get_path("long_pre.pkl"), - self.get_path("short_pre.pkl"), - self.get_path("long_short_r.pkl"), - self.get_path("long_avg_r.pkl"), - ] - return paths + return ["ic.pkl", "ric.pkl", "long_pre.pkl", "short_pre.pkl", "long_short_r.pkl", "long_avg_r.pkl"] class SigAnaRecord(RecordTemp): @@ -241,13 +263,23 @@ class SigAnaRecord(RecordTemp): artifact_path = "sig_analysis" depend_cls = SignalRecord - def __init__(self, recorder, ana_long_short=False, ann_scaler=252, label_col=0): + def __init__(self, recorder, ana_long_short=False, ann_scaler=252, label_col=0, skip_existing=False): super().__init__(recorder=recorder) self.ana_long_short = ana_long_short self.ann_scaler = ann_scaler self.label_col = label_col + self.skip_existing = skip_existing def generate(self, **kwargs): + if self.skip_existing: + try: + self.check(include_self=True, parents=False) + except FileNotFoundError: + pass # continue to generating metrics + else: + logger.info("The results has previously generated, generation skipped.") + return + self.check() pred = self.load("pred.pkl") @@ -280,13 +312,13 @@ class SigAnaRecord(RecordTemp): } ) self.recorder.log_metrics(**metrics) - self.recorder.save_objects(**objects, artifact_path=self.get_path()) + self.save(**objects) pprint(metrics) def list(self): - paths = [self.get_path("ic.pkl"), self.get_path("ric.pkl")] + paths = ["ic.pkl", "ric.pkl"] if self.ana_long_short: - paths.extend([self.get_path("long_short_r.pkl"), self.get_path("long_avg_r.pkl")]) + paths.extend(["long_short_r.pkl", "long_avg_r.pkl"]) return paths @@ -373,17 +405,11 @@ class PortAnaRecord(RecordTemp): executor=self.executor_config, strategy=self.strategy_config, **self.backtest_config ) for _freq, (report_normal, positions_normal) in portfolio_metric_dict.items(): - self.recorder.save_objects( - **{f"report_normal_{_freq}.pkl": report_normal}, artifact_path=PortAnaRecord.get_path() - ) - self.recorder.save_objects( - **{f"positions_normal_{_freq}.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path() - ) + self.save(**{f"report_normal_{_freq}.pkl": report_normal}) + self.save(**{f"positions_normal_{_freq}.pkl": positions_normal}) for _freq, indicators_normal in indicator_dict.items(): - self.recorder.save_objects( - **{f"indicators_normal_{_freq}.pkl": indicators_normal}, artifact_path=PortAnaRecord.get_path() - ) + self.save(**{f"indicators_normal_{_freq}.pkl": indicators_normal}) for _analysis_freq in self.risk_analysis_freq: if _analysis_freq not in portfolio_metric_dict: @@ -405,9 +431,7 @@ class PortAnaRecord(RecordTemp): analysis_dict = flatten_dict(analysis_df["risk"].unstack().T.to_dict()) self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()}) # save results - self.recorder.save_objects( - **{f"port_analysis_{_analysis_freq}.pkl": analysis_df}, artifact_path=PortAnaRecord.get_path() - ) + self.save(**{f"port_analysis_{_analysis_freq}.pkl": analysis_df}) logger.info( f"Portfolio analysis record 'port_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}" ) @@ -432,9 +456,7 @@ class PortAnaRecord(RecordTemp): analysis_dict = analysis_df["value"].to_dict() self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()}) # save results - self.recorder.save_objects( - **{f"indicator_analysis_{_analysis_freq}.pkl": analysis_df}, artifact_path=PortAnaRecord.get_path() - ) + self.save(**{f"indicator_analysis_{_analysis_freq}.pkl": analysis_df}) logger.info( f"Indicator analysis record 'indicator_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}" ) @@ -446,20 +468,19 @@ class PortAnaRecord(RecordTemp): for _freq in self.all_freq: list_path.extend( [ - PortAnaRecord.get_path(f"report_normal_{_freq}.pkl"), - PortAnaRecord.get_path(f"positions_normal_{_freq}.pkl"), + f"report_normal_{_freq}.pkl", + f"positions_normal_{_freq}.pkl", ] ) for _analysis_freq in self.risk_analysis_freq: if _analysis_freq in self.all_freq: - list_path.append(PortAnaRecord.get_path(f"port_analysis_{_analysis_freq}.pkl")) + list_path.append(f"port_analysis_{_analysis_freq}.pkl") else: warnings.warn(f"risk_analysis freq {_analysis_freq} is not found") for _analysis_freq in self.indicator_analysis_freq: if _analysis_freq in self.all_freq: - list_path.append(PortAnaRecord.get_path(f"indicator_analysis_{_analysis_freq}.pkl")) + list_path.append(f"indicator_analysis_{_analysis_freq}.pkl") else: warnings.warn(f"indicator_analysis freq {_analysis_freq} is not found") - return list_path diff --git a/tests/test_all_pipeline.py b/tests/test_all_pipeline.py index 24c6765aa..de15d8722 100644 --- a/tests/test_all_pipeline.py +++ b/tests/test_all_pipeline.py @@ -47,13 +47,13 @@ def train(uri_path: str = None): rid = recorder.id sr = SignalRecord(model, dataset, recorder) sr.generate() - pred_score = sr.load(sr.get_path("pred.pkl")) + pred_score = sr.load("pred.pkl") # calculate ic and ric sar = SigAnaRecord(recorder) sar.generate() - ic = sar.load(sar.get_path("ic.pkl")) - ric = sar.load(sar.get_path("ric.pkl")) + ic = sar.load("ic.pkl") + ric = sar.load("ric.pkl") return pred_score, {"ic": ic, "ric": ric}, rid @@ -78,13 +78,13 @@ def train_with_sigana(uri_path: str = None): sr = SignalRecord(model, dataset, recorder) sr.generate() - pred_score = sr.load(sr.get_path("pred.pkl")) + pred_score = sr.load("pred.pkl") # predict and calculate ic and ric sar = SigAnaRecord(recorder) sar.generate() - ic = sar.load(sar.get_path("ic.pkl")) - ric = sar.load(sar.get_path("ric.pkl")) + ic = sar.load("ic.pkl") + ric = sar.load("ric.pkl") uri_path = R.get_uri() return pred_score, {"ic": ic, "ric": ric}, uri_path @@ -169,7 +169,7 @@ def backtest_analysis(pred, rid, uri_path: str = None): # backtest par = PortAnaRecord(recorder, port_analysis_config, risk_analysis_freq="day") par.generate() - analysis_df = par.load(par.get_path("port_analysis_1day.pkl")) + analysis_df = par.load("port_analysis_1day.pkl") print(analysis_df) return analysis_df