mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Merge remote-tracking branch 'microsoft/main' into online_srv
This commit is contained in:
@@ -17,6 +17,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
| ALSTM (Yao Qin, et al.) | Alpha360 | 0.0493±0.01 | 0.3778±0.06| 0.0585±0.00 | 0.4606±0.04 | 0.0513±0.03 | 0.6727±0.38| -0.1085±0.02 |
|
||||
| GATs (Petar Velickovic, et al.) | Alpha360 | 0.0475±0.00 | 0.3515±0.02| 0.0592±0.00 | 0.4585±0.01 | 0.0876±0.02 | 1.1513±0.27| -0.0795±0.02 |
|
||||
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha360 | 0.0407±0.00| 0.3053±0.00 | 0.0490±0.00 | 0.3840±0.00 | 0.0380±0.02 | 0.5000±0.21 | -0.0984±0.02 |
|
||||
| TabNet (Sercan O. Arik, et al.)| Alpha360 | 0.0192±0.00 | 0.1401±0.00| 0.0291±0.00 | 0.2163±0.00 | -0.0258±0.00 | -0.2961±0.00| -0.1429±0.00 |
|
||||
|
||||
## Alpha158 dataset
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown |
|
||||
@@ -32,6 +33,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of
|
||||
| ALSTM (Yao Qin, et al.) | Alpha158 (with selected 20 features) | 0.0385±0.01 | 0.3022±0.06| 0.0478±0.00 | 0.3874±0.04 | 0.0486±0.03 | 0.7141±0.45| -0.1088±0.03 |
|
||||
| GATs (Petar Velickovic, et al.) | Alpha158 (with selected 20 features) | 0.0349±0.00 | 0.2511±0.01| 0.0457±0.00 | 0.3537±0.01 | 0.0578±0.02 | 0.8221±0.25| -0.0824±0.02 |
|
||||
| DoubleEnsemble (Chuheng Zhang, et al.) | Alpha158 | 0.0544±0.00 | 0.4338±0.01 | 0.0523±0.00 | 0.4257±0.01 | 0.1253±0.01 | 1.4105±0.14 | -0.0902±0.01 |
|
||||
| TabNet (Sercan O. Arik, et al.)| Alpha158 | 0.0383±0.00 | 0.3414±0.00| 0.0388±0.00 | 0.3460±0.00 | 0.0226±0.00 | 0.2652±0.00| -0.1072±0.00 |
|
||||
|
||||
- The selected 20 features are based on the feature importance of a lightgbm-based model.
|
||||
- The base model of DoubleEnsemble is LGBM.
|
||||
|
||||
@@ -132,7 +132,7 @@ class GenericDataFormatter(abc.ABC):
|
||||
return -1, -1
|
||||
|
||||
def get_column_definition(self):
|
||||
""""Returns formatted column definition in order expected by the TFT."""
|
||||
"""Returns formatted column definition in order expected by the TFT."""
|
||||
|
||||
column_definition = self._column_definition
|
||||
|
||||
|
||||
@@ -25,4 +25,11 @@ The example is given in `workflow.py`, users can run the code as follows.
|
||||
Run the example by running the following command:
|
||||
```bash
|
||||
python workflow.py dump_and_load_dataset
|
||||
```
|
||||
```
|
||||
|
||||
## Benchmarks Performance
|
||||
### Signal Test
|
||||
Here are the results of signal test for benchmark models. We will keep updating benchmark models in future.
|
||||
| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Long precision| Short Precision | Long-Short Average Return | Long-Short Average Sharpe |
|
||||
|---|---|---|---|---|---|---|---|---|---|
|
||||
| LightGBM | Alpha158 | 0.3042±0.00 | 1.5372±0.00| 0.3117±0.00 | 1.6258±0.00 | 0.6720±0.00 | 0.6870±0.00 | 0.000769±0.00 | 1.0190±0.00 |
|
||||
|
||||
@@ -0,0 +1,65 @@
|
||||
qlib_init:
|
||||
provider_uri: "~/.qlib/qlib_data/cn_data_1min"
|
||||
region: cn
|
||||
market: &market 'csi300'
|
||||
start_time: &start_time "2020-09-15 00:00:00"
|
||||
end_time: &end_time "2021-01-18 16:00:00"
|
||||
train_end_time: &train_end_time "2020-11-15 16:00:00"
|
||||
valid_start_time: &valid_start_time "2020-11-16 00:00:00"
|
||||
valid_end_time: &valid_end_time "2020-11-30 16:00:00"
|
||||
test_start_time: &test_start_time "2020-12-01 00:00:00"
|
||||
data_handler_config: &data_handler_config
|
||||
start_time: *start_time
|
||||
end_time: *end_time
|
||||
fit_start_time: *start_time
|
||||
fit_end_time: *train_end_time
|
||||
instruments: *market
|
||||
freq: '1min'
|
||||
infer_processors:
|
||||
- class: 'RobustZScoreNorm'
|
||||
kwargs:
|
||||
fields_group: 'feature'
|
||||
clip_outlier: false
|
||||
- class: "Fillna"
|
||||
kwargs:
|
||||
fields_group: 'feature'
|
||||
learn_processors:
|
||||
- class: 'DropnaLabel'
|
||||
- class: 'CSRankNorm'
|
||||
kwargs:
|
||||
fields_group: 'label'
|
||||
label: ["Ref($close, -2) / Ref($close, -1) - 1"]
|
||||
|
||||
task:
|
||||
model:
|
||||
class: "HFLGBModel"
|
||||
module_path: "qlib.contrib.model.highfreq_gdbt_model"
|
||||
kwargs:
|
||||
objective: 'binary'
|
||||
metric: ['binary_logloss','auc']
|
||||
verbosity: -1
|
||||
learning_rate: 0.01
|
||||
max_depth: 8
|
||||
num_leaves: 150
|
||||
lambda_l1: 1.5
|
||||
lambda_l2: 1
|
||||
num_threads: 20
|
||||
dataset:
|
||||
class: "DatasetH"
|
||||
module_path: "qlib.data.dataset"
|
||||
kwargs:
|
||||
handler:
|
||||
class: "Alpha158"
|
||||
module_path: "qlib.contrib.data.handler"
|
||||
kwargs: *data_handler_config
|
||||
segments:
|
||||
train: [*start_time, *train_end_time]
|
||||
valid: [*train_end_time, *valid_end_time]
|
||||
test: [*test_start_time, *end_time]
|
||||
record:
|
||||
- class: "SignalRecord"
|
||||
module_path: "qlib.workflow.record_temp"
|
||||
kwargs: {}
|
||||
- class: "HFSignalRecord"
|
||||
module_path: "qlib.workflow.record_temp"
|
||||
kwargs: {}
|
||||
@@ -15,7 +15,8 @@ LOG = get_module_logger("backtest")
|
||||
|
||||
|
||||
def backtest(pred, strategy, executor, trade_exchange, shift, verbose, account, benchmark, return_order):
|
||||
"""Parameters
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
pred : pandas.DataFrame
|
||||
predict should has <datetime, instrument> index and one `score` column
|
||||
@@ -124,7 +125,9 @@ def backtest(pred, strategy, executor, trade_exchange, shift, verbose, account,
|
||||
|
||||
|
||||
def update_account(trade_account, trade_info, trade_exchange, trade_date):
|
||||
"""Update the account and strategy
|
||||
"""
|
||||
Update the account and strategy
|
||||
|
||||
Parameters
|
||||
----------
|
||||
trade_account : Account()
|
||||
|
||||
@@ -128,7 +128,7 @@ class Position:
|
||||
return self.position["cash"]
|
||||
|
||||
def get_stock_amount_dict(self):
|
||||
"""generate stock amount dict {stock_id : amount of stock} """
|
||||
"""generate stock amount dict {stock_id : amount of stock}"""
|
||||
d = {}
|
||||
stock_list = self.get_stock_list()
|
||||
for stock_code in stock_list:
|
||||
|
||||
@@ -8,6 +8,59 @@ import pandas as pd
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
def calc_long_short_prec(
|
||||
pred: pd.Series, label: pd.Series, date_col="datetime", quantile: float = 0.2, dropna=False, is_alpha=False
|
||||
) -> Tuple[pd.Series, pd.Series]:
|
||||
"""
|
||||
calculate the precision for long and short operation
|
||||
|
||||
|
||||
:param pred/label: index is **pd.MultiIndex**, index name is **[datetime, instruments]**; columns names is **[score]**.
|
||||
|
||||
.. code-block:: python
|
||||
score
|
||||
datetime instrument
|
||||
2020-12-01 09:30:00 SH600068 0.553634
|
||||
SH600195 0.550017
|
||||
SH600276 0.540321
|
||||
SH600584 0.517297
|
||||
SH600715 0.544674
|
||||
label :
|
||||
label
|
||||
date_col :
|
||||
date_col
|
||||
|
||||
Returns
|
||||
-------
|
||||
(pd.Series, pd.Series)
|
||||
long precision and short precision in time level
|
||||
"""
|
||||
if is_alpha:
|
||||
label = label - label.mean(level=date_col)
|
||||
if int(1 / quantile) >= len(label.index.get_level_values(1).unique()):
|
||||
raise ValueError("Need more instruments to calculate precision")
|
||||
|
||||
df = pd.DataFrame({"pred": pred, "label": label})
|
||||
if dropna:
|
||||
df.dropna(inplace=True)
|
||||
|
||||
group = df.groupby(level=date_col)
|
||||
|
||||
N = lambda x: int(len(x) * quantile)
|
||||
# find the top/low quantile of prediction and treat them as long and short target
|
||||
long = group.apply(lambda x: x.nlargest(N(x), columns="pred").label).reset_index(level=0, drop=True)
|
||||
short = group.apply(lambda x: x.nsmallest(N(x), columns="pred").label).reset_index(level=0, drop=True)
|
||||
|
||||
groupll = long.groupby(date_col)
|
||||
l_dom = groupll.apply(lambda x: x > 0)
|
||||
l_c = groupll.count()
|
||||
|
||||
groups = short.groupby(date_col)
|
||||
s_dom = groups.apply(lambda x: x < 0)
|
||||
s_c = groups.count()
|
||||
return (l_dom.groupby(date_col).sum() / l_c), (s_dom.groupby(date_col).sum() / s_c)
|
||||
|
||||
|
||||
def calc_ic(pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False) -> Tuple[pd.Series, pd.Series]:
|
||||
"""calc_ic.
|
||||
|
||||
|
||||
157
qlib/contrib/model/highfreq_gdbt_model.py
Normal file
157
qlib/contrib/model/highfreq_gdbt_model.py
Normal file
@@ -0,0 +1,157 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import lightgbm as lgb
|
||||
|
||||
from qlib.model.base import ModelFT
|
||||
from qlib.data.dataset import DatasetH
|
||||
from qlib.data.dataset.handler import DataHandlerLP
|
||||
import warnings
|
||||
|
||||
|
||||
class HFLGBModel(ModelFT):
|
||||
"""LightGBM Model for high frequency prediction"""
|
||||
|
||||
def __init__(self, loss="mse", **kwargs):
|
||||
if loss not in {"mse", "binary"}:
|
||||
raise NotImplementedError
|
||||
self.params = {"objective": loss, "verbosity": -1}
|
||||
self.params.update(kwargs)
|
||||
self.model = None
|
||||
|
||||
def _cal_signal_metrics(self, y_test, l_cut, r_cut):
|
||||
"""
|
||||
Calcaute the signal metrics by daily level
|
||||
"""
|
||||
up_pre, down_pre = [], []
|
||||
up_alpha_ll, down_alpha_ll = [], []
|
||||
for date in y_test.index.get_level_values(0).unique():
|
||||
df_res = y_test.loc[date].sort_values("pred")
|
||||
if int(l_cut * len(df_res)) < 10:
|
||||
warnings.warn("Warning: threhold is too low or instruments number is not enough")
|
||||
continue
|
||||
top = df_res.iloc[: int(l_cut * len(df_res))]
|
||||
bottom = df_res.iloc[int(r_cut * len(df_res)) :]
|
||||
|
||||
down_precision = len(top[top[top.columns[0]] < 0]) / (len(top))
|
||||
up_precision = len(bottom[bottom[top.columns[0]] > 0]) / (len(bottom))
|
||||
|
||||
down_alpha = top[top.columns[0]].mean()
|
||||
up_alpha = bottom[bottom.columns[0]].mean()
|
||||
|
||||
up_pre.append(up_precision)
|
||||
down_pre.append(down_precision)
|
||||
up_alpha_ll.append(up_alpha)
|
||||
down_alpha_ll.append(down_alpha)
|
||||
|
||||
return (
|
||||
np.array(up_pre).mean(),
|
||||
np.array(down_pre).mean(),
|
||||
np.array(up_alpha_ll).mean(),
|
||||
np.array(down_alpha_ll).mean(),
|
||||
)
|
||||
|
||||
def hf_signal_test(self, dataset: DatasetH, threhold=0.2):
|
||||
"""
|
||||
Test the sigal in high frequency test set
|
||||
"""
|
||||
if self.model == None:
|
||||
raise ValueError("Model hasn't been trained yet")
|
||||
df_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I)
|
||||
df_test.dropna(inplace=True)
|
||||
x_test, y_test = df_test["feature"], df_test["label"]
|
||||
# Convert label into alpha
|
||||
y_test[y_test.columns[0]] = y_test[y_test.columns[0]] - y_test[y_test.columns[0]].mean(level=0)
|
||||
|
||||
res = pd.Series(self.model.predict(x_test.values), index=x_test.index)
|
||||
y_test["pred"] = res
|
||||
|
||||
up_p, down_p, up_a, down_a = self._cal_signal_metrics(y_test, threhold, 1 - threhold)
|
||||
print("===============================")
|
||||
print("High frequency signal test")
|
||||
print("===============================")
|
||||
print("Test set precision: ")
|
||||
print("Positive precision: {}, Negative precision: {}".format(up_p, down_p))
|
||||
print("Test Alpha Average in test set: ")
|
||||
print("Positive average alpha: {}, Negative average alpha: {}".format(up_a, down_a))
|
||||
|
||||
def _prepare_data(self, dataset: DatasetH):
|
||||
df_train, df_valid = dataset.prepare(
|
||||
["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L
|
||||
)
|
||||
|
||||
x_train, y_train = df_train["feature"], df_train["label"]
|
||||
x_valid, y_valid = df_train["feature"], df_valid["label"]
|
||||
if y_train.values.ndim == 2 and y_train.values.shape[1] == 1:
|
||||
l_name = df_train["label"].columns[0]
|
||||
# Convert label into alpha
|
||||
df_train["label"][l_name] = df_train["label"][l_name] - df_train["label"][l_name].mean(level=0)
|
||||
df_valid["label"][l_name] = df_valid["label"][l_name] - df_valid["label"][l_name].mean(level=0)
|
||||
mapping_fn = lambda x: 0 if x < 0 else 1
|
||||
df_train["label_c"] = df_train["label"][l_name].apply(mapping_fn)
|
||||
df_valid["label_c"] = df_valid["label"][l_name].apply(mapping_fn)
|
||||
x_train, y_train = df_train["feature"], df_train["label_c"].values
|
||||
x_valid, y_valid = df_valid["feature"], df_valid["label_c"].values
|
||||
else:
|
||||
raise ValueError("LightGBM doesn't support multi-label training")
|
||||
|
||||
dtrain = lgb.Dataset(x_train.values, label=y_train)
|
||||
dvalid = lgb.Dataset(x_valid.values, label=y_valid)
|
||||
return dtrain, dvalid
|
||||
|
||||
def fit(
|
||||
self,
|
||||
dataset: DatasetH,
|
||||
num_boost_round=1000,
|
||||
early_stopping_rounds=50,
|
||||
verbose_eval=20,
|
||||
evals_result=dict(),
|
||||
**kwargs
|
||||
):
|
||||
dtrain, dvalid = self._prepare_data(dataset)
|
||||
self.model = lgb.train(
|
||||
self.params,
|
||||
dtrain,
|
||||
num_boost_round=num_boost_round,
|
||||
valid_sets=[dtrain, dvalid],
|
||||
valid_names=["train", "valid"],
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
verbose_eval=verbose_eval,
|
||||
evals_result=evals_result,
|
||||
**kwargs
|
||||
)
|
||||
evals_result["train"] = list(evals_result["train"].values())[0]
|
||||
evals_result["valid"] = list(evals_result["valid"].values())[0]
|
||||
|
||||
def predict(self, dataset):
|
||||
if self.model is None:
|
||||
raise ValueError("model is not fitted yet!")
|
||||
x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I)
|
||||
return pd.Series(self.model.predict(x_test.values), index=x_test.index)
|
||||
|
||||
def finetune(self, dataset: DatasetH, num_boost_round=10, verbose_eval=20):
|
||||
"""
|
||||
finetune model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
dataset : DatasetH
|
||||
dataset for finetuning
|
||||
num_boost_round : int
|
||||
number of round to finetune model
|
||||
verbose_eval : int
|
||||
verbose level
|
||||
"""
|
||||
# Based on existing model and finetune by train more rounds
|
||||
dtrain, _ = self._prepare_data(dataset)
|
||||
self.model = lgb.train(
|
||||
self.params,
|
||||
dtrain,
|
||||
num_boost_round=num_boost_round,
|
||||
init_model=self.model,
|
||||
valid_sets=[dtrain],
|
||||
valid_names=["train"],
|
||||
verbose_eval=verbose_eval,
|
||||
)
|
||||
@@ -214,7 +214,7 @@ def cumulative_return_graph(
|
||||
features_df = D.features(D.instruments('csi500'), ['Ref($close, -1)/$close - 1'], pred_df_dates.min(), pred_df_dates.max())
|
||||
features_df.columns = ['label']
|
||||
|
||||
qcr.cumulative_return_graph(positions, report_normal_df, features_df)
|
||||
qcr.analysis_position.cumulative_return_graph(positions, report_normal_df, features_df)
|
||||
|
||||
|
||||
Graph desc:
|
||||
|
||||
@@ -94,7 +94,7 @@ def rank_label_graph(
|
||||
features_df = D.features(D.instruments('csi500'), ['Ref($close, -1)/$close-1'], pred_df_dates.min(), pred_df_dates.max())
|
||||
features_df.columns = ['label']
|
||||
|
||||
qcr.rank_label_graph(positions, features_df, pred_df_dates.min(), pred_df_dates.max())
|
||||
qcr.analysis_position.rank_label_graph(positions, features_df, pred_df_dates.min(), pred_df_dates.max())
|
||||
|
||||
|
||||
:param position: position data; **qlib.contrib.backtest.backtest.backtest** result.
|
||||
|
||||
@@ -186,7 +186,7 @@ def report_graph(report_df: pd.DataFrame, show_notebook: bool = True) -> [list,
|
||||
|
||||
report_normal_df, _ = backtest(pred_df, strategy, **bparas)
|
||||
|
||||
qcr.report_graph(report_normal_df)
|
||||
qcr.analysis_position.report_graph(report_normal_df)
|
||||
|
||||
:param report_df: **df.index.name** must be **date**, **df.columns** must contain **return**, **turnover**, **cost**, **bench**.
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from ...utils import get_module_by_module_path
|
||||
|
||||
|
||||
class BaseGraph:
|
||||
""""""
|
||||
""" """
|
||||
|
||||
_name = None
|
||||
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import logging
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from sklearn.metrics import mean_squared_error
|
||||
from typing import Dict, Text, Any
|
||||
import numpy as np
|
||||
|
||||
from ...contrib.eva.alpha import calc_ic
|
||||
from ...workflow.record_temp import RecordTemp
|
||||
@@ -12,7 +13,7 @@ from ...workflow.record_temp import SignalRecord
|
||||
from ...data import dataset as qlib_dataset
|
||||
from ...log import get_module_logger
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
logger = get_module_logger("workflow", logging.INFO)
|
||||
|
||||
|
||||
class MultiSegRecord(RecordTemp):
|
||||
|
||||
@@ -522,6 +522,9 @@ class LocalCalendarProvider(CalendarProvider):
|
||||
# if future calendar not exists, return current calendar
|
||||
if not os.path.exists(fname):
|
||||
get_module_logger("data").warning(f"{freq}_future.txt not exists, return current calendar!")
|
||||
get_module_logger("data").warning(
|
||||
"You can get future calendar by referring to the following document: https://github.com/microsoft/qlib/blob/main/scripts/data_collector/contrib/README.md"
|
||||
)
|
||||
fname = self._uri_cal.format(freq)
|
||||
else:
|
||||
fname = self._uri_cal.format(freq)
|
||||
@@ -1016,7 +1019,8 @@ class ClientProvider(BaseProvider):
|
||||
self.logger = get_module_logger(self.__class__.__name__)
|
||||
if isinstance(Cal, ClientCalendarProvider):
|
||||
Cal.set_conn(self.client)
|
||||
Inst.set_conn(self.client)
|
||||
if isinstance(Inst, ClientInstrumentProvider):
|
||||
Inst.set_conn(self.client)
|
||||
if hasattr(DatasetD, "provider"):
|
||||
DatasetD.provider.set_conn(self.client)
|
||||
else:
|
||||
|
||||
@@ -130,7 +130,7 @@ class FilterCol(Processor):
|
||||
|
||||
|
||||
class TanhProcess(Processor):
|
||||
""" Use tanh to process noise data"""
|
||||
"""Use tanh to process noise data"""
|
||||
|
||||
def __call__(self, df):
|
||||
def tanh_denoise(data):
|
||||
@@ -145,7 +145,7 @@ class TanhProcess(Processor):
|
||||
|
||||
|
||||
class ProcessInf(Processor):
|
||||
"""Process infinity """
|
||||
"""Process infinity"""
|
||||
|
||||
def __call__(self, df):
|
||||
def replace_inf(data):
|
||||
|
||||
34
qlib/log.py
34
qlib/log.py
@@ -12,7 +12,37 @@ from contextlib import contextmanager
|
||||
from .config import C
|
||||
|
||||
|
||||
def get_module_logger(module_name, level: Optional[int] = None):
|
||||
class MetaLogger(type):
|
||||
def __new__(cls, name, bases, dict):
|
||||
wrapper_dict = logging.Logger.__dict__.copy()
|
||||
wrapper_dict.update(dict)
|
||||
wrapper_dict["__doc__"] = logging.Logger.__doc__
|
||||
return type.__new__(cls, name, bases, wrapper_dict)
|
||||
|
||||
|
||||
class QlibLogger(metaclass=MetaLogger):
|
||||
"""
|
||||
Customized logger for Qlib.
|
||||
"""
|
||||
|
||||
def __init__(self, module_name):
|
||||
self.module_name = module_name
|
||||
self.level = 0
|
||||
|
||||
@property
|
||||
def logger(self):
|
||||
logger = logging.getLogger(self.module_name)
|
||||
logger.setLevel(self.level)
|
||||
return logger
|
||||
|
||||
def setLevel(self, level):
|
||||
self.level = level
|
||||
|
||||
def __getattr__(self, name):
|
||||
return self.logger.__getattribute__(name)
|
||||
|
||||
|
||||
def get_module_logger(module_name, level: Optional[int] = None) -> logging.Logger:
|
||||
"""
|
||||
Get a logger for a specific module.
|
||||
|
||||
@@ -27,7 +57,7 @@ def get_module_logger(module_name, level: Optional[int] = None):
|
||||
|
||||
module_name = "qlib.{}".format(module_name)
|
||||
# Get logger.
|
||||
module_logger = logging.getLogger(module_name)
|
||||
module_logger = QlibLogger(module_name)
|
||||
module_logger.setLevel(level)
|
||||
return module_logger
|
||||
|
||||
|
||||
@@ -11,11 +11,11 @@ class BaseModel(Serializable, metaclass=abc.ABCMeta):
|
||||
|
||||
@abc.abstractmethod
|
||||
def predict(self, *args, **kwargs) -> object:
|
||||
""" Make predictions after modeling things """
|
||||
"""Make predictions after modeling things"""
|
||||
pass
|
||||
|
||||
def __call__(self, *args, **kwargs) -> object:
|
||||
""" leverage Python syntactic sugar to make the models' behaviors like functions """
|
||||
"""leverage Python syntactic sugar to make the models' behaviors like functions"""
|
||||
return self.predict(*args, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -5,9 +5,9 @@ import abc
|
||||
|
||||
|
||||
class BaseOptimizer(abc.ABC):
|
||||
""" Construct portfolio with a optimization related method """
|
||||
"""Construct portfolio with a optimization related method"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def __call__(self, *args, **kwargs) -> object:
|
||||
""" Generate a optimized portfolio allocation """
|
||||
"""Generate a optimized portfolio allocation"""
|
||||
pass
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import mlflow
|
||||
import mlflow, logging
|
||||
from mlflow.entities import ViewType
|
||||
from mlflow.exceptions import MlflowException
|
||||
from pathlib import Path
|
||||
from .recorder import Recorder, MLflowRecorder
|
||||
from ..log import get_module_logger
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
logger = get_module_logger("workflow", logging.INFO)
|
||||
|
||||
|
||||
class Experiment:
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
import mlflow
|
||||
from mlflow.exceptions import MlflowException
|
||||
from mlflow.entities import ViewType
|
||||
import os
|
||||
import os, logging
|
||||
from pathlib import Path
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Text
|
||||
@@ -14,7 +14,7 @@ from ..config import C
|
||||
from .recorder import Recorder
|
||||
from ..log import get_module_logger
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
logger = get_module_logger("workflow", logging.INFO)
|
||||
|
||||
|
||||
class ExpManager:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import re
|
||||
import re, logging
|
||||
import pandas as pd
|
||||
from pathlib import Path
|
||||
from pprint import pprint
|
||||
@@ -13,10 +13,10 @@ from ..data.dataset.handler import DataHandlerLP
|
||||
from ..utils import init_instance_by_config, get_module_by_module_path
|
||||
from ..log import get_module_logger
|
||||
from ..utils import flatten_dict
|
||||
from ..contrib.eva.alpha import calc_ic, calc_long_short_return
|
||||
from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec
|
||||
from ..contrib.strategy.strategy import BaseStrategy
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
logger = get_module_logger("workflow", logging.INFO)
|
||||
|
||||
|
||||
class RecordTemp:
|
||||
@@ -166,6 +166,60 @@ class SignalRecord(RecordTemp):
|
||||
return super().load(name)
|
||||
|
||||
|
||||
class HFSignalRecord(SignalRecord):
|
||||
"""
|
||||
This is the Signal Analysis Record class that generates the analysis results such as IC and IR. This class inherits the ``RecordTemp`` class.
|
||||
"""
|
||||
|
||||
artifact_path = "hg_sig_analysis"
|
||||
|
||||
def __init__(self, recorder, **kwargs):
|
||||
super().__init__(recorder=recorder)
|
||||
|
||||
def generate(self):
|
||||
pred = self.load("pred.pkl")
|
||||
raw_label = self.load("label.pkl")
|
||||
long_pre, short_pre = calc_long_short_prec(pred.iloc[:, 0], raw_label.iloc[:, 0], is_alpha=True)
|
||||
ic, ric = calc_ic(pred.iloc[:, 0], raw_label.iloc[:, 0])
|
||||
metrics = {
|
||||
"IC": ic.mean(),
|
||||
"ICIR": ic.mean() / ic.std(),
|
||||
"Rank IC": ric.mean(),
|
||||
"Rank ICIR": ric.mean() / ric.std(),
|
||||
"Long precision": long_pre.mean(),
|
||||
"Short precision": short_pre.mean(),
|
||||
}
|
||||
objects = {"ic.pkl": ic, "ric.pkl": ric}
|
||||
objects.update({"long_pre.pkl": long_pre, "short_pre.pkl": short_pre})
|
||||
long_short_r, long_avg_r = calc_long_short_return(pred.iloc[:, 0], raw_label.iloc[:, 0])
|
||||
metrics.update(
|
||||
{
|
||||
"Long-Short Average Return": long_short_r.mean(),
|
||||
"Long-Short Average Sharpe": long_short_r.mean() / long_short_r.std(),
|
||||
}
|
||||
)
|
||||
objects.update(
|
||||
{
|
||||
"long_short_r.pkl": long_short_r,
|
||||
"long_avg_r.pkl": long_avg_r,
|
||||
}
|
||||
)
|
||||
self.recorder.log_metrics(**metrics)
|
||||
self.recorder.save_objects(**objects, artifact_path=self.get_path())
|
||||
pprint(metrics)
|
||||
|
||||
def list(self):
|
||||
paths = [
|
||||
self.get_path("ic.pkl"),
|
||||
self.get_path("ric.pkl"),
|
||||
self.get_path("long_pre.pkl"),
|
||||
self.get_path("short_pre.pkl"),
|
||||
self.get_path("long_short_r.pkl"),
|
||||
self.get_path("long_avg_r.pkl"),
|
||||
]
|
||||
return paths
|
||||
|
||||
|
||||
class SigAnaRecord(SignalRecord):
|
||||
"""
|
||||
This is the Signal Analysis Record class that generates the analysis results such as IC and IR. This class inherits the ``RecordTemp`` class.
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import mlflow
|
||||
import mlflow, logging
|
||||
import shutil, os, pickle, tempfile, codecs, pickle
|
||||
from pathlib import Path
|
||||
from datetime import datetime
|
||||
from ..utils.objm import FileManager
|
||||
from ..log import get_module_logger
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
logger = get_module_logger("workflow", logging.INFO)
|
||||
|
||||
|
||||
class Recorder:
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys, traceback, signal, atexit
|
||||
import sys, traceback, signal, atexit, logging
|
||||
from . import R
|
||||
from .recorder import Recorder
|
||||
from ..log import get_module_logger
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
logger = get_module_logger("workflow", logging.INFO)
|
||||
|
||||
|
||||
# function to handle the experiment when unusual program ending occurs
|
||||
|
||||
24
scripts/data_collector/contrib/README.md
Normal file
24
scripts/data_collector/contrib/README.md
Normal file
@@ -0,0 +1,24 @@
|
||||
# Get future trading days
|
||||
|
||||
> `D.calendar(future=True)` will be used
|
||||
|
||||
## Requirements
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Collector Data
|
||||
|
||||
```bash
|
||||
# parse instruments, using in qlib/instruments.
|
||||
python future_trading_date_collector.py --qlib_dir ~/.qlib/qlib_data/cn_data --freq day
|
||||
```
|
||||
|
||||
## Parameters
|
||||
|
||||
- qlib_dir: qlib data directory
|
||||
- freq: value from [`day`, `1min`], default `day`
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import sys
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from loguru import logger
|
||||
|
||||
# get data from baostock
|
||||
import baostock as bs
|
||||
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
|
||||
|
||||
from data_collector.utils import generate_minutes_calendar_from_daily
|
||||
|
||||
|
||||
def read_calendar_from_qlib(qlib_dir: Path) -> pd.DataFrame:
|
||||
calendar_path = qlib_dir.joinpath("calendars").joinpath("day.txt")
|
||||
if not calendar_path.exists():
|
||||
return pd.DataFrame()
|
||||
return pd.read_csv(calendar_path, header=None)
|
||||
|
||||
|
||||
def write_calendar_to_qlib(qlib_dir: Path, date_list: List[str], freq: str = "day"):
|
||||
calendar_path = str(qlib_dir.joinpath("calendars").joinpath(f"{freq}_future.txt"))
|
||||
|
||||
np.savetxt(calendar_path, date_list, fmt="%s", encoding="utf-8")
|
||||
logger.info(f"write future calendars success: {calendar_path}")
|
||||
|
||||
|
||||
def generate_qlib_calendar(date_list: List[str], freq: str) -> List[str]:
|
||||
print(freq)
|
||||
if freq == "day":
|
||||
return date_list
|
||||
elif freq == "1min":
|
||||
date_list = generate_minutes_calendar_from_daily(date_list, freq=freq).tolist()
|
||||
return list(map(lambda x: pd.Timestamp(x).strftime("%Y-%m-%d %H:%M:%S"), date_list))
|
||||
else:
|
||||
raise ValueError(f"Unsupported freq: {freq}")
|
||||
|
||||
|
||||
def future_calendar_collector(qlib_dir: [str, Path], freq: str = "day"):
|
||||
"""get future calendar
|
||||
|
||||
Parameters
|
||||
----------
|
||||
qlib_dir: str or Path
|
||||
qlib data directory
|
||||
freq: str
|
||||
value from ["day", "1min"], by default day
|
||||
"""
|
||||
qlib_dir = Path(qlib_dir).expanduser().resolve()
|
||||
if not qlib_dir.exists():
|
||||
raise FileNotFoundError(str(qlib_dir))
|
||||
|
||||
lg = bs.login()
|
||||
if lg.error_code != "0":
|
||||
logger.error(f"login error: {lg.error_msg}")
|
||||
return
|
||||
# read daily calendar
|
||||
daily_calendar = read_calendar_from_qlib(qlib_dir)
|
||||
end_year = pd.Timestamp.now().year
|
||||
if daily_calendar.empty:
|
||||
start_year = pd.Timestamp.now().year
|
||||
else:
|
||||
start_year = pd.Timestamp(daily_calendar.iloc[-1, 0]).year
|
||||
rs = bs.query_trade_dates(start_date=pd.Timestamp(f"{start_year}-01-01"), end_date=f"{end_year}-12-31")
|
||||
data_list = []
|
||||
while (rs.error_code == "0") & rs.next():
|
||||
_row_data = rs.get_row_data()
|
||||
if int(_row_data[1]) == 1:
|
||||
data_list.append(_row_data[0])
|
||||
data_list = sorted(data_list)
|
||||
date_list = generate_qlib_calendar(data_list, freq=freq)
|
||||
write_calendar_to_qlib(qlib_dir, date_list, freq=freq)
|
||||
bs.logout()
|
||||
logger.info(f"get trading dates success: {start_year}-01-01 to {end_year}-12-31")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(future_calendar_collector)
|
||||
5
scripts/data_collector/contrib/requirements.txt
Normal file
5
scripts/data_collector/contrib/requirements.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
baostock
|
||||
fire
|
||||
numpy
|
||||
pandas
|
||||
loguru
|
||||
@@ -10,7 +10,9 @@ import random
|
||||
import requests
|
||||
import functools
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Tuple
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from lxml import etree
|
||||
from loguru import logger
|
||||
@@ -418,5 +420,40 @@ def get_trading_date_by_shift(trading_list: list, trading_date: pd.Timestamp, sh
|
||||
return res
|
||||
|
||||
|
||||
def generate_minutes_calendar_from_daily(
|
||||
calendars: Iterable,
|
||||
freq: str = "1min",
|
||||
am_range: Tuple[str, str] = ("09:30:00", "11:29:00"),
|
||||
pm_range: Tuple[str, str] = ("13:00:00", "14:59:00"),
|
||||
) -> pd.Index:
|
||||
"""generate minutes calendar
|
||||
|
||||
Parameters
|
||||
----------
|
||||
calendars: Iterable
|
||||
daily calendar
|
||||
freq: str
|
||||
by default 1min
|
||||
am_range: Tuple[str, str]
|
||||
AM Time Range, by default China-Stock: ("09:30:00", "11:29:00")
|
||||
pm_range: Tuple[str, str]
|
||||
PM Time Range, by default China-Stock: ("13:00:00", "14:59:00")
|
||||
|
||||
"""
|
||||
daily_format: str = "%Y-%m-%d"
|
||||
res = []
|
||||
for _day in calendars:
|
||||
for _range in [am_range, pm_range]:
|
||||
res.append(
|
||||
pd.date_range(
|
||||
f"{pd.Timestamp(_day).strftime(daily_format)} {_range[0]}",
|
||||
f"{pd.Timestamp(_day).strftime(daily_format)} {_range[1]}",
|
||||
freq=freq,
|
||||
)
|
||||
)
|
||||
|
||||
return pd.Index(sorted(set(np.hstack(res))))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
assert len(get_hs_stock_symbols()) >= MINIMUM_SYMBOLS_NUM
|
||||
|
||||
@@ -24,7 +24,12 @@ from qlib.config import REG_CN as REGION_CN
|
||||
CUR_DIR = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CUR_DIR.parent.parent))
|
||||
from data_collector.base import BaseCollector, BaseNormalize, BaseRun
|
||||
from data_collector.utils import get_calendar_list, get_hs_stock_symbols, get_us_stock_symbols
|
||||
from data_collector.utils import (
|
||||
get_calendar_list,
|
||||
get_hs_stock_symbols,
|
||||
get_us_stock_symbols,
|
||||
generate_minutes_calendar_from_daily,
|
||||
)
|
||||
|
||||
INDEX_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{index_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg={begin}&end={end}"
|
||||
|
||||
@@ -418,21 +423,9 @@ class YahooNormalize1min(YahooNormalize, ABC):
|
||||
return calendar_list_1d
|
||||
|
||||
def generate_1min_from_daily(self, calendars: Iterable) -> pd.Index:
|
||||
res = []
|
||||
daily_format = self.DAILY_FORMAT
|
||||
am_range = self.AM_RANGE
|
||||
pm_range = self.PM_RANGE
|
||||
for _day in calendars:
|
||||
for _range in [am_range, pm_range]:
|
||||
res.append(
|
||||
pd.date_range(
|
||||
f"{_day.strftime(daily_format)} {_range[0]}",
|
||||
f"{_day.strftime(daily_format)} {_range[1]}",
|
||||
freq="1min",
|
||||
)
|
||||
)
|
||||
|
||||
return pd.Index(sorted(set(np.hstack(res))))
|
||||
return generate_minutes_calendar_from_daily(
|
||||
calendars, freq="1min", am_range=self.AM_RANGE, pm_range=self.PM_RANGE
|
||||
)
|
||||
|
||||
def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
# TODO: using daily data factor
|
||||
|
||||
Reference in New Issue
Block a user