1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00

Merge pull request #650 from microsoft/backtest_improve

Improve the backtest design and APIs
This commit is contained in:
you-n-g
2021-11-08 09:10:33 +08:00
committed by GitHub
65 changed files with 947 additions and 286 deletions

View File

@@ -159,6 +159,21 @@ Version 0.5.0
- Add baselines
- public data crawler
Version greater than Version 0.5.0
Version 0.8.0
--------------------
- The backtest is greatly refactored.
- Nested decision execution framework is supported
- There are lots of changes for daily trading, it is hard to list all of them. But a few important changes could be noticed
- The trading limitation is more accurate;
- In `previous version <https://github.com/microsoft/qlib/blob/v0.7.2/qlib/contrib/backtest/exchange.py#L160>`_, longing and shorting actions share the same action.
- In `current verison <https://github.com/microsoft/qlib/blob/7c31012b507a3823117bddcc693fc64899460b2a/qlib/backtest/exchange.py#L304>`_, the trading limitation is different between loging and shorting action.
- The constant is different when calculating annualized metrics.
- `Current version <https://github.com/microsoft/qlib/blob/7c31012b507a3823117bddcc693fc64899460b2a/qlib/contrib/evaluate.py#L42>`_ uses more accurate constant than `previous version <https://github.com/microsoft/qlib/blob/v0.7.2/qlib/contrib/evaluate.py#L22>`_
- `A new version <https://github.com/microsoft/qlib/blob/7c31012b507a3823117bddcc693fc64899460b2a/qlib/tests/data.py#L17>`_ of data is released. Due to the unstability of Yahoo data source, the data may be different after downloading data again.
- Users could chec kout the backtesting results between `Current version <https://github.com/microsoft/qlib/tree/7c31012b507a3823117bddcc693fc64899460b2a/examples/benchmarks>`_ and `previous version <https://github.com/microsoft/qlib/tree/v0.7.2/examples/benchmarks>`_
Other Versions
----------------------------------
Please refer to `Github release Notes <https://github.com/microsoft/qlib/releases>`_

View File

@@ -53,6 +53,9 @@ Below is a typical config file of ``qrun``.
kwargs:
topk: 50
n_drop: 5
signal:
- <MODEL>
- <DATASET>
backtest:
limit_threshold: 0.095
account: 100000000
@@ -240,6 +243,9 @@ The following script is the configuration of `backtest` and the `strategy` used
kwargs:
topk: 50
n_drop: 5
signal:
- <MODEL>
- <DATASET>
backtest:
limit_threshold: 0.095
account: 100000000

View File

@@ -36,8 +36,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:
@@ -86,4 +87,4 @@ task:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
config: *port_analysis_config

View File

@@ -14,8 +14,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -21,8 +21,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -14,8 +14,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -21,8 +21,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:
@@ -100,4 +101,4 @@ task:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
config: *port_analysis_config

View File

@@ -35,8 +35,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:
@@ -94,4 +95,4 @@ task:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
config: *port_analysis_config

View File

@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -36,8 +36,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:
@@ -85,4 +86,4 @@ task:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
config: *port_analysis_config

View File

@@ -36,8 +36,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:
@@ -85,4 +86,4 @@ task:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
config: *port_analysis_config

View File

@@ -14,7 +14,7 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
model: <MODEL>
dataset: <DATASET>
topk: 50
n_drop: 5

View File

@@ -33,6 +33,9 @@ port_analysis_config: &port_analysis_config
kwargs:
topk: 50
n_drop: 5
signal:
- <MODEL>
- <DATASET>
backtest:
verbose: False
limit_threshold: 0.095
@@ -80,4 +83,4 @@ task:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
config: *port_analysis_config

View File

@@ -21,8 +21,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:
@@ -76,4 +77,4 @@ task:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
config: *port_analysis_config

View File

@@ -29,8 +29,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -31,8 +31,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -36,8 +36,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -41,8 +41,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:
@@ -98,4 +99,4 @@ task:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
config: *port_analysis_config

View File

@@ -29,8 +29,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:
@@ -85,4 +86,4 @@ task:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
config: *port_analysis_config

View File

@@ -13,6 +13,10 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
>
> In the new version of qlib, the default dataset is **v2**. Since the data is collected from the YahooFinance API (which is not very stable), the results of *v2* and *v1* may differ -->
> NOTE:
> The backtest start from 0.8.0 is quite different from previous version. Please check out the changelog for the difference.
## Alpha158 dataset
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |

View File

@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -30,8 +30,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:
@@ -94,4 +95,4 @@ task:
- class: PortAnaRecord
module_path: qlib.workflow.record_temp
kwargs:
config: *port_analysis_config
config: *port_analysis_config

View File

@@ -16,8 +16,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -57,8 +57,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -51,8 +51,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -51,8 +51,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -36,8 +36,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -28,8 +28,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -14,8 +14,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -21,8 +21,9 @@ port_analysis_config: &port_analysis_config
class: TopkDropoutStrategy
module_path: qlib.contrib.strategy
kwargs:
model: <MODEL>
dataset: <DATASET>
signal:
- <MODEL>
- <DATASET>
topk: 50
n_drop: 5
backtest:

View File

@@ -151,10 +151,9 @@ class NestedDecisionExecutionWorkflow:
self._train_model(model, dataset)
strategy_config = {
"class": "TopkDropoutStrategy",
"module_path": "qlib.contrib.strategy.model_strategy",
"module_path": "qlib.contrib.strategy.signal_strategy",
"kwargs": {
"model": model,
"dataset": dataset,
"signal": (model, dataset),
"topk": 50,
"n_drop": 5,
},
@@ -189,10 +188,9 @@ class NestedDecisionExecutionWorkflow:
backtest_config["benchmark"] = self.benchmark
strategy_config = {
"class": "TopkDropoutStrategy",
"module_path": "qlib.contrib.strategy.model_strategy",
"module_path": "qlib.contrib.strategy.signal_strategy",
"kwargs": {
"model": model,
"dataset": dataset,
"signal": (model, dataset),
"topk": 50,
"n_drop": 5,
},

View File

@@ -31,10 +31,9 @@ if __name__ == "__main__":
},
"strategy": {
"class": "TopkDropoutStrategy",
"module_path": "qlib.contrib.strategy.model_strategy",
"module_path": "qlib.contrib.strategy.signal_strategy",
"kwargs": {
"model": model,
"dataset": dataset,
"signal": (model, dataset),
"topk": 50,
"n_drop": 5,
},

View File

@@ -152,8 +152,11 @@ def init_from_yaml_conf(conf_path, **kwargs):
:param conf_path: A path to the qlib config in yml format
"""
with open(conf_path) as f:
config = yaml.safe_load(f)
if conf_path is None:
config = {}
else:
with open(conf_path) as f:
config = yaml.safe_load(f)
config.update(kwargs)
default_conf = config.pop("default_conf", "client")
init(default_conf, **config)
@@ -216,7 +219,7 @@ def auto_init(**kwargs):
.. code-block:: yaml
conf_type: ref
qlib_cfg: '<shared_yaml_config_path>'
qlib_cfg: '<shared_yaml_config_path>' # this could be null reference no config from other files
# following configs in `qlib_cfg_update` is project=specific
qlib_cfg_update:
exp_manager:
@@ -246,6 +249,7 @@ def auto_init(**kwargs):
except FileNotFoundError:
init(**kwargs)
else:
logger = get_module_logger("Initialization")
conf_pp = pp / "config.yaml"
with conf_pp.open() as f:
conf = yaml.safe_load(f)
@@ -259,8 +263,14 @@ def auto_init(**kwargs):
# - There is a shared configure file and you don't want to edit it inplace.
# - The shared configure may be updated later and you don't want to copy it.
# - You have some customized config.
qlib_conf_path = conf["qlib_cfg"]
qlib_conf_update = conf.get("qlib_cfg_update")
init_from_yaml_conf(qlib_conf_path, **qlib_conf_update, **kwargs)
logger = get_module_logger("Initialization")
qlib_conf_path = conf.get("qlib_cfg", None)
# merge the arguments
qlib_conf_update = conf.get("qlib_cfg_update", {})
for k, v in kwargs.items():
if k in qlib_conf_update:
logger.warning(f"`qlib_conf_update` from conf_pp is override by `kwargs` on key '{k}'")
qlib_conf_update.update(kwargs)
init_from_yaml_conf(qlib_conf_path, **qlib_conf_update)
logger.info(f"Auto load project config: {conf_pp}")

View File

@@ -34,6 +34,7 @@ class Exchange:
open_cost=0.0015,
close_cost=0.0025,
min_cost=5,
impact_cost=0.0,
extra_quote=None,
quote_cls=NumpyQuote,
**kwargs,
@@ -95,6 +96,7 @@ class Exchange:
**NOTE**: `trade_unit` is included in the `kwargs`. It is necessary because we must
distinguish `not set` and `disable trade_unit`
:param min_cost: min cost, default 5
:param impact_cost: market impact cost rate (a.k.a. slippage). A recommended value is 0.1.
:param extra_quote: pandas, dataframe consists of
columns: like ['$vwap', '$close', '$volume', '$factor', 'limit_sell', 'limit_buy'].
The limit indicates that the etf is tradable on a specific day.
@@ -164,9 +166,12 @@ class Exchange:
all_fields = list(all_fields | set(subscribe_fields))
self.all_fields = all_fields
self.open_cost = open_cost
self.close_cost = close_cost
self.min_cost = min_cost
self.impact_cost = impact_cost
self.limit_threshold: Union[Tuple[str, str], float, None] = limit_threshold
self.volume_threshold = volume_threshold
self.extra_quote = extra_quote
@@ -685,12 +690,14 @@ class Exchange:
f"Order clipped due to volume limitation: {order}, {[(vol, rule) for vol, rule in zip(vol_limit_num, vol_limit)]}"
)
def _get_buy_amount_by_cash_limit(self, trade_price, cash):
def _get_buy_amount_by_cash_limit(self, trade_price, cash, cost_ratio):
"""return the real order amount after cash limit for buying.
Parameters
----------
trade_price : float
position : cash
cost_ratio : float
Return
----------
float
@@ -699,10 +706,10 @@ class Exchange:
max_trade_amount = 0
if cash >= self.min_cost:
# critical_price means the stock transaction price when the service fee is equal to min_cost.
critical_price = self.min_cost / self.open_cost + self.min_cost
critical_price = self.min_cost / cost_ratio + self.min_cost
if cash >= critical_price:
# the service fee is equal to open_cost * trade_amount
max_trade_amount = cash / (1 + self.open_cost) / trade_price
# the service fee is equal to cost_ratio * trade_amount
max_trade_amount = cash / (1 + cost_ratio) / trade_price
else:
# the service fee is equal to min_cost
max_trade_amount = (cash - self.min_cost) / trade_price
@@ -718,6 +725,7 @@ class Exchange:
:return: trade_price, trade_val, trade_cost
"""
trade_price = self.get_deal_price(order.stock_id, order.start_time, order.end_time, direction=order.direction)
total_trade_val = self.get_volume(order.stock_id, order.start_time, order.end_time) * trade_price
order.factor = self.get_factor(order.stock_id, order.start_time, order.end_time)
order.deal_amount = order.amount # set to full amount and clip it step by step
# Clipping amount first
@@ -726,8 +734,12 @@ class Exchange:
# - It simulates that the large order is submitted, but partial is dealt regardless of rounding by trading unit.
self._clip_amount_by_volume(order, dealt_order_amount)
# TODO: the adjusted cost ratio can be overestimated as deal_amount will be clipped in the next steps
trade_val = order.deal_amount * trade_price
adj_cost_ratio = self.impact_cost * (trade_val / total_trade_val) ** 2
if order.direction == Order.SELL:
cost_ratio = self.close_cost
cost_ratio = self.close_cost + adj_cost_ratio
# sell
# if we don't know current position, we choose to sell all
# Otherwise, we clip the amount based on current position
@@ -750,14 +762,18 @@ class Exchange:
self.logger.debug(f"Order clipped due to cash limitation: {order}")
elif order.direction == Order.BUY:
cost_ratio = self.open_cost
cost_ratio = self.open_cost + adj_cost_ratio
# buy
if position is not None:
cash = position.get_cash()
trade_val = order.deal_amount * trade_price
if cash < trade_val + max(trade_val * cost_ratio, self.min_cost):
if cash < max(trade_val * cost_ratio, self.min_cost):
# cash cannot cover cost
order.deal_amount = 0
self.logger.debug(f"Order clipped due to cost higher than cash: {order}")
elif cash < trade_val + max(trade_val * cost_ratio, self.min_cost):
# The money is not enough
max_buy_amount = self._get_buy_amount_by_cash_limit(trade_price, cash)
max_buy_amount = self._get_buy_amount_by_cash_limit(trade_price, cash, cost_ratio)
order.deal_amount = self.round_amount_by_trade_unit(
min(max_buy_amount, order.deal_amount), order.factor
)

View File

@@ -160,6 +160,11 @@ class NumpyQuote(BaseQuote):
if is_single_value(start_time, end_time, self.freq, self.region):
# this is a very special case.
# skip aggregating function to speed-up the query calculation
# FIXME:
# it will go to the else logic when it comes to the
# 1) the day before holiday when daily trading
# 2) the last minute of the day when intraday trading
try:
return self.data[stock_id].loc[start_time, field]
except KeyError:

102
qlib/backtest/signal.py Normal file
View File

@@ -0,0 +1,102 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from qlib.utils import init_instance_by_config
from typing import Dict, List, Text, Tuple, Union
from ..model.base import BaseModel
from ..data.dataset import Dataset
from ..data.dataset.utils import convert_index_format
from ..utils.resam import resam_ts_data
import pandas as pd
import abc
class Signal(metaclass=abc.ABCMeta):
"""
Some trading strategy make decisions based on other prediction signals
The signals may comes from different sources(e.g. prepared data, online prediction from model and dataset)
This interface is tries to provide unified interface for those different sources
"""
@abc.abstractmethod
def get_signal(self, start_time, end_time) -> Union[pd.Series, pd.DataFrame, None]:
"""
get the signal at the end of the decision step(from `start_time` to `end_time`)
Returns
-------
Union[pd.Series, pd.DataFrame, None]:
returns None if no signal in the specific day
"""
...
class SignalWCache(Signal):
"""
Signal With pandas with based Cache
SignalWCache will store the prepared signal as a attribute and give the according signal based on input query
"""
def __init__(self, signal: Union[pd.Series, pd.DataFrame]):
"""
Parameters
----------
signal : Union[pd.Series, pd.DataFrame]
The expected format of the signal is like the data below (the order of index is not important and can be automatically adjusted)
instrument datetime
SH600000 2008-01-02 0.079704
2008-01-03 0.120125
2008-01-04 0.878860
2008-01-07 0.505539
2008-01-08 0.395004
"""
self.signal_cache = convert_index_format(signal, level="datetime")
def get_signal(self, start_time, end_time) -> Union[pd.Series, pd.DataFrame]:
# the frequency of the signal may not algin with the decision frequency of strategy
# so resampling from the data is necessary
# the latest signal leverage more recent data and therefore is used in trading.
signal = resam_ts_data(self.signal_cache, start_time=start_time, end_time=end_time, method="last")
return signal
class ModelSignal(SignalWCache):
def __init__(self, model: BaseModel, dataset: Dataset):
self.model = model
self.dataset = dataset
pred_scores = self.model.predict(dataset)
if isinstance(pred_scores, pd.DataFrame):
pred_scores = pred_scores.iloc[:, 0]
super().__init__(pred_scores)
def _update_model(self):
"""
When using online data, update model in each bar as the following steps:
- update dataset with online data, the dataset should support online update
- make the latest prediction scores of the new bar
- update the pred score into the latest prediction
"""
# TODO: this method is not included in the framework and could be refactor later
raise NotImplementedError("_update_model is not implemented!")
def create_signal_from(
obj: Union[Signal, Tuple[BaseModel, Dataset], List, Dict, Text, pd.Series, pd.DataFrame]
) -> Signal:
"""
create signal from diverse information
This method will choose the right method to create a signal based on `obj`
Please refer to the code below.
"""
if isinstance(obj, Signal):
return obj
elif isinstance(obj, (tuple, list)):
return ModelSignal(*obj)
elif isinstance(obj, (dict, str)):
return init_instance_by_config(obj)
elif isinstance(obj, (pd.DataFrame, pd.Series)):
return SignalWCache(signal=obj)
else:
raise NotImplementedError(f"This type of signal is not supported")

View File

@@ -70,7 +70,7 @@ class TradeCalendarManager:
- If self.trade_step >= self.self.trade_len, it means the trading is finished
- If self.trade_step < self.self.trade_len, it means the number of trading step finished is self.trade_step
"""
return self.trade_step >= self.trade_len
return self.trade_step >= self.trade_len - 1
def step(self):
if self.finished():

View File

View File

@@ -0,0 +1,183 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import pandas as pd
from typing import Dict, Iterable
def align_index(df_dict, join):
res = {}
for k, df in df_dict.items():
if join is not None and k != join:
df = df.reindex(df_dict[join].index)
res[k] = df
return res
# Mocking the pd.DataFrame class
class SepDataFrame:
"""
(Sep)erate DataFrame
We usually concat multiple dataframe to be processed together(Such as feature, label, weight, filter).
However, they are usally be used seperately at last.
This will result in extra cost for concating and spliting data(reshaping and copying data in the memory is very expensive)
SepDataFrame tries to act like a DataFrame whose column with multiindex
"""
def __init__(self, df_dict: Dict[str, pd.DataFrame], join: str, skip_align=False):
"""
initialize the data based on the dataframe dictionary
Parameters
----------
df_dict : Dict[str, pd.DataFrame]
dataframe dictionary
join : str
how to join the data
It will reindex the dataframe based on the join key.
If join is None, the reindex step will be skipped
skip_align :
for some cases, we can improve performance by skipping aligning index
"""
self.join = join
if skip_align:
self._df_dict = df_dict
else:
self._df_dict = align_index(df_dict, join)
@property
def loc(self):
return SDFLoc(self, join=self.join)
@property
def index(self):
return self._df_dict[self.join].index
def apply_each(self, method: str, skip_align=True, *args, **kwargs):
"""
Assumptions:
- inplace methods will return None
"""
inplace = False
df_dict = {}
for k, df in self._df_dict.items():
df_dict[k] = getattr(df, method)(*args, **kwargs)
if df_dict[k] is None:
inplace = True
if not inplace:
return SepDataFrame(df_dict=df_dict, join=self.join, skip_align=skip_align)
def sort_index(self, *args, **kwargs):
return self.apply_each("sort_index", True, *args, **kwargs)
def copy(self, *args, **kwargs):
return self.apply_each("copy", True, *args, **kwargs)
def _update_join(self):
if self.join not in self:
self.join = next(iter(self._df_dict.keys()))
def __getitem__(self, item):
return self._df_dict[item]
def __setitem__(self, item: str, df: pd.DataFrame):
# TODO: consider the join behavior
self._df_dict[item] = df
def __delitem__(self, item: str):
del self._df_dict[item]
self._update_join()
def __contains__(self, item):
return item in self._df_dict
def __len__(self):
return len(self._df_dict[self.join])
def droplevel(self, *args, **kwargs):
raise NotImplementedError(f"Please implement the `droplevel` method")
@property
def columns(self):
dfs = []
for k, df in self._df_dict.items():
df = df.head(0)
df.columns = pd.MultiIndex.from_product([[k], df.columns])
dfs.append(df)
return pd.concat(dfs, axis=1).columns
# Useless methods
@staticmethod
def merge(df_dict: Dict[str, pd.DataFrame], join: str):
all_df = df_dict[join]
for k, df in df_dict.items():
if k != join:
all_df = all_df.join(df)
return all_df
class SDFLoc:
"""Mock Class"""
def __init__(self, sdf: SepDataFrame, join):
self._sdf = sdf
self.axis = None
self.join = join
def __call__(self, axis):
self.axis = axis
return self
def __getitem__(self, args):
if self.axis == 1:
if isinstance(args, str):
return self._sdf[args]
elif isinstance(args, (tuple, list)):
new_df_dict = {k: self._sdf[k] for k in args}
return SepDataFrame(new_df_dict, join=self.join if self.join in args else args[0], skip_align=True)
else:
raise NotImplementedError(f"This type of input is not supported")
elif self.axis == 0:
return SepDataFrame(
{k: df.loc(axis=0)[args] for k, df in self._sdf._df_dict.items()}, join=self.join, skip_align=True
)
else:
df = self._sdf
if isinstance(args, tuple):
ax0, *ax1 = args
if len(ax1) == 0:
ax1 = None
if ax1 is not None:
df = df.loc(axis=1)[ax1]
if ax0 is not None:
df = df.loc(axis=0)[ax0]
return df
else:
return df.loc(axis=0)[args]
# Patch pandas DataFrame
# Tricking isinstance to accept SepDataFrame as its subclass
import builtins
def _isinstance(instance, cls):
if isinstance_orig(instance, SepDataFrame): # pylint: disable=E0602
if isinstance(cls, Iterable):
for c in cls:
if c is pd.DataFrame:
return True
elif cls is pd.DataFrame:
return True
return isinstance_orig(instance, cls) # pylint: disable=E0602
builtins.isinstance_orig = builtins.isinstance
builtins.isinstance = _isinstance
if __name__ == "__main__":
sdf = SepDataFrame({}, join=None)
print(isinstance(sdf, (pd.DataFrame,)))
print(isinstance(sdf, pd.DataFrame))

View File

@@ -2,7 +2,7 @@
# Licensed under the MIT License.
from .model_strategy import (
from .signal_strategy import (
TopkDropoutStrategy,
WeightStrategyBase,
)

View File

@@ -6,7 +6,7 @@ This strategy is not well maintained
from .order_generator import OrderGenWInteract
from .model_strategy import WeightStrategyBase
from .signal_strategy import WeightStrategyBase
import copy

View File

@@ -80,18 +80,22 @@ class OrderGenWInteract(OrderGenerator):
:rtype: list
"""
if target_weight_position is None:
return []
# calculate current_tradable_value
current_amount_dict = current.get_stock_amount_dict()
current_total_value = trade_exchange.calculate_amount_position_value(
amount_dict=current_amount_dict,
trade_start_time=trade_start_time,
trade_end_time=trade_end_time,
start_time=trade_start_time,
end_time=trade_end_time,
only_tradable=False,
)
current_tradable_value = trade_exchange.calculate_amount_position_value(
amount_dict=current_amount_dict,
trade_start_time=trade_start_time,
trade_end_time=trade_end_time,
start_time=trade_start_time,
end_time=trade_end_time,
only_tradable=True,
)
# add cash
@@ -105,9 +109,7 @@ class OrderGenWInteract(OrderGenerator):
# value. Then just sell all the stocks
target_amount_dict = copy.deepcopy(current_amount_dict.copy())
for stock_id in list(target_amount_dict.keys()):
if trade_exchange.is_stock_tradable(
stock_id, trade_start_time=trade_start_time, trade_end_time=trade_end_time
):
if trade_exchange.is_stock_tradable(stock_id, start_time=trade_start_time, end_time=trade_end_time):
del target_amount_dict[stock_id]
else:
# consider cost rate
@@ -118,16 +120,16 @@ class OrderGenWInteract(OrderGenerator):
target_amount_dict = trade_exchange.generate_amount_position_from_weight_position(
weight_position=target_weight_position,
cash=current_tradable_value,
trade_start_time=trade_start_time,
trade_end_time=trade_end_time,
start_time=trade_start_time,
end_time=trade_end_time,
)
order_list = trade_exchange.generate_order_for_target_amount_position(
target_position=target_amount_dict,
current_position=current_amount_dict,
trade_start_time=trade_start_time,
trade_end_time=trade_end_time,
start_time=trade_start_time,
end_time=trade_end_time,
)
return TradeDecisionWO(order_list, self)
return order_list
class OrderGenWOInteract(OrderGenerator):
@@ -163,8 +165,11 @@ class OrderGenWOInteract(OrderGenerator):
:param trade_date:
:type trade_date: pd.Timestamp
:rtype: list
:rtype: list of generated orders
"""
if target_weight_position is None:
return []
risk_total_value = risk_degree * current.calculate_value()
current_stock = current.get_stock_list()
@@ -172,13 +177,17 @@ class OrderGenWOInteract(OrderGenerator):
for stock_id in target_weight_position:
# Current rule will ignore the stock that not hold and cannot be traded at predict date
if trade_exchange.is_stock_tradable(
stock_id=stock_id, trade_start_time=trade_start_time, trade_end_time=trade_end_time
stock_id=stock_id, start_time=trade_start_time, end_time=trade_end_time
) and trade_exchange.is_stock_tradable(
stock_id=stock_id, start_time=pred_start_time, end_time=pred_end_time
):
amount_dict[stock_id] = (
risk_total_value
* target_weight_position[stock_id]
/ trade_exchange.get_close(stock_id, trade_start_time=pred_start_time, trade_end_time=pred_end_time)
/ trade_exchange.get_close(stock_id, start_time=pred_start_time, end_time=pred_end_time)
)
# TODO: Qlib use None to represent trading suspension. So last close price can't be the estimated trading price.
# Maybe a close price with forward fill will be a better solution.
elif stock_id in current_stock:
amount_dict[stock_id] = (
risk_total_value * target_weight_position[stock_id] / current.get_stock_price(stock_id)
@@ -188,7 +197,7 @@ class OrderGenWOInteract(OrderGenerator):
order_list = trade_exchange.generate_order_for_target_amount_position(
target_position=amount_dict,
current_position=current.get_stock_amount_dict(),
trade_start_time=trade_start_time,
trade_end_time=trade_end_time,
start_time=trade_start_time,
end_time=trade_end_time,
)
return TradeDecisionWO(order_list, self)
return order_list

View File

@@ -1,3 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from pathlib import Path
import warnings
import numpy as np

View File

@@ -1,27 +1,33 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import copy
from qlib.backtest.signal import Signal, create_signal_from
from typing import Dict, List, Text, Tuple, Union
from qlib.data.dataset import Dataset
from qlib.model.base import BaseModel
from qlib.backtest.position import Position
import warnings
import numpy as np
import pandas as pd
from ...utils.resam import resam_ts_data
from ...strategy.base import ModelStrategy
from ...strategy.base import BaseStrategy
from ...backtest.decision import Order, BaseTradeDecision, OrderDir, TradeDecisionWO
from .order_generator import OrderGenWInteract
class TopkDropoutStrategy(ModelStrategy):
class TopkDropoutStrategy(BaseStrategy):
# TODO:
# 1. Supporting leverage the get_range_limit result from the decision
# 2. Supporting alter_outer_trade_decision
# 3. Supporting checking the availability of trade decision
def __init__(
self,
model,
dataset,
*,
topk,
n_drop,
signal: Union[Signal, Tuple[BaseModel, Dataset], List, Dict, Text, pd.Series, pd.DataFrame] = None,
method_sell="bottom",
method_buy="top",
risk_degree=0.95,
@@ -30,6 +36,8 @@ class TopkDropoutStrategy(ModelStrategy):
trade_exchange=None,
level_infra=None,
common_infra=None,
model=None,
dataset=None,
**kwargs,
):
"""
@@ -39,6 +47,9 @@ class TopkDropoutStrategy(ModelStrategy):
the number of stocks in the portfolio.
n_drop : int
number of stocks to be replaced in each trading date.
signal :
the information to describe a signal. Please refer to the docs of `qlib.backtest.signal.create_signal_from`
the decision of the strategy will base on the given signal
method_sell : str
dropout method_sell, random/bottom.
method_buy : str
@@ -64,7 +75,7 @@ class TopkDropoutStrategy(ModelStrategy):
"""
super(TopkDropoutStrategy, self).__init__(
model, dataset, level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs
level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs
)
self.topk = topk
self.n_drop = n_drop
@@ -74,6 +85,13 @@ class TopkDropoutStrategy(ModelStrategy):
self.hold_thresh = hold_thresh
self.only_tradable = only_tradable
# This is trying to be compatible with previous version of qlib task config
if model is not None and dataset is not None:
warnings.warn("`model` `dataset` is deprecated; use `signal`.", DeprecationWarning)
signal = model, dataset
self.signal: Signal = create_signal_from(signal)
def get_risk_degree(self, trade_step=None):
"""get_risk_degree
Return the proportion of your total value you will used in investment.
@@ -87,7 +105,7 @@ class TopkDropoutStrategy(ModelStrategy):
trade_step = self.trade_calendar.get_trade_step()
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)
pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
pred_score = self.signal.get_signal(start_time=pred_start_time, end_time=pred_end_time)
if pred_score is None:
return TradeDecisionWO([], self)
if self.only_tradable:
@@ -235,15 +253,15 @@ class TopkDropoutStrategy(ModelStrategy):
return TradeDecisionWO(sell_order_list + buy_order_list, self)
class WeightStrategyBase(ModelStrategy):
class WeightStrategyBase(BaseStrategy):
# TODO:
# 1. Supporting leverage the get_range_limit result from the decision
# 2. Supporting alter_outer_trade_decision
# 3. Supporting checking the availability of trade decision
def __init__(
self,
model,
dataset,
*,
signal: Union[Signal, Tuple[BaseModel, Dataset], List, Dict, Text, pd.Series, pd.DataFrame],
order_generator_cls_or_obj=OrderGenWInteract,
trade_exchange=None,
level_infra=None,
@@ -251,6 +269,9 @@ class WeightStrategyBase(ModelStrategy):
**kwargs,
):
"""
signal :
the information to describe a signal. Please refer to the docs of `qlib.backtest.signal.create_signal_from`
the decision of the strategy will base on the given signal
trade_exchange : Exchange
exchange that provides market info, used to deal order and generate report
- If `trade_exchange` is None, self.trade_exchange will be set with common_infra
@@ -260,13 +281,15 @@ class WeightStrategyBase(ModelStrategy):
- In minutely execution, the daily exchange is not usable, only the minutely exchange is recommended.
"""
super(WeightStrategyBase, self).__init__(
model, dataset, level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs
level_infra=level_infra, common_infra=common_infra, trade_exchange=trade_exchange, **kwargs
)
if isinstance(order_generator_cls_or_obj, type):
self.order_generator = order_generator_cls_or_obj()
else:
self.order_generator = order_generator_cls_or_obj
self.signal: Signal = create_signal_from(signal)
def get_risk_degree(self, trade_step=None):
"""get_risk_degree
Return the proportion of your total value you will used in investment.
@@ -298,7 +321,7 @@ class WeightStrategyBase(ModelStrategy):
trade_step = self.trade_calendar.get_trade_step()
trade_start_time, trade_end_time = self.trade_calendar.get_step_time(trade_step)
pred_start_time, pred_end_time = self.trade_calendar.get_step_time(trade_step, shift=1)
pred_score = resam_ts_data(self.pred_scores, start_time=pred_start_time, end_time=pred_end_time, method="last")
pred_score = self.signal.get_signal(start_time=pred_start_time, end_time=pred_end_time)
if pred_score is None:
return TradeDecisionWO([], self)
current_temp = copy.deepcopy(self.trade_position)

View File

@@ -49,7 +49,7 @@ class MultiSegRecord(RecordTemp):
if save:
save_name = "results-{:}.pkl".format(key)
self.recorder.save_objects(**{save_name: results})
self.save(**{save_name: results})
logger.info(
"The record '{:}' has been saved as the artifact of the Experiment {:}".format(
save_name, self.recorder.experiment_id
@@ -79,9 +79,8 @@ class SignalMseRecord(RecordTemp):
metrics = {"MSE": mse, "RMSE": np.sqrt(mse)}
objects = {"mse.pkl": mse, "rmse.pkl": np.sqrt(mse)}
self.recorder.log_metrics(**metrics)
self.recorder.save_objects(**objects, artifact_path=self.get_path())
self.save(**objects)
logger.info("The evaluation results in SignalMseRecord is {:}".format(metrics))
def list(self):
paths = [self.get_path("mse.pkl"), self.get_path("rmse.pkl")]
return paths
return ["mse.pkl", "rmse.pkl"]

View File

@@ -320,6 +320,7 @@ class TSDataSampler:
self.flt_data = flt_data.values
self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map)
self.data_index = self.data_index[np.where(self.flt_data == True)[0]]
self.idx_map = self.idx_map2arr(self.idx_map)
self.start_idx, self.end_idx = self.data_index.slice_locs(
start=time_to_slc_point(start), end=time_to_slc_point(end)
@@ -328,6 +329,25 @@ class TSDataSampler:
del self.data # save memory
@staticmethod
def idx_map2arr(idx_map):
# pytorch data sampler will have better memory control without large dict or list
# - https://github.com/pytorch/pytorch/issues/13243
# - https://github.com/airctic/icevision/issues/613
# So we convert the dict into int array.
# The arr_map is expected to behave the same as idx_map
dtype = np.int32
# set a index out of bound to indicate the none existing
no_existing_idx = (np.iinfo(dtype).max, np.iinfo(dtype).max)
max_idx = max(idx_map.keys())
arr_map = []
for i in range(max_idx + 1):
arr_map.append(idx_map.get(i, no_existing_idx))
arr_map = np.array(arr_map, dtype=dtype)
return arr_map
@staticmethod
def flt_idx_map(flt_data, idx_map):
idx = 0
@@ -524,20 +544,18 @@ class TSDatasetH(DatasetH):
def setup_data(self, **kwargs):
super().setup_data(**kwargs)
# make sure the calendar is updated to latest when loading data from new config
cal = self.handler.fetch(col_set=self.handler.CS_RAW).index.get_level_values("datetime").unique()
cal = sorted(cal)
self.cal = cal
self.cal = sorted(cal)
def _prepare_raw_seg(self, slc: slice, **kwargs) -> pd.DataFrame:
@staticmethod
def _extend_slice(slc: slice, cal: list, step_len: int) -> slice:
# Dataset decide how to slice data(Get more data for timeseries).
start, end = slc.start, slc.stop
start_idx = bisect.bisect_left(self.cal, pd.Timestamp(start))
pad_start_idx = max(0, start_idx - self.step_len)
pad_start = self.cal[pad_start_idx]
# TSDatasetH will retrieve more data for complete
data = super()._prepare_seg(slice(pad_start, end), **kwargs)
return data
start_idx = bisect.bisect_left(cal, pd.Timestamp(start))
pad_start_idx = max(0, start_idx - step_len)
pad_start = cal[pad_start_idx]
return slice(pad_start, end)
def _prepare_seg(self, slc: slice, **kwargs) -> TSDataSampler:
"""
@@ -546,13 +564,15 @@ class TSDatasetH(DatasetH):
dtype = kwargs.pop("dtype", None)
start, end = slc.start, slc.stop
flt_col = kwargs.pop("flt_col", None)
# TSDatasetH will retrieve more data for complete
data = self._prepare_raw_seg(slc, **kwargs)
# TSDatasetH will retrieve more data for complete time-series
ext_slice = self._extend_slice(slc, self.cal, self.step_len)
data = super()._prepare_seg(ext_slice, **kwargs)
flt_kwargs = deepcopy(kwargs)
if flt_col is not None:
flt_kwargs["col_set"] = flt_col
flt_data = self._prepare_raw_seg(slc, **flt_kwargs)
flt_data = self._prepare_seg(ext_slice, **flt_kwargs)
assert len(flt_data.columns) == 1
else:
flt_data = None

View File

@@ -82,8 +82,6 @@ class DataHandler(Serializable):
fetch_orig : bool
Return the original data instead of copy if possible.
"""
# Set logger
self.logger = get_module_logger("DataHandler")
# Setup data loader
assert data_loader is not None # to make start_time end_time could have None default value
@@ -302,6 +300,7 @@ class DataHandlerLP(DataHandler):
DK_R = "raw"
DK_I = "infer"
DK_L = "learn"
ATTR_MAP = {DK_R: "_data", DK_I: "_infer", DK_L: "_learn"}
# process type
PTYPE_I = "independent"
@@ -543,7 +542,7 @@ class DataHandlerLP(DataHandler):
raise AttributeError(
"DataHandlerLP has not attribute _data, please set drop_raw = False if you want to use raw data"
)
df = getattr(self, {self.DK_R: "_data", self.DK_I: "_infer", self.DK_L: "_learn"}[data_key])
df = getattr(self, self.ATTR_MAP[data_key])
return df
def fetch(
@@ -624,3 +623,33 @@ class DataHandlerLP(DataHandler):
df = self._get_df_by_key(data_key).head()
df = fetch_df_by_col(df, col_set)
return df.columns.to_list()
@classmethod
def cast(cls, handler: "DataHandlerLP") -> "DataHandlerLP":
"""
Motivation
- A user create a datahandler in his customized package. Then he want to share the processed handler to other users without introduce the package dependency and complicated data processing logic.
- This class make it possible by casting the class to DataHandlerLP and only keep the processed data
Parameters
----------
handler : DataHandlerLP
A subclass of DataHandlerLP
Returns
-------
DataHandlerLP:
the converted processed data
"""
new_hd: DataHandlerLP = object.__new__(DataHandlerLP)
new_hd.from_cast = True # add a mark for the casted instance
for key in list(DataHandlerLP.ATTR_MAP.values()) + [
"instruments",
"start_time",
"end_time",
"fetch_orig",
"drop_raw",
]:
setattr(new_hd, key, getattr(handler, key, None))
return new_hd

View File

@@ -17,7 +17,7 @@ from ..utils import init_instance_by_config
from ..backtest.utils import CommonInfrastructure, LevelInfrastructure, TradeCalendarManager
from ..backtest.decision import BaseTradeDecision
__all__ = ["BaseStrategy", "ModelStrategy", "RLStrategy", "RLIntStrategy"]
__all__ = ["BaseStrategy", "RLStrategy", "RLIntStrategy"]
class BaseStrategy:
@@ -194,45 +194,6 @@ class BaseStrategy:
return max(cal_range[0], range_limit[0]), min(cal_range[1], range_limit[1])
class ModelStrategy(BaseStrategy):
"""Model-based trading strategy, use model to make predictions for trading"""
def __init__(
self,
model: BaseModel,
dataset: DatasetH,
outer_trade_decision: BaseTradeDecision = None,
level_infra: LevelInfrastructure = None,
common_infra: CommonInfrastructure = None,
**kwargs,
):
"""
Parameters
----------
model : BaseModel
the model used in when making predictions
dataset : DatasetH
provide test data for model
kwargs : dict
arguments that will be passed into `reset` method
"""
super(ModelStrategy, self).__init__(outer_trade_decision, level_infra, common_infra, **kwargs)
self.model = model
self.dataset = dataset
self.pred_scores = convert_index_format(self.model.predict(dataset), level="datetime")
if isinstance(self.pred_scores, pd.DataFrame):
self.pred_scores = self.pred_scores.iloc[:, 0]
def _update_model(self):
"""
When using online data, pdate model in each bar as the following steps:
- update dataset with online data, the dataset should support online update
- make the latest prediction scores of the new bar
- update the pred score into the latest prediction
"""
raise NotImplementedError("_update_model is not implemented!")
class RLStrategy(BaseStrategy):
"""RL-based strategy"""

View File

@@ -199,6 +199,7 @@ def get_callable_kwargs(config: Union[dict, str], default_module: Union[str, Mod
----------
config : [dict, str]
similar to config
please refer to the doc of init_instance_by_config
default_module : Python module or str
It should be a python module to load the class type
@@ -219,9 +220,12 @@ def get_callable_kwargs(config: Union[dict, str], default_module: Union[str, Mod
_callable = config["class"] # the class type itself is passed in
kwargs = config.get("kwargs", {})
elif isinstance(config, str):
module = get_module_by_module_path(default_module)
# a.b.c.ClassName
*m_path, cls = config.split(".")
m_path = ".".join(m_path)
module = get_module_by_module_path(default_module if m_path == "" else m_path)
_callable = getattr(module, config)
_callable = getattr(module, cls)
kwargs = {}
else:
raise NotImplementedError(f"This type of input is not supported")
@@ -260,7 +264,9 @@ def init_instance_by_config(
1) specify a pickle object
- path like 'file:///<path to pickle file>/obj.pkl'
2) specify a class name
- "ClassName": getattr(module, config)() will be used.
- "ClassName": getattr(module, "ClassName")() will be used.
3) specify module path with class name
- "a.b.c.ClassName" getattr(<a.b.c.module>, "ClassName")() will be used.
object example:
instance of accept_types
default_module : Python module

View File

@@ -401,6 +401,10 @@ class IndexData(metaclass=index_data_ops_creator):
def columns(self):
return self.indices[1]
def __getitem__(self, args):
# NOTE: this tries to behave like a numpy array to be compatible with numpy aggregating function like nansum and nanmean
return self.iloc[args]
def _align_indices(self, other: "IndexData") -> "IndexData":
"""
Align all indices of `other` to `self` before performing the arithmetic operations.
@@ -409,7 +413,7 @@ class IndexData(metaclass=index_data_ops_creator):
Parameters
----------
other : "IndexData"
the index in `other` is to be chagned
the index in `other` is to be changed
Returns
-------
@@ -455,7 +459,8 @@ class IndexData(metaclass=index_data_ops_creator):
"""
return len(self.data)
def sum(self, axis=None):
def sum(self, axis=None, dtype=None, out=None):
assert out is None and dtype is None, "`out` is just for compatible with numpy's aggregating function"
# FIXME: weird logic and not general
if axis is None:
return np.nansum(self.data)
@@ -468,7 +473,8 @@ class IndexData(metaclass=index_data_ops_creator):
else:
raise ValueError(f"axis must be None, 0 or 1")
def mean(self, axis=None):
def mean(self, axis=None, dtype=None, out=None):
assert out is None and dtype is None, "`out` is just for compatible with numpy's aggregating function"
# FIXME: weird logic and not general
if axis is None:
return np.nanmean(self.data)

View File

@@ -1,9 +1,15 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import pandas as pd
from functools import partial
from threading import Thread
from typing import Callable
from joblib import Parallel, delayed
from joblib._parallel_backends import MultiprocessingBackend
import pandas as pd
from queue import Queue
class ParallelExt(Parallel):
@@ -46,3 +52,54 @@ def datetime_groupby_apply(df, apply_func, axis=0, level="datetime", resample_ru
return pd.concat(dfs, axis=axis).sort_index()
else:
return _naive_group_apply(df)
class AsyncCaller:
"""
This AsyncCaller tries to make it easier to async call
Currently, it is used in MLflowRecorder to make functions like `log_params` async
NOTE:
- This caller didn't consider the return value
"""
STOP_MARK = "__STOP"
def __init__(self) -> None:
self._q = Queue()
self._stop = False
self._t = Thread(target=self.run)
self._t.start()
def close(self):
self._q.put(self.STOP_MARK)
def run(self):
while True:
data = self._q.get()
if data == self.STOP_MARK:
break
else:
data()
def __call__(self, func, *args, **kwargs):
self._q.put(partial(func, *args, **kwargs))
def wait(self, close=True):
if close:
self.close()
self._t.join()
@staticmethod
def async_dec(ac_attr):
def decorator_func(func):
def wrapper(self, *args, **kwargs):
if isinstance(getattr(self, ac_attr, None), Callable):
return getattr(self, ac_attr)(func, self, *args, **kwargs)
else:
return func(self, *args, **kwargs)
return wrapper
return decorator_func

View File

@@ -21,19 +21,65 @@ Situations Description
Online + Trainer When you want to do a REAL routine, the Trainer will help you train the models. It
will train models task by task and strategy by strategy.
Online + DelayTrainer When your models don't have any temporal dependence, the DelayTrainer will train
nothing until all tasks have been prepared. It makes user can train all tasks in
the end of `routine` or `first_train`.
Online + DelayTrainer DelayTrainer will skip concrete training until all tasks have been prepared by
different strategies. It makes users can parallelly train all tasks at the end of
`routine` or `first_train`. Otherwise, these functions will get stuck when each
strategy prepare tasks.
Simulation + Trainer When your models have some temporal dependence on the previous models, then you
need to consider using Trainer. This means it will REAL train your models in
every routine and prepare signals for every routine.
Simulation + Trainer It will behave in the same way as `Online + Trainer`. The only difference is that it
is for simulation/backtesting instead of online trading
Simulation + DelayTrainer When your models don't have any temporal dependence, you can use DelayTrainer
for the ability to multitasking. It means all tasks in all routines
can be REAL trained at the end of simulating. The signals will be prepared well at
different time segments (based on whether or not any new model is online).
========================= ===================================================================================
Here is some pseudo code the demonstrate the workflow of each situation
For simplicity
- Only one strategy is used in the strategy
- `update_online_pred` is only called in the online mode and is ignored
1) `Online + Trainer`
.. code-block:: python
tasks = first_train()
models = trainer.train(tasks)
trainer.end_train(models)
for day in online_trading_days:
# OnlineManager.routine
models = trainer.train(strategy.prepare_tasks()) # for each strategy
strategy.prepare_online_models(models) # for each strategy
trainer.end_train(models)
prepare_signals() # prepare trading signals daily
`Online + DelayTrainer`: the workflow is the same as `Online + Trainer`.
2) `Simulation + DelayTrainer`
.. code-block:: python
# simulate
tasks = first_train()
models = trainer.train(tasks)
for day in historical_calendars:
# OnlineManager.routine
models = trainer.train(strategy.prepare_tasks()) # for each strategy
strategy.prepare_online_models(models) # for each strategy
# delay_prepare()
# FIXME: Currently the delay_prepare is not implemented in a proper way.
trainer.end_train(<for all previous models>)
prepare_signals()
# Can we simplify current workflow?
- Can reduce the number of state of tasks?
- For each task, we have three phases (i.e. task, partly trained task, final trained task)
"""
import logging
@@ -58,7 +104,7 @@ class OnlineManager(Serializable):
"""
STATUS_SIMULATING = "simulating" # when calling `simulate`
STATUS_NORMAL = "normal" # the normal status
STATUS_ONLINE = "online" # the normal status. It is used when online trading
def __init__(
self,
@@ -87,12 +133,24 @@ class OnlineManager(Serializable):
self.begin_time = pd.Timestamp(begin_time)
self.cur_time = self.begin_time
# OnlineManager will recorder the history of online models, which is a dict like {pd.Timestamp, {strategy, [online_models]}}.
# It records the online servnig models of each strategy for each day.
self.history = {}
if trainer is None:
trainer = TrainerR()
self.trainer = trainer
self.signals = None
self.status = self.STATUS_NORMAL
self.status = self.STATUS_ONLINE
def _postpone_action(self):
"""
Should the workflow to postpone the following actions to the end (in delay_prepare)
- trainer.end_train
- prepare_signals
Postpone these actions is to support simulating/backtest online strategies without time dependencies.
All the actions can be done parallelly at the end.
"""
return self.status == self.STATUS_SIMULATING and self.trainer.is_delay()
def first_train(self, strategies: List[OnlineStrategy] = None, model_kwargs: dict = {}):
"""
@@ -113,12 +171,12 @@ class OnlineManager(Serializable):
models = self.trainer.train(tasks, experiment_name=strategy.name_id)
models_list.append(models)
self.logger.info(f"Finished training {len(models)} models.")
# FIXME: Traing multiple online models at `first_train` will result in getting too much online models at the
# FIXME: Train multiple online models at `first_train` will result in getting too much online models at the
# start.
online_models = strategy.prepare_online_models(models, **model_kwargs)
self.history.setdefault(self.cur_time, {})[strategy] = online_models
if not self.status == self.STATUS_SIMULATING or not self.trainer.is_delay():
if not self._postpone_action():
for strategy, models in zip(strategies, models_list):
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
@@ -160,10 +218,10 @@ class OnlineManager(Serializable):
# The online model may changes in the above processes
# So updating the predictions of online models should be the last step
if self.status == self.STATUS_NORMAL:
if self.status == self.STATUS_ONLINE:
strategy.tool.update_online_pred()
if not self.status == self.STATUS_SIMULATING or not self.trainer.is_delay():
if not self._postpone_action():
for strategy, models in zip(self.strategies, models_list):
models = self.trainer.end_train(models, experiment_name=strategy.name_id)
self.prepare_signals(**signal_kwargs)
@@ -278,13 +336,13 @@ class OnlineManager(Serializable):
signal_kwargs=signal_kwargs,
)
# delay prepare the models and signals
if self.trainer.is_delay():
if self._postpone_action():
self.delay_prepare(model_kwargs=model_kwargs, signal_kwargs=signal_kwargs)
# FIXME: get logging level firstly and restore it here
set_global_logger_level(logging.DEBUG)
self.logger.info(f"Finished preparing signals")
self.status = self.STATUS_NORMAL
self.status = self.STATUS_ONLINE
return self.get_signals()
def delay_prepare(self, model_kwargs={}, signal_kwargs={}):
@@ -295,6 +353,8 @@ class OnlineManager(Serializable):
model_kwargs: the params for `end_train`
signal_kwargs: the params for `prepare_signals`
"""
# FIXME:
# This method is not implemented in the proper way!!!
last_models = {}
signals_time = D.calendar()[0]
need_prepare = False

View File

@@ -9,6 +9,9 @@ import pandas as pd
from pathlib import Path
from pprint import pprint
from typing import Union, List
from collections import defaultdict
from qlib.utils.exceptions import LoadObjectError
from ..contrib.evaluate import indicator_analysis, risk_analysis, indicator_analysis
from ..data.dataset import DatasetH
@@ -45,6 +48,16 @@ class RecordTemp:
return "/".join(names)
def save(self, **kwargs):
"""
It behaves the same as self.recorder.save_objects.
But it is an easier interface because users don't have to care about `get_path` and `artifact_path`
"""
art_path = self.get_path()
if art_path == "":
art_path = None
self.recorder.save_objects(artifact_path=art_path, **kwargs)
def __init__(self, recorder):
self._recorder = recorder
@@ -67,31 +80,37 @@ class RecordTemp:
"""
raise NotImplementedError(f"Please implement the `generate` method.")
def load(self, name):
def load(self, name: str, parents: bool = True):
"""
Load the stored records. Due to the fact that some problems occured when we tried to balancing a clean API
with the Python's inheritance. This method has to be used in a rather ugly way, and we will try to fix them
in the future::
sar = SigAnaRecord(recorder)
ic = sar.load(sar.get_path("ic.pkl"))
It behaves the same as self.recorder.load_object.
But it is an easier interface because users don't have to care about `get_path` and `artifact_path`
Parameters
----------
name : str
the name for the file to be load.
parents : bool
Each recorder has different `artifact_path`.
So parents recursively find the path in parents
Sub classes has higher priority
Return
------
The stored records.
"""
# try to load the saved object
obj = self.recorder.load_object(name)
return obj
try:
return self.recorder.load_object(self.get_path(name))
except LoadObjectError:
if parents:
if self.depend_cls is not None:
with class_casting(self, self.depend_cls):
return self.load(name, parents=True)
def list(self):
"""
List the supported artifacts.
Users don't have to consider self.get_path
Return
------
@@ -99,7 +118,7 @@ class RecordTemp:
"""
return []
def check(self, include_self: bool = False):
def check(self, include_self: bool = False, parents: bool = True):
"""
Check if the records is properly generated and saved.
It is useful in following examples
@@ -110,19 +129,34 @@ class RecordTemp:
----------
include_self : bool
is the file generated by self included
parents : bool
will we check parents
Raise
------
FileExistsError: whether the records are stored properly.
FileNotFoundError
: whether the records are stored properly.
"""
artifacts = set(self.recorder.list_artifacts())
if include_self:
# Some mlflow backend will not list the directly recursively.
# So we force to the directly
artifacts = {}
def _get_arts(dirn):
if dirn not in artifacts:
artifacts[dirn] = self.recorder.list_artifacts(dirn)
return artifacts[dirn]
for item in self.list():
if item not in artifacts:
raise FileExistsError(item)
if self.depend_cls is not None:
with class_casting(self, self.depend_cls):
self.check(include_self=True)
ps = self.get_path(item).split("/")
dirn, fn = "/".join(ps[:-1]), ps[-1]
if self.get_path(item) not in _get_arts(dirn):
raise FileNotFoundError
if parents:
if self.depend_cls is not None:
with class_casting(self, self.depend_cls):
self.check(include_self=True)
class SignalRecord(RecordTemp):
@@ -158,7 +192,7 @@ class SignalRecord(RecordTemp):
pred = self.model.predict(self.dataset)
if isinstance(pred, pd.Series):
pred = pred.to_frame("score")
self.recorder.save_objects(**{"pred.pkl": pred})
self.save(**{"pred.pkl": pred})
logger.info(
f"Signal record 'pred.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
@@ -169,15 +203,11 @@ class SignalRecord(RecordTemp):
if isinstance(self.dataset, DatasetH):
raw_label = self.generate_label(self.dataset)
self.recorder.save_objects(**{"label.pkl": raw_label})
self.save(**{"label.pkl": raw_label})
@staticmethod
def list():
def list(self):
return ["pred.pkl", "label.pkl"]
def load(self, name="pred.pkl"):
return super().load(name)
class HFSignalRecord(SignalRecord):
"""
@@ -218,19 +248,11 @@ class HFSignalRecord(SignalRecord):
}
)
self.recorder.log_metrics(**metrics)
self.recorder.save_objects(**objects, artifact_path=self.get_path())
self.save(**objects)
pprint(metrics)
def list(self):
paths = [
self.get_path("ic.pkl"),
self.get_path("ric.pkl"),
self.get_path("long_pre.pkl"),
self.get_path("short_pre.pkl"),
self.get_path("long_short_r.pkl"),
self.get_path("long_avg_r.pkl"),
]
return paths
return ["ic.pkl", "ric.pkl", "long_pre.pkl", "short_pre.pkl", "long_short_r.pkl", "long_avg_r.pkl"]
class SigAnaRecord(RecordTemp):
@@ -241,13 +263,23 @@ class SigAnaRecord(RecordTemp):
artifact_path = "sig_analysis"
depend_cls = SignalRecord
def __init__(self, recorder, ana_long_short=False, ann_scaler=252, label_col=0):
def __init__(self, recorder, ana_long_short=False, ann_scaler=252, label_col=0, skip_existing=False):
super().__init__(recorder=recorder)
self.ana_long_short = ana_long_short
self.ann_scaler = ann_scaler
self.label_col = label_col
self.skip_existing = skip_existing
def generate(self, **kwargs):
if self.skip_existing:
try:
self.check(include_self=True, parents=False)
except FileNotFoundError:
pass # continue to generating metrics
else:
logger.info("The results has previously generated, generation skipped.")
return
self.check()
pred = self.load("pred.pkl")
@@ -280,13 +312,13 @@ class SigAnaRecord(RecordTemp):
}
)
self.recorder.log_metrics(**metrics)
self.recorder.save_objects(**objects, artifact_path=self.get_path())
self.save(**objects)
pprint(metrics)
def list(self):
paths = [self.get_path("ic.pkl"), self.get_path("ric.pkl")]
paths = ["ic.pkl", "ric.pkl"]
if self.ana_long_short:
paths.extend([self.get_path("long_short_r.pkl"), self.get_path("long_avg_r.pkl")])
paths.extend(["long_short_r.pkl", "long_avg_r.pkl"])
return paths
@@ -373,17 +405,11 @@ class PortAnaRecord(RecordTemp):
executor=self.executor_config, strategy=self.strategy_config, **self.backtest_config
)
for _freq, (report_normal, positions_normal) in portfolio_metric_dict.items():
self.recorder.save_objects(
**{f"report_normal_{_freq}.pkl": report_normal}, artifact_path=PortAnaRecord.get_path()
)
self.recorder.save_objects(
**{f"positions_normal_{_freq}.pkl": positions_normal}, artifact_path=PortAnaRecord.get_path()
)
self.save(**{f"report_normal_{_freq}.pkl": report_normal})
self.save(**{f"positions_normal_{_freq}.pkl": positions_normal})
for _freq, indicators_normal in indicator_dict.items():
self.recorder.save_objects(
**{f"indicators_normal_{_freq}.pkl": indicators_normal}, artifact_path=PortAnaRecord.get_path()
)
self.save(**{f"indicators_normal_{_freq}.pkl": indicators_normal})
for _analysis_freq in self.risk_analysis_freq:
if _analysis_freq not in portfolio_metric_dict:
@@ -405,9 +431,7 @@ class PortAnaRecord(RecordTemp):
analysis_dict = flatten_dict(analysis_df["risk"].unstack().T.to_dict())
self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()})
# save results
self.recorder.save_objects(
**{f"port_analysis_{_analysis_freq}.pkl": analysis_df}, artifact_path=PortAnaRecord.get_path()
)
self.save(**{f"port_analysis_{_analysis_freq}.pkl": analysis_df})
logger.info(
f"Portfolio analysis record 'port_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
)
@@ -432,9 +456,7 @@ class PortAnaRecord(RecordTemp):
analysis_dict = analysis_df["value"].to_dict()
self.recorder.log_metrics(**{f"{_analysis_freq}.{k}": v for k, v in analysis_dict.items()})
# save results
self.recorder.save_objects(
**{f"indicator_analysis_{_analysis_freq}.pkl": analysis_df}, artifact_path=PortAnaRecord.get_path()
)
self.save(**{f"indicator_analysis_{_analysis_freq}.pkl": analysis_df})
logger.info(
f"Indicator analysis record 'indicator_analysis_{_analysis_freq}.pkl' has been saved as the artifact of the Experiment {self.recorder.experiment_id}"
)
@@ -446,20 +468,19 @@ class PortAnaRecord(RecordTemp):
for _freq in self.all_freq:
list_path.extend(
[
PortAnaRecord.get_path(f"report_normal_{_freq}.pkl"),
PortAnaRecord.get_path(f"positions_normal_{_freq}.pkl"),
f"report_normal_{_freq}.pkl",
f"positions_normal_{_freq}.pkl",
]
)
for _analysis_freq in self.risk_analysis_freq:
if _analysis_freq in self.all_freq:
list_path.append(PortAnaRecord.get_path(f"port_analysis_{_analysis_freq}.pkl"))
list_path.append(f"port_analysis_{_analysis_freq}.pkl")
else:
warnings.warn(f"risk_analysis freq {_analysis_freq} is not found")
for _analysis_freq in self.indicator_analysis_freq:
if _analysis_freq in self.all_freq:
list_path.append(PortAnaRecord.get_path(f"indicator_analysis_{_analysis_freq}.pkl"))
list_path.append(f"indicator_analysis_{_analysis_freq}.pkl")
else:
warnings.warn(f"indicator_analysis freq {_analysis_freq} is not found")
return list_path

View File

@@ -1,6 +1,7 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import os
from qlib.utils.serial import Serializable
import mlflow, logging
import shutil, os, pickle, tempfile, codecs, pickle
@@ -8,8 +9,10 @@ from pathlib import Path
from datetime import datetime
from qlib.utils.exceptions import LoadObjectError
from qlib.utils.paral import AsyncCaller
from ..utils.objm import FileManager
from ..log import get_module_logger
from ..log import TimeInspector, get_module_logger
from mlflow.store.artifact.azure_blob_artifact_repo import AzureBlobArtifactRepository
logger = get_module_logger("workflow", logging.INFO)
@@ -227,6 +230,7 @@ class MLflowRecorder(Recorder):
if mlflow_run.info.end_time is not None
else None
)
self.async_log = None
def __repr__(self):
name = self.__class__.__name__
@@ -285,6 +289,10 @@ class MLflowRecorder(Recorder):
self.status = Recorder.STATUS_R
logger.info(f"Recorder {self.id} starts running under Experiment {self.experiment_id} ...")
# NOTE: making logging async.
# - This may cause delay when uploading results
# - The logging time may not be accurate
self.async_log = AsyncCaller()
return run
def end_run(self, status: str = Recorder.STATUS_S):
@@ -298,6 +306,9 @@ class MLflowRecorder(Recorder):
self.end_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
if self.status != Recorder.STATUS_S:
self.status = status
with TimeInspector.logt("waiting `async_log`"):
self.async_log.wait()
self.async_log = None
def save_objects(self, local_path=None, artifact_path=None, **kwargs):
assert self.uri is not None, "Please start the experiment and recorder first before using recorder directly."
@@ -333,18 +344,27 @@ class MLflowRecorder(Recorder):
try:
path = self.client.download_artifacts(self.id, name)
with Path(path).open("rb") as f:
return pickle.load(f)
data = pickle.load(f)
ar = self.client._tracking_client._get_artifact_repo(self.id)
if isinstance(ar, AzureBlobArtifactRepository):
# for saving disk space
# For safety, only remove redundant file for specific ArtifactRepository
shutil.rmtree(Path(path).absolute().parent)
return data
except Exception as e:
raise LoadObjectError(message=str(e))
@AsyncCaller.async_dec(ac_attr="async_log")
def log_params(self, **kwargs):
for name, data in kwargs.items():
self.client.log_param(self.id, name, data)
@AsyncCaller.async_dec(ac_attr="async_log")
def log_metrics(self, step=None, **kwargs):
for name, data in kwargs.items():
self.client.log_metric(self.id, name, data, step=step)
@AsyncCaller.async_dec(ac_attr="async_log")
def set_tags(self, **kwargs):
for name, data in kwargs.items():
self.client.set_tag(self.id, name, data)

View File

@@ -5,6 +5,7 @@
Collector module can collect objects from everywhere and process them such as merging, grouping, averaging and so on.
"""
from libs.qlib.qlib.log import TimeInspector
from typing import Callable, Dict, List
from qlib.log import get_module_logger
from qlib.utils.serial import Serializable
@@ -190,7 +191,9 @@ class RecorderCollector(Collector):
collect_dict = {}
# filter records
recs = self.experiment.list_recorders(**self.list_kwargs)
with TimeInspector.logt("Time to `list_recorders` in RecorderCollector"):
recs = self.experiment.list_recorders(**self.list_kwargs)
recs_flt = {}
for rid, rec in recs.items():
if rec_filter_func is None or rec_filter_func(rec):

View File

@@ -94,6 +94,11 @@ class TaskGen(metaclass=abc.ABCMeta):
def handler_mod(task: dict, rolling_gen):
"""
Help to modify the handler end time when using RollingGen
It try to handle the following case
- Hander's data end_time is earlier than dataset's test_data's segments.
- To handle this, handler's data's end_time is extended.
If the handler's end_time is None, then it is not necessary to change it's end time.
Args:
task (dict): a task template
@@ -112,6 +117,9 @@ def handler_mod(task: dict, rolling_gen):
except KeyError:
# Maybe dataset do not have handler, then do nothing.
pass
except TypeError:
# May be the handler is a string. `"handler.pkl"["kwargs"]` will raise TypeError
pass
class RollingGen(TaskGen):
@@ -259,3 +267,56 @@ class RollingGen(TaskGen):
# Update the following rolling
res.extend(self.gen_following_tasks(t, test_end))
return res
class MultiHorizonGenBase(TaskGen):
def __init__(self, horizon: List[int] = [5], label_leak_n=2):
"""
This task generator tries to genrate tasks for different horizons based on an existing task
Parameters
----------
horizon : List[int]
the possible horizons of the tasks
label_leak_n : int
How many future days it will take to get complete label after the day making prediction
For example:
- User make prediction on day `T`(after getting the close price on `T`)
- The label is the return of buying stock on `T + 1` and selling it on `T + 2`
- the `label_leak_n` will be 2 (e.g. two days of information is leaked to leverage this sample)
"""
self.horizon = list(horizon)
self.label_leak_n = label_leak_n
self.ta = TimeAdjuster()
self.test_key = "test"
@abc.abstractmethod
def set_horizon(self, task: dict, hr: int):
"""
This method is designed to change the task **in place**
Parameters
----------
task : dict
Qlib's task
hr : int
the horizon of task
"""
def generate(self, task: dict):
res = []
for hr in self.horizon:
# Add horizon
t = copy.deepcopy(task)
self.set_horizon(t, hr)
# adjust segment
segments = self.ta.align_seg(t["dataset"]["kwargs"]["segments"])
test_start = min(t for t in segments[self.test_key] if t is not None)
for k in list(segments.keys()):
if k != self.test_key:
segments[k] = self.ta.truncate(segments[k], test_start, hr + self.label_leak_n)
t["dataset"]["kwargs"]["segments"] = segments
res.append(t)
return res

View File

@@ -1,6 +1,5 @@
import numpy as np
import pandas as pd
import qlib.utils.index_data as idd
import unittest
@@ -115,6 +114,19 @@ class IndexDataTest(unittest.TestCase):
# sd2 = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"])
# 2 * sd2
def test_squeeze(self):
sd1 = idd.SingleData([1, 2, 3, 4], index=["foo", "bar", "f", "g"])
# automatically squeezing
self.assertTrue(not isinstance(np.nansum(sd1), idd.IndexData))
self.assertTrue(not isinstance(np.sum(sd1), idd.IndexData))
self.assertTrue(not isinstance(sd1.sum(), idd.IndexData))
self.assertEqual(np.nansum(sd1), 10)
self.assertEqual(np.sum(sd1), 10)
self.assertEqual(sd1.sum(), 10)
self.assertEqual(np.nanmean(sd1), 2.5)
self.assertEqual(np.mean(sd1), 2.5)
self.assertEqual(sd1.mean(), 2.5)
if __name__ == "__main__":
unittest.main()

View File

@@ -47,13 +47,13 @@ def train(uri_path: str = None):
rid = recorder.id
sr = SignalRecord(model, dataset, recorder)
sr.generate()
pred_score = sr.load(sr.get_path("pred.pkl"))
pred_score = sr.load("pred.pkl")
# calculate ic and ric
sar = SigAnaRecord(recorder)
sar.generate()
ic = sar.load(sar.get_path("ic.pkl"))
ric = sar.load(sar.get_path("ric.pkl"))
ic = sar.load("ic.pkl")
ric = sar.load("ric.pkl")
return pred_score, {"ic": ic, "ric": ric}, rid
@@ -78,13 +78,13 @@ def train_with_sigana(uri_path: str = None):
sr = SignalRecord(model, dataset, recorder)
sr.generate()
pred_score = sr.load(sr.get_path("pred.pkl"))
pred_score = sr.load("pred.pkl")
# predict and calculate ic and ric
sar = SigAnaRecord(recorder)
sar.generate()
ic = sar.load(sar.get_path("ic.pkl"))
ric = sar.load(sar.get_path("ric.pkl"))
ic = sar.load("ic.pkl")
ric = sar.load("ric.pkl")
uri_path = R.get_uri()
return pred_score, {"ic": ic, "ric": ric}, uri_path
@@ -144,10 +144,9 @@ def backtest_analysis(pred, rid, uri_path: str = None):
},
"strategy": {
"class": "TopkDropoutStrategy",
"module_path": "qlib.contrib.strategy.model_strategy",
"module_path": "qlib.contrib.strategy.signal_strategy",
"kwargs": {
"model": model,
"dataset": dataset,
"signal": (model, dataset),
"topk": 50,
"n_drop": 5,
},
@@ -170,7 +169,7 @@ def backtest_analysis(pred, rid, uri_path: str = None):
# backtest
par = PortAnaRecord(recorder, port_analysis_config, risk_analysis_freq="day")
par.generate()
analysis_df = par.load(par.get_path("port_analysis_1day.pkl"))
analysis_df = par.load("port_analysis_1day.pkl")
print(analysis_df)
return analysis_df