From 1ad237f89fc5197a6629b8e2df2217dd3e2fb712 Mon Sep 17 00:00:00 2001 From: Alex Wang Date: Mon, 22 Mar 2021 14:20:44 +0800 Subject: [PATCH 01/19] update high freq demo --- ...rkflow_config_High_Freq_Tree_Alpha158.yaml | 65 ++++++++ qlib/contrib/eva/alpha.py | 40 +++++ qlib/contrib/model/highfreq_gdbt_model.py | 157 ++++++++++++++++++ qlib/workflow/record_temp.py | 50 +++++- 4 files changed, 311 insertions(+), 1 deletion(-) create mode 100644 examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml create mode 100644 qlib/contrib/model/highfreq_gdbt_model.py diff --git a/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml b/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml new file mode 100644 index 000000000..ca8e92d08 --- /dev/null +++ b/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml @@ -0,0 +1,65 @@ +qlib_init: + provider_uri: "~/.qlib/qlib_data/yahoo_cn_1min" + region: cn +market: &market ['SH605222', 'SZ002796', 'SZ002246', 'SZ000713', 'SZ002820', 'SH601328', 'SZ000668', 'SH603359', 'SZ002144', 'SH600195', 'SH603685', 'SH603386', 'SZ002586', 'SZ000573', 'SZ000605', 'SZ002842', 'SH600068', 'SZ300547', 'SZ000926', 'SZ002036', 'SZ002161', 'SH600715', 'SZ300427', 'SZ002573', 'SZ300142', 'SH605116', 'SZ002951', 'SH600276', 'SZ002437', 'SH603355', 'SZ002893', 'SH600584'] +start_time: &start_time "2020-09-15 00:00:00" +end_time: &end_time "2021-01-18 16:00:00" +train_end_time: &train_end_time "2020-11-15 16:00:00" +valid_start_time: &valid_start_time "2020-11-16 00:00:00" +valid_end_time: &valid_end_time "2020-11-30 16:00:00" +test_start_time: &test_start_time "2020-12-01 00:00:00" +data_handler_config: &data_handler_config + start_time: *start_time + end_time: *end_time + fit_start_time: *start_time + fit_end_time: *train_end_time + instruments: *market + freq: '1min' + infer_processors: + - class: 'RobustZScoreNorm' + kwargs: + fields_group: 'feature' + clip_outlier: false + - class: "Fillna" + kwargs: + fields_group: 'feature' + learn_processors: + - class: 'DropnaLabel' + - class: 'CSRankNorm' + kwargs: + fields_group: 'label' + label: ["Ref($close, -2) / Ref($close, -1) - 1"] + +task: + model: + class: "HF_LGBModel" + module_path: "qlib.contrib.model.highfreq_gdbt_model" + kwargs: + objective: 'binary' + metric: ['binary_logloss','auc'] + verbosity: -1 + learning_rate: 0.01 + max_depth: 8 + num_leaves: 150 + lambda_l1: 1.5 + lambda_l2: 1 + num_threads: 20 + dataset: + class: "DatasetH" + module_path: "qlib.data.dataset" + kwargs: + handler: + class: "Alpha158" + module_path: "qlib.contrib.data.handler" + kwargs: *data_handler_config + segments: + train: [*start_time, *train_end_time] + valid: [*train_end_time, *valid_end_time] + test: [*test_start_time, *end_time] + record: + - class: "SignalRecord" + module_path: "qlib.workflow.record_temp" + kwargs: {} + - class: "HFSignalRecord" + module_path: "qlib.workflow.record_temp" + kwargs: {} \ No newline at end of file diff --git a/qlib/contrib/eva/alpha.py b/qlib/contrib/eva/alpha.py index c68571853..e2beafc13 100644 --- a/qlib/contrib/eva/alpha.py +++ b/qlib/contrib/eva/alpha.py @@ -7,6 +7,46 @@ import pandas as pd from typing import Tuple +def calc_prec(pred: pd.Series, label: pd.Series, date_col="datetime", quantile: float = 0.2, dropna=False, is_alpha=False) -> Tuple[pd.Series, pd.Series]: + """ calculate the precision + pred : + pred + label : + label + date_col : + date_col + + Returns + ------- + (pd.Series, pd.Series) + long precision and short precision in time level + """ + if is_alpha: + label = label - label.mean(level=0) + if int(1/quantile) >= len(label.index.get_level_values(1).unique()): + raise ValueError("Need more instruments to calculate precision") + + + df = pd.DataFrame({"pred": pred, "label": label}) + if dropna: + df.dropna(inplace = True) + + group = df.groupby(level=date_col) + + N = lambda x: int(len(x) * quantile) + # find the top/low quantile of prediction and treat them as long and short target + long = group.apply(lambda x: x.nlargest(N(x), columns="pred").label).reset_index(level=0, drop=True) + short = group.apply(lambda x: x.nsmallest(N(x), columns="pred").label).reset_index(level=0, drop=True) + + groupll = long.groupby(date_col) + ll_ration = groupll.apply(lambda x: x > 0) + ll_c = groupll.count() + + groups = short.groupby(date_col) + s_ration = groups.apply(lambda x: x < 0) + s_c = groups.count() + return (ll_ration.groupby(date_col).sum()/ll_c), (s_ration.groupby(date_col).sum()/s_c) + def calc_ic(pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False) -> Tuple[pd.Series, pd.Series]: """calc_ic. diff --git a/qlib/contrib/model/highfreq_gdbt_model.py b/qlib/contrib/model/highfreq_gdbt_model.py new file mode 100644 index 000000000..62e45c841 --- /dev/null +++ b/qlib/contrib/model/highfreq_gdbt_model.py @@ -0,0 +1,157 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import numpy as np +import pandas as pd +import lightgbm as lgb + +from qlib.model.base import ModelFT +from qlib.data.dataset import DatasetH +from qlib.data.dataset.handler import DataHandlerLP +import warnings + + +class HF_LGBModel(ModelFT): + """LightGBM Model""" + + def __init__(self, loss="mse", **kwargs): + if loss not in {"mse", "binary"}: + raise NotImplementedError + self.params = {"objective": loss, "verbosity": -1} + self.params.update(kwargs) + self.model = None + + def _cal_signal_metrics(self, y_test, l_cut, r_cut): + """ + Calcaute the signal metrics by daily level + """ + up_pre, down_pre = [], [] + up_alpha_ll, down_alpha_ll = [], [] + for date in y_test.index.get_level_values(0).unique(): + df_res = y_test.loc[date].sort_values("pred") + if int(l_cut * len(df_res)) < 10: + warnings.warn("Warning: threhold is too low or instruments number is not enough") + continue + top = df_res.iloc[: int(l_cut * len(df_res))] + bottom = df_res.iloc[int(r_cut * len(df_res)) :] + + down_precision = len(top[top[top.columns[0]] < 0]) / (len(top)) + up_precision = len(bottom[bottom[top.columns[0]] > 0]) / (len(bottom)) + + down_alpha = top[top.columns[0]].mean() + up_alpha = bottom[bottom.columns[0]].mean() + + up_pre.append(up_precision) + down_pre.append(down_precision) + up_alpha_ll.append(up_alpha) + down_alpha_ll.append(down_alpha) + + return ( + np.array(up_pre).mean(), + np.array(down_pre).mean(), + np.array(up_alpha_ll).mean(), + np.array(down_alpha_ll).mean(), + ) + + def hf_signal_test(self, dataset: DatasetH, threhold=0.2): + """ + Test the sigal in high frequency test set + """ + if self.model == None: + raise ValueError("Model hasn't been trained yet") + df_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I) + df_test.dropna(inplace=True) + x_test, y_test = df_test["feature"], df_test["label"] + # Convert label into alpha + y_test[y_test.columns[0]] = y_test[y_test.columns[0]] - y_test[y_test.columns[0]].mean(level=0) + + res = pd.Series(self.model.predict(x_test.values), index=x_test.index) + y_test["pred"] = res + + up_p, down_p, up_a, down_a = self._cal_signal_metrics(y_test, threhold, 1 - threhold) + print("===============================") + print("High frequency signal test") + print("===============================") + print("Test set precision: ") + print("Positive precision: {}, Negative precision: {}".format(up_p, down_p)) + print("Test Alpha Average in test set: ") + print("Positive average alpha: {}, Negative average alpha: {}".format(up_a, down_a)) + + def _prepare_data(self, dataset: DatasetH): + df_train, df_valid = dataset.prepare( + ["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L + ) + + x_train, y_train = df_train["feature"], df_train["label"] + x_valid, y_valid = df_train["feature"], df_valid["label"] + if y_train.values.ndim == 2 and y_train.values.shape[1] == 1: + l_name = df_train["label"].columns[0] + # Convert label into alpha + df_train["label"][l_name] = df_train["label"][l_name] - df_train["label"][l_name].mean(level=0) + df_valid["label"][l_name] = df_valid["label"][l_name] - df_valid["label"][l_name].mean(level=0) + mapping_fn = lambda x: 0 if x < 0 else 1 + df_train["label_c"] = df_train["label"][l_name].apply(mapping_fn) + df_valid["label_c"] = df_valid["label"][l_name].apply(mapping_fn) + x_train, y_train = df_train["feature"], df_train["label_c"].values + x_valid, y_valid = df_valid["feature"], df_valid["label_c"].values + else: + raise ValueError("LightGBM doesn't support multi-label training") + + dtrain = lgb.Dataset(x_train.values, label=y_train) + dvalid = lgb.Dataset(x_valid.values, label=y_valid) + return dtrain, dvalid + + def fit( + self, + dataset: DatasetH, + num_boost_round=1000, + early_stopping_rounds=50, + verbose_eval=20, + evals_result=dict(), + **kwargs + ): + dtrain, dvalid = self._prepare_data(dataset) + self.model = lgb.train( + self.params, + dtrain, + num_boost_round=num_boost_round, + valid_sets=[dtrain, dvalid], + valid_names=["train", "valid"], + early_stopping_rounds=early_stopping_rounds, + verbose_eval=verbose_eval, + evals_result=evals_result, + **kwargs + ) + evals_result["train"] = list(evals_result["train"].values())[0] + evals_result["valid"] = list(evals_result["valid"].values())[0] + + def predict(self, dataset): + if self.model is None: + raise ValueError("model is not fitted yet!") + x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I) + return pd.Series(self.model.predict(x_test.values), index=x_test.index) + + def finetune(self, dataset: DatasetH, num_boost_round=10, verbose_eval=20): + """ + finetune model + + Parameters + ---------- + dataset : DatasetH + dataset for finetuning + num_boost_round : int + number of round to finetune model + verbose_eval : int + verbose level + """ + # Based on existing model and finetune by train more rounds + dtrain, _ = self._prepare_data(dataset) + self.model = lgb.train( + self.params, + dtrain, + num_boost_round=num_boost_round, + init_model=self.model, + valid_sets=[dtrain], + valid_names=["train"], + verbose_eval=verbose_eval, + ) diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index 2c1b6fecc..8ab8405a5 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -13,7 +13,7 @@ from ..data.dataset.handler import DataHandlerLP from ..utils import init_instance_by_config, get_module_by_module_path from ..log import get_module_logger from ..utils import flatten_dict -from ..contrib.eva.alpha import calc_ic, calc_long_short_return +from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_prec from ..contrib.strategy.strategy import BaseStrategy logger = get_module_logger("workflow", "INFO") @@ -154,6 +154,54 @@ class SignalRecord(RecordTemp): def load(self, name="pred.pkl"): return super().load(name) + + +class HFSignalRecord(SignalRecord): + """ + This is the Signal Analysis Record class that generates the analysis results such as IC and IR. This class inherits the ``RecordTemp`` class. + """ + artifact_path = "hg_sig_analysis" + + def __init__(self, recorder, **kwargs): + super().__init__(recorder=recorder) + + def generate(self): + pred = self.load("pred.pkl") + raw_label = self.load("label.pkl") + + long_pre, short_pre = calc_prec(pred.iloc[:, 0], raw_label.iloc[:, 0], is_alpha = True) + ic, ric = calc_ic(pred.iloc[:, 0], raw_label.iloc[:, 0]) + metrics = { + "IC": ic.mean(), + "ICIR": ic.mean() / ic.std(), + "Rank IC": ric.mean(), + "Rank ICIR": ric.mean() / ric.std(), + "Long precision": long_pre.mean(), + "Short precision": short_pre.mean() + } + objects = {"ic.pkl": ic, "ric.pkl": ric} + objects.update({"long_pre.pkl": long_pre, "short_pre.pkl": short_pre}) + long_short_r, long_avg_r = calc_long_short_return(pred.iloc[:, 0], raw_label.iloc[:, 0]) + metrics.update( + { + "Long-Short Average Return": long_short_r.mean(), + "Long-Short Average Sharpe": long_short_r.mean() / long_short_r.std(), + } + ) + objects.update( + { + "long_short_r.pkl": long_short_r, + "long_avg_r.pkl": long_avg_r, + } + ) + self.recorder.log_metrics(**metrics) + self.recorder.save_objects(**objects, artifact_path=self.get_path()) + pprint(metrics) + + def list(self): + paths = [self.get_path("ic.pkl"), self.get_path("ric.pkl"), self.get_path("long_pre.pkl"), self.get_path("short_pre.pkl")] + paths.extend([self.get_path("long_short_r.pkl"), self.get_path("long_avg_r.pkl")]) + return paths class SigAnaRecord(SignalRecord): From 3bf6c7f95f5cc77d4025358e618d5f688138f5cc Mon Sep 17 00:00:00 2001 From: Alex Wang Date: Mon, 22 Mar 2021 15:37:54 +0800 Subject: [PATCH 02/19] update format --- qlib/contrib/eva/alpha.py | 24 +++++++++++++----------- qlib/workflow/record_temp.py | 16 +++++++++++----- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/qlib/contrib/eva/alpha.py b/qlib/contrib/eva/alpha.py index e2beafc13..8078dd4ed 100644 --- a/qlib/contrib/eva/alpha.py +++ b/qlib/contrib/eva/alpha.py @@ -7,15 +7,18 @@ import pandas as pd from typing import Tuple -def calc_prec(pred: pd.Series, label: pd.Series, date_col="datetime", quantile: float = 0.2, dropna=False, is_alpha=False) -> Tuple[pd.Series, pd.Series]: - """ calculate the precision + +def calc_prec( + pred: pd.Series, label: pd.Series, date_col="datetime", quantile: float = 0.2, dropna=False, is_alpha=False +) -> Tuple[pd.Series, pd.Series]: + """calculate the precision pred : pred label : label date_col : date_col - + Returns ------- (pd.Series, pd.Series) @@ -23,29 +26,28 @@ def calc_prec(pred: pd.Series, label: pd.Series, date_col="datetime", quantile: """ if is_alpha: label = label - label.mean(level=0) - if int(1/quantile) >= len(label.index.get_level_values(1).unique()): + if int(1 / quantile) >= len(label.index.get_level_values(1).unique()): raise ValueError("Need more instruments to calculate precision") - df = pd.DataFrame({"pred": pred, "label": label}) if dropna: - df.dropna(inplace = True) - + df.dropna(inplace=True) + group = df.groupby(level=date_col) - + N = lambda x: int(len(x) * quantile) # find the top/low quantile of prediction and treat them as long and short target long = group.apply(lambda x: x.nlargest(N(x), columns="pred").label).reset_index(level=0, drop=True) short = group.apply(lambda x: x.nsmallest(N(x), columns="pred").label).reset_index(level=0, drop=True) - + groupll = long.groupby(date_col) ll_ration = groupll.apply(lambda x: x > 0) ll_c = groupll.count() - + groups = short.groupby(date_col) s_ration = groups.apply(lambda x: x < 0) s_c = groups.count() - return (ll_ration.groupby(date_col).sum()/ll_c), (s_ration.groupby(date_col).sum()/s_c) + return (ll_ration.groupby(date_col).sum() / ll_c), (s_ration.groupby(date_col).sum() / s_c) def calc_ic(pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False) -> Tuple[pd.Series, pd.Series]: diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index 8ab8405a5..c47b999f3 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -154,12 +154,13 @@ class SignalRecord(RecordTemp): def load(self, name="pred.pkl"): return super().load(name) - - + + class HFSignalRecord(SignalRecord): """ This is the Signal Analysis Record class that generates the analysis results such as IC and IR. This class inherits the ``RecordTemp`` class. """ + artifact_path = "hg_sig_analysis" def __init__(self, recorder, **kwargs): @@ -169,7 +170,7 @@ class HFSignalRecord(SignalRecord): pred = self.load("pred.pkl") raw_label = self.load("label.pkl") - long_pre, short_pre = calc_prec(pred.iloc[:, 0], raw_label.iloc[:, 0], is_alpha = True) + long_pre, short_pre = calc_prec(pred.iloc[:, 0], raw_label.iloc[:, 0], is_alpha=True) ic, ric = calc_ic(pred.iloc[:, 0], raw_label.iloc[:, 0]) metrics = { "IC": ic.mean(), @@ -177,7 +178,7 @@ class HFSignalRecord(SignalRecord): "Rank IC": ric.mean(), "Rank ICIR": ric.mean() / ric.std(), "Long precision": long_pre.mean(), - "Short precision": short_pre.mean() + "Short precision": short_pre.mean(), } objects = {"ic.pkl": ic, "ric.pkl": ric} objects.update({"long_pre.pkl": long_pre, "short_pre.pkl": short_pre}) @@ -199,7 +200,12 @@ class HFSignalRecord(SignalRecord): pprint(metrics) def list(self): - paths = [self.get_path("ic.pkl"), self.get_path("ric.pkl"), self.get_path("long_pre.pkl"), self.get_path("short_pre.pkl")] + paths = [ + self.get_path("ic.pkl"), + self.get_path("ric.pkl"), + self.get_path("long_pre.pkl"), + self.get_path("short_pre.pkl"), + ] paths.extend([self.get_path("long_short_r.pkl"), self.get_path("long_avg_r.pkl")]) return paths From e3739bb980b5347520a13fd510bf9bf7180c8905 Mon Sep 17 00:00:00 2001 From: Alex Wang Date: Wed, 24 Mar 2021 15:47:26 +0800 Subject: [PATCH 03/19] fix naming and code style --- ...rkflow_config_High_Freq_Tree_Alpha158.yaml | 2 +- qlib/contrib/eva/alpha.py | 29 +++++++++++++------ qlib/contrib/model/highfreq_gdbt_model.py | 4 +-- qlib/workflow/record_temp.py | 8 ++--- 4 files changed, 27 insertions(+), 16 deletions(-) diff --git a/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml b/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml index ca8e92d08..c21ef1da3 100644 --- a/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml +++ b/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml @@ -32,7 +32,7 @@ data_handler_config: &data_handler_config task: model: - class: "HF_LGBModel" + class: "HFLGBModel" module_path: "qlib.contrib.model.highfreq_gdbt_model" kwargs: objective: 'binary' diff --git a/qlib/contrib/eva/alpha.py b/qlib/contrib/eva/alpha.py index 8078dd4ed..fadef9d16 100644 --- a/qlib/contrib/eva/alpha.py +++ b/qlib/contrib/eva/alpha.py @@ -8,12 +8,23 @@ import pandas as pd from typing import Tuple -def calc_prec( +def calc_long_short_prec( pred: pd.Series, label: pd.Series, date_col="datetime", quantile: float = 0.2, dropna=False, is_alpha=False ) -> Tuple[pd.Series, pd.Series]: - """calculate the precision - pred : - pred + """ + calculate the precision for long and short operation + + + :param pred/label: index is **pd.MultiIndex**, index name is **[datetime, instruments]**; columns names is **[score]**. + + .. code-block:: python + score + datetime instrument + 2020-12-01 09:30:00 SH600068 0.553634 + SH600195 0.550017 + SH600276 0.540321 + SH600584 0.517297 + SH600715 0.544674 label : label date_col : @@ -25,7 +36,7 @@ def calc_prec( long precision and short precision in time level """ if is_alpha: - label = label - label.mean(level=0) + label = label - label.mean(level=date_col) if int(1 / quantile) >= len(label.index.get_level_values(1).unique()): raise ValueError("Need more instruments to calculate precision") @@ -41,13 +52,13 @@ def calc_prec( short = group.apply(lambda x: x.nsmallest(N(x), columns="pred").label).reset_index(level=0, drop=True) groupll = long.groupby(date_col) - ll_ration = groupll.apply(lambda x: x > 0) - ll_c = groupll.count() + l_dom = groupll.apply(lambda x: x > 0) + l_c = groupll.count() groups = short.groupby(date_col) - s_ration = groups.apply(lambda x: x < 0) + s_dom = groups.apply(lambda x: x < 0) s_c = groups.count() - return (ll_ration.groupby(date_col).sum() / ll_c), (s_ration.groupby(date_col).sum() / s_c) + return (l_dom.groupby(date_col).sum() / l_c), (s_dom.groupby(date_col).sum() / s_c) def calc_ic(pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False) -> Tuple[pd.Series, pd.Series]: diff --git a/qlib/contrib/model/highfreq_gdbt_model.py b/qlib/contrib/model/highfreq_gdbt_model.py index 62e45c841..5a2eeb50a 100644 --- a/qlib/contrib/model/highfreq_gdbt_model.py +++ b/qlib/contrib/model/highfreq_gdbt_model.py @@ -11,8 +11,8 @@ from qlib.data.dataset.handler import DataHandlerLP import warnings -class HF_LGBModel(ModelFT): - """LightGBM Model""" +class HFLGBModel(ModelFT): + """LightGBM Model for high frequency prediction""" def __init__(self, loss="mse", **kwargs): if loss not in {"mse", "binary"}: diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index c47b999f3..239527fa0 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -13,7 +13,7 @@ from ..data.dataset.handler import DataHandlerLP from ..utils import init_instance_by_config, get_module_by_module_path from ..log import get_module_logger from ..utils import flatten_dict -from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_prec +from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec from ..contrib.strategy.strategy import BaseStrategy logger = get_module_logger("workflow", "INFO") @@ -169,8 +169,7 @@ class HFSignalRecord(SignalRecord): def generate(self): pred = self.load("pred.pkl") raw_label = self.load("label.pkl") - - long_pre, short_pre = calc_prec(pred.iloc[:, 0], raw_label.iloc[:, 0], is_alpha=True) + long_pre, short_pre = calc_long_short_prec(pred.iloc[:, 0], raw_label.iloc[:, 0], is_alpha=True) ic, ric = calc_ic(pred.iloc[:, 0], raw_label.iloc[:, 0]) metrics = { "IC": ic.mean(), @@ -205,8 +204,9 @@ class HFSignalRecord(SignalRecord): self.get_path("ric.pkl"), self.get_path("long_pre.pkl"), self.get_path("short_pre.pkl"), + self.get_path("long_short_r.pkl"), + self.get_path("long_avg_r.pkl"), ] - paths.extend([self.get_path("long_short_r.pkl"), self.get_path("long_avg_r.pkl")]) return paths From bed1175e2404ccd4d711bb71aff9577c8449c6a9 Mon Sep 17 00:00:00 2001 From: Alex Wang Date: Tue, 30 Mar 2021 19:29:17 +0800 Subject: [PATCH 04/19] update dataset --- .../highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml b/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml index c21ef1da3..45c59c670 100644 --- a/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml +++ b/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml @@ -1,7 +1,7 @@ qlib_init: - provider_uri: "~/.qlib/qlib_data/yahoo_cn_1min" + provider_uri: "~/.qlib/qlib_data/cn_data_1min" region: cn -market: &market ['SH605222', 'SZ002796', 'SZ002246', 'SZ000713', 'SZ002820', 'SH601328', 'SZ000668', 'SH603359', 'SZ002144', 'SH600195', 'SH603685', 'SH603386', 'SZ002586', 'SZ000573', 'SZ000605', 'SZ002842', 'SH600068', 'SZ300547', 'SZ000926', 'SZ002036', 'SZ002161', 'SH600715', 'SZ300427', 'SZ002573', 'SZ300142', 'SH605116', 'SZ002951', 'SH600276', 'SZ002437', 'SH603355', 'SZ002893', 'SH600584'] +market: &market 'csi300' start_time: &start_time "2020-09-15 00:00:00" end_time: &end_time "2021-01-18 16:00:00" train_end_time: &train_end_time "2020-11-15 16:00:00" From fe190dec4b6670a8e0d5410545d7bb8a13304157 Mon Sep 17 00:00:00 2001 From: Alex Wang Date: Wed, 14 Apr 2021 14:40:28 +0800 Subject: [PATCH 05/19] update readme --- examples/highfreq/README.md | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/examples/highfreq/README.md b/examples/highfreq/README.md index 30c2e19db..c07d8a2a0 100644 --- a/examples/highfreq/README.md +++ b/examples/highfreq/README.md @@ -25,4 +25,11 @@ The example is given in `workflow.py`, users can run the code as follows. Run the example by running the following command: ```bash python workflow.py dump_and_load_dataset -``` \ No newline at end of file +``` + +## Benchmarks Performance +### Signal Test +Here are the results of signal test for benchmark models. We will keep updating benchmark models in future. +| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Long precision| Short Precision | Long-Short Average Return | Long-Short Average Sharpe | +|---|---|---|---|---|---|---|---|---|---| +| LightGBM | Alpha158 | 0.3042±0.00 | 1.5372±0.00| 0.3117±0.00 | 1.6258±0.00 | 0.6720±0.00 | 0.6870±0.00 | 0.000769±0.00 | 1.0190±0.00 | From 941c980d06371b83cf54eef8e84b0614104eb5d4 Mon Sep 17 00:00:00 2001 From: Alex Wang Date: Wed, 14 Apr 2021 17:35:19 +0800 Subject: [PATCH 06/19] update tabnet --- examples/benchmarks/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/benchmarks/README.md b/examples/benchmarks/README.md index f1e7437fa..c3d965d85 100644 --- a/examples/benchmarks/README.md +++ b/examples/benchmarks/README.md @@ -17,6 +17,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of | ALSTM (Yao Qin, et al.) | Alpha360 | 0.0493±0.01 | 0.3778±0.06| 0.0585±0.00 | 0.4606±0.04 | 0.0513±0.03 | 0.6727±0.38| -0.1085±0.02 | | GATs (Petar Velickovic, et al.) | Alpha360 | 0.0475±0.00 | 0.3515±0.02| 0.0592±0.00 | 0.4585±0.01 | 0.0876±0.02 | 1.1513±0.27| -0.0795±0.02 | | DoubleEnsemble (Chuheng Zhang, et al.) | Alpha360 | 0.0407±0.00| 0.3053±0.00 | 0.0490±0.00 | 0.3840±0.00 | 0.0380±0.02 | 0.5000±0.21 | -0.0984±0.02 | +| TabNet (Sercan O. Arik, et al.)| Alpha360 | 0.0192±0.00 | 0.1401±0.00| 0.0291±0.00 | 0.2163±0.00 | -0.0258±0.00 | -0.2961±0.00| -0.1429±0.00 | ## Alpha158 dataset | Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown | @@ -32,6 +33,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of | ALSTM (Yao Qin, et al.) | Alpha158 (with selected 20 features) | 0.0385±0.01 | 0.3022±0.06| 0.0478±0.00 | 0.3874±0.04 | 0.0486±0.03 | 0.7141±0.45| -0.1088±0.03 | | GATs (Petar Velickovic, et al.) | Alpha158 (with selected 20 features) | 0.0349±0.00 | 0.2511±0.01| 0.0457±0.00 | 0.3537±0.01 | 0.0578±0.02 | 0.8221±0.25| -0.0824±0.02 | | DoubleEnsemble (Chuheng Zhang, et al.) | Alpha158 | 0.0544±0.00 | 0.4338±0.01 | 0.0523±0.00 | 0.4257±0.01 | 0.1253±0.01 | 1.4105±0.14 | -0.0902±0.01 | +| TabNet (Sercan O. Arik, et al.)| Alpha158 | 0.0383±0.00 | 0.3414±0.00| 0.0388±0.00 | 0.3460±0.00 | 0.0226±0.00 | 0.2652±0.00| -0.1072±0.00 | - The selected 20 features are based on the feature importance of a lightgbm-based model. - The base model of DoubleEnsemble is LGBM. From 848d953226cedc782f8949838698801458b1a829 Mon Sep 17 00:00:00 2001 From: Jactus Date: Fri, 16 Apr 2021 09:58:55 +0800 Subject: [PATCH 07/19] Update qlib logger --- qlib/log.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/qlib/log.py b/qlib/log.py index 126acb9d2..ed050f6c9 100644 --- a/qlib/log.py +++ b/qlib/log.py @@ -11,6 +11,26 @@ from contextlib import contextmanager from .config import C +class QlibLogger(Logger,meta=): + ''' + Customized logger for Qlib. + ''' + def __init__(self, module_name): + self.module_name = module_name + self.level = 0 + + @property + def logger(self): + logger = logging.getLogger(self.module_name) + logger.setLevel(self.level) + return logger + + def setLevel(self, level): + self.level = level + + def __getattr__(self, name): + return self.logger.__getattribute__(name) + def get_module_logger(module_name, level: Optional[int] = None): """ @@ -27,7 +47,7 @@ def get_module_logger(module_name, level: Optional[int] = None): module_name = "qlib.{}".format(module_name) # Get logger. - module_logger = logging.getLogger(module_name) + module_logger = QlibLogger(module_name) module_logger.setLevel(level) return module_logger From 78bb8882cd4f23e20a14d69682b54cdd24a3e200 Mon Sep 17 00:00:00 2001 From: Jactus Date: Fri, 16 Apr 2021 12:00:18 +0800 Subject: [PATCH 08/19] Format --- qlib/log.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/qlib/log.py b/qlib/log.py index ed050f6c9..017e8e339 100644 --- a/qlib/log.py +++ b/qlib/log.py @@ -11,10 +11,12 @@ from contextlib import contextmanager from .config import C -class QlibLogger(Logger,meta=): - ''' + +class QlibLogger: + """ Customized logger for Qlib. - ''' + """ + def __init__(self, module_name): self.module_name = module_name self.level = 0 @@ -27,10 +29,10 @@ class QlibLogger(Logger,meta=): def setLevel(self, level): self.level = level - + def __getattr__(self, name): return self.logger.__getattribute__(name) - + def get_module_logger(module_name, level: Optional[int] = None): """ From f4bfe8e6197aa52bfb759c3981346953f8306f41 Mon Sep 17 00:00:00 2001 From: Jactus Date: Fri, 16 Apr 2021 14:35:05 +0800 Subject: [PATCH 09/19] First trial of adding docstring --- qlib/log.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/qlib/log.py b/qlib/log.py index 017e8e339..c7d269f4d 100644 --- a/qlib/log.py +++ b/qlib/log.py @@ -12,7 +12,23 @@ from contextlib import contextmanager from .config import C -class QlibLogger: +class MetaLogger(type): + def __init__(self, name, bases, dic): + super().__init__(name, bases, dic) + + def __new__(cls, name, bases, dict): + wrapper_dict = type(logging.getLogger("module_name")).__dict__.copy() + wrapper_dict.update(dict) + wrapper_dict["__doc__"] = logging.getLogger("module_name").__doc__ + return type.__new__(cls, name, bases, wrapper_dict) + + def __call__(cls, *args, **kwargs): + obj = cls.__new__(cls) + cls.__init__(cls, *args, **kwargs) + return obj + + +class QlibLogger(metaclass=MetaLogger): """ Customized logger for Qlib. """ From 4ebf68479416932d8e28fdd4af289655e54a254f Mon Sep 17 00:00:00 2001 From: Jactus Date: Fri, 16 Apr 2021 15:35:11 +0800 Subject: [PATCH 10/19] Update workflow logging --- qlib/log.py | 4 ++-- qlib/workflow/exp.py | 4 ++-- qlib/workflow/expm.py | 4 ++-- qlib/workflow/record_temp.py | 4 ++-- qlib/workflow/recorder.py | 4 ++-- qlib/workflow/utils.py | 4 ++-- 6 files changed, 12 insertions(+), 12 deletions(-) diff --git a/qlib/log.py b/qlib/log.py index c7d269f4d..4ecdceef2 100644 --- a/qlib/log.py +++ b/qlib/log.py @@ -17,9 +17,9 @@ class MetaLogger(type): super().__init__(name, bases, dic) def __new__(cls, name, bases, dict): - wrapper_dict = type(logging.getLogger("module_name")).__dict__.copy() + wrapper_dict = type(logging.getLogger("MetaLogger")).__dict__.copy() wrapper_dict.update(dict) - wrapper_dict["__doc__"] = logging.getLogger("module_name").__doc__ + wrapper_dict["__doc__"] = logging.getLogger("MetaLogger").__doc__ return type.__new__(cls, name, bases, wrapper_dict) def __call__(cls, *args, **kwargs): diff --git a/qlib/workflow/exp.py b/qlib/workflow/exp.py index dd73f7f52..7b3d1f507 100644 --- a/qlib/workflow/exp.py +++ b/qlib/workflow/exp.py @@ -1,14 +1,14 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import mlflow +import mlflow, logging from mlflow.entities import ViewType from mlflow.exceptions import MlflowException from pathlib import Path from .recorder import Recorder, MLflowRecorder from ..log import get_module_logger -logger = get_module_logger("workflow", "INFO") +logger = get_module_logger("workflow", logging.INFO) class Experiment: diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py index 5275e57d7..590790c9e 100644 --- a/qlib/workflow/expm.py +++ b/qlib/workflow/expm.py @@ -4,7 +4,7 @@ import mlflow from mlflow.exceptions import MlflowException from mlflow.entities import ViewType -import os +import os, logging from pathlib import Path from contextlib import contextmanager from typing import Optional, Text @@ -14,7 +14,7 @@ from ..config import C from .recorder import Recorder from ..log import get_module_logger -logger = get_module_logger("workflow", "INFO") +logger = get_module_logger("workflow", logging.INFO) class ExpManager: diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index dee327f64..5732c95a9 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import re +import re, logging import pandas as pd from pathlib import Path from pprint import pprint @@ -16,7 +16,7 @@ from ..utils import flatten_dict from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec from ..contrib.strategy.strategy import BaseStrategy -logger = get_module_logger("workflow", "INFO") +logger = get_module_logger("workflow", logging.INFO) class RecordTemp: diff --git a/qlib/workflow/recorder.py b/qlib/workflow/recorder.py index 5915e58da..b9b2fd1b3 100644 --- a/qlib/workflow/recorder.py +++ b/qlib/workflow/recorder.py @@ -1,14 +1,14 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import mlflow +import mlflow, logging import shutil, os, pickle, tempfile, codecs, pickle from pathlib import Path from datetime import datetime from ..utils.objm import FileManager from ..log import get_module_logger -logger = get_module_logger("workflow", "INFO") +logger = get_module_logger("workflow", logging.INFO) class Recorder: diff --git a/qlib/workflow/utils.py b/qlib/workflow/utils.py index 33d251dd8..596ff0927 100644 --- a/qlib/workflow/utils.py +++ b/qlib/workflow/utils.py @@ -1,12 +1,12 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import sys, traceback, signal, atexit +import sys, traceback, signal, atexit, logging from . import R from .recorder import Recorder from ..log import get_module_logger -logger = get_module_logger("workflow", "INFO") +logger = get_module_logger("workflow", logging.INFO) # function to handle the experiment when unusual program ending occurs From cbf1fa721ed85f0d2e89ff19f9ec0e08af2339c2 Mon Sep 17 00:00:00 2001 From: Jactus Date: Sat, 17 Apr 2021 15:47:49 +0800 Subject: [PATCH 11/19] Update --- qlib/contrib/workflow/record_temp.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/qlib/contrib/workflow/record_temp.py b/qlib/contrib/workflow/record_temp.py index 12792fbcb..bedf89105 100644 --- a/qlib/contrib/workflow/record_temp.py +++ b/qlib/contrib/workflow/record_temp.py @@ -1,10 +1,11 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import logging import pandas as pd +import numpy as np from sklearn.metrics import mean_squared_error from typing import Dict, Text, Any -import numpy as np from ...contrib.eva.alpha import calc_ic from ...workflow.record_temp import RecordTemp @@ -12,7 +13,7 @@ from ...workflow.record_temp import SignalRecord from ...data import dataset as qlib_dataset from ...log import get_module_logger -logger = get_module_logger("workflow", "INFO") +logger = get_module_logger("workflow", logging.INFO) class MultiSegRecord(RecordTemp): From 6a05d4e2559f1917e6411478426cab6c4f6eaa78 Mon Sep 17 00:00:00 2001 From: Jactus Date: Mon, 19 Apr 2021 11:36:00 +0800 Subject: [PATCH 12/19] Enable IDEs docstrings --- qlib/log.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qlib/log.py b/qlib/log.py index 4ecdceef2..8b123d05d 100644 --- a/qlib/log.py +++ b/qlib/log.py @@ -50,7 +50,7 @@ class QlibLogger(metaclass=MetaLogger): return self.logger.__getattribute__(name) -def get_module_logger(module_name, level: Optional[int] = None): +def get_module_logger(module_name, level: Optional[int] = None) -> logging.Logger: """ Get a logger for a specific module. From aafaff45d2b0d2740d83b2651a8887f51011037b Mon Sep 17 00:00:00 2001 From: Jactus Date: Thu, 22 Apr 2021 14:13:36 +0800 Subject: [PATCH 13/19] Update doc --- qlib/contrib/backtest/backtest.py | 7 +++++-- qlib/contrib/report/analysis_position/cumulative_return.py | 2 +- qlib/contrib/report/analysis_position/rank_label.py | 2 +- qlib/contrib/report/analysis_position/report.py | 2 +- 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/qlib/contrib/backtest/backtest.py b/qlib/contrib/backtest/backtest.py index b87d6afe3..909948c25 100644 --- a/qlib/contrib/backtest/backtest.py +++ b/qlib/contrib/backtest/backtest.py @@ -15,7 +15,8 @@ LOG = get_module_logger("backtest") def backtest(pred, strategy, executor, trade_exchange, shift, verbose, account, benchmark, return_order): - """Parameters + """ + Parameters ---------- pred : pandas.DataFrame predict should has index and one `score` column @@ -124,7 +125,9 @@ def backtest(pred, strategy, executor, trade_exchange, shift, verbose, account, def update_account(trade_account, trade_info, trade_exchange, trade_date): - """Update the account and strategy + """ + Update the account and strategy + Parameters ---------- trade_account : Account() diff --git a/qlib/contrib/report/analysis_position/cumulative_return.py b/qlib/contrib/report/analysis_position/cumulative_return.py index abb68ea60..00985a17c 100644 --- a/qlib/contrib/report/analysis_position/cumulative_return.py +++ b/qlib/contrib/report/analysis_position/cumulative_return.py @@ -214,7 +214,7 @@ def cumulative_return_graph( features_df = D.features(D.instruments('csi500'), ['Ref($close, -1)/$close - 1'], pred_df_dates.min(), pred_df_dates.max()) features_df.columns = ['label'] - qcr.cumulative_return_graph(positions, report_normal_df, features_df) + qcr.analysis_position.cumulative_return_graph(positions, report_normal_df, features_df) Graph desc: diff --git a/qlib/contrib/report/analysis_position/rank_label.py b/qlib/contrib/report/analysis_position/rank_label.py index 72a358adc..77743b10c 100644 --- a/qlib/contrib/report/analysis_position/rank_label.py +++ b/qlib/contrib/report/analysis_position/rank_label.py @@ -94,7 +94,7 @@ def rank_label_graph( features_df = D.features(D.instruments('csi500'), ['Ref($close, -1)/$close-1'], pred_df_dates.min(), pred_df_dates.max()) features_df.columns = ['label'] - qcr.rank_label_graph(positions, features_df, pred_df_dates.min(), pred_df_dates.max()) + qcr.analysis_position.rank_label_graph(positions, features_df, pred_df_dates.min(), pred_df_dates.max()) :param position: position data; **qlib.contrib.backtest.backtest.backtest** result. diff --git a/qlib/contrib/report/analysis_position/report.py b/qlib/contrib/report/analysis_position/report.py index f82e654c4..6b83f0734 100644 --- a/qlib/contrib/report/analysis_position/report.py +++ b/qlib/contrib/report/analysis_position/report.py @@ -186,7 +186,7 @@ def report_graph(report_df: pd.DataFrame, show_notebook: bool = True) -> [list, report_normal_df, _ = backtest(pred_df, strategy, **bparas) - qcr.report_graph(report_normal_df) + qcr.analysis_position.report_graph(report_normal_df) :param report_df: **df.index.name** must be **date**, **df.columns** must contain **return**, **turnover**, **cost**, **bench**. From 8adfafa6aa6f76591ae2af537f9d8ad91ccb6c43 Mon Sep 17 00:00:00 2001 From: Jactus Date: Thu, 22 Apr 2021 14:17:25 +0800 Subject: [PATCH 14/19] Black format --- qlib/contrib/backtest/backtest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qlib/contrib/backtest/backtest.py b/qlib/contrib/backtest/backtest.py index 909948c25..fc30065fd 100644 --- a/qlib/contrib/backtest/backtest.py +++ b/qlib/contrib/backtest/backtest.py @@ -127,7 +127,7 @@ def backtest(pred, strategy, executor, trade_exchange, shift, verbose, account, def update_account(trade_account, trade_info, trade_exchange, trade_date): """ Update the account and strategy - + Parameters ---------- trade_account : Account() From fbff4c271a7e74f2f0b4770912abf2fb01a9354b Mon Sep 17 00:00:00 2001 From: Jactus Date: Fri, 23 Apr 2021 00:38:45 +0800 Subject: [PATCH 15/19] Remove redundant methods in meta --- qlib/log.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/qlib/log.py b/qlib/log.py index 8b123d05d..3b3362d5b 100644 --- a/qlib/log.py +++ b/qlib/log.py @@ -13,8 +13,6 @@ from .config import C class MetaLogger(type): - def __init__(self, name, bases, dic): - super().__init__(name, bases, dic) def __new__(cls, name, bases, dict): wrapper_dict = type(logging.getLogger("MetaLogger")).__dict__.copy() @@ -22,11 +20,6 @@ class MetaLogger(type): wrapper_dict["__doc__"] = logging.getLogger("MetaLogger").__doc__ return type.__new__(cls, name, bases, wrapper_dict) - def __call__(cls, *args, **kwargs): - obj = cls.__new__(cls) - cls.__init__(cls, *args, **kwargs) - return obj - class QlibLogger(metaclass=MetaLogger): """ From e410caaa8fb315de7898035986ec7cca58384bf0 Mon Sep 17 00:00:00 2001 From: Jactus Date: Fri, 23 Apr 2021 10:08:12 +0800 Subject: [PATCH 16/19] Simplify meta class --- qlib/log.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/qlib/log.py b/qlib/log.py index 3b3362d5b..5888b3841 100644 --- a/qlib/log.py +++ b/qlib/log.py @@ -13,11 +13,10 @@ from .config import C class MetaLogger(type): - def __new__(cls, name, bases, dict): - wrapper_dict = type(logging.getLogger("MetaLogger")).__dict__.copy() + wrapper_dict = logging.Logger.__dict__.copy() wrapper_dict.update(dict) - wrapper_dict["__doc__"] = logging.getLogger("MetaLogger").__doc__ + wrapper_dict["__doc__"] = logging.Logger.__doc__ return type.__new__(cls, name, bases, wrapper_dict) From e15ea06122bd570706ac8b6d3ab6b96b5ee64edb Mon Sep 17 00:00:00 2001 From: zhupr Date: Sun, 25 Apr 2021 23:50:29 +0800 Subject: [PATCH 17/19] Fix ClientProvider not supporting LocalInstrumentProvider && online using the latest python-socketio --- qlib/data/data.py | 3 ++- setup.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/qlib/data/data.py b/qlib/data/data.py index 000bd1196..cea2f42eb 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -1016,7 +1016,8 @@ class ClientProvider(BaseProvider): self.logger = get_module_logger(self.__class__.__name__) if isinstance(Cal, ClientCalendarProvider): Cal.set_conn(self.client) - Inst.set_conn(self.client) + if isinstance(Inst, ClientInstrumentProvider): + Inst.set_conn(self.client) if hasattr(DatasetD, "provider"): DatasetD.provider.set_conn(self.client) else: diff --git a/setup.py b/setup.py index 83cf6e1b6..747d885f4 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ REQUIRED = [ "scipy>=1.0.0", "requests>=2.18.0", "sacred>=0.7.4", - "python-socketio==3.1.2", + "python-socketio", "redis>=3.0.1", "python-redis-lock>=3.3.1", "schedule>=0.6.0", From 5a7eecabeefdf5218a4a4ea1db5ed94343df6c42 Mon Sep 17 00:00:00 2001 From: Young Date: Tue, 27 Apr 2021 04:04:43 +0000 Subject: [PATCH 18/19] black formating (black is upgraded in github) --- examples/benchmarks/TFT/data_formatters/base.py | 2 +- qlib/contrib/backtest/position.py | 2 +- qlib/contrib/report/graph.py | 2 +- qlib/data/dataset/processor.py | 4 ++-- qlib/model/base.py | 4 ++-- qlib/portfolio/optimizer/base.py | 4 ++-- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/benchmarks/TFT/data_formatters/base.py b/examples/benchmarks/TFT/data_formatters/base.py index c68a192ba..aa1c0dc82 100644 --- a/examples/benchmarks/TFT/data_formatters/base.py +++ b/examples/benchmarks/TFT/data_formatters/base.py @@ -132,7 +132,7 @@ class GenericDataFormatter(abc.ABC): return -1, -1 def get_column_definition(self): - """"Returns formatted column definition in order expected by the TFT.""" + """Returns formatted column definition in order expected by the TFT.""" column_definition = self._column_definition diff --git a/qlib/contrib/backtest/position.py b/qlib/contrib/backtest/position.py index 6c269d505..97abc2a56 100644 --- a/qlib/contrib/backtest/position.py +++ b/qlib/contrib/backtest/position.py @@ -128,7 +128,7 @@ class Position: return self.position["cash"] def get_stock_amount_dict(self): - """generate stock amount dict {stock_id : amount of stock} """ + """generate stock amount dict {stock_id : amount of stock}""" d = {} stock_list = self.get_stock_list() for stock_code in stock_list: diff --git a/qlib/contrib/report/graph.py b/qlib/contrib/report/graph.py index 677e767ee..2d4f546e8 100644 --- a/qlib/contrib/report/graph.py +++ b/qlib/contrib/report/graph.py @@ -18,7 +18,7 @@ from ...utils import get_module_by_module_path class BaseGraph: - """""" + """ """ _name = None diff --git a/qlib/data/dataset/processor.py b/qlib/data/dataset/processor.py index e035f5624..7635a4127 100644 --- a/qlib/data/dataset/processor.py +++ b/qlib/data/dataset/processor.py @@ -129,7 +129,7 @@ class FilterCol(Processor): class TanhProcess(Processor): - """ Use tanh to process noise data""" + """Use tanh to process noise data""" def __call__(self, df): def tanh_denoise(data): @@ -144,7 +144,7 @@ class TanhProcess(Processor): class ProcessInf(Processor): - """Process infinity """ + """Process infinity""" def __call__(self, df): def replace_inf(data): diff --git a/qlib/model/base.py b/qlib/model/base.py index 1ac8f2fc9..12caf5f73 100644 --- a/qlib/model/base.py +++ b/qlib/model/base.py @@ -11,11 +11,11 @@ class BaseModel(Serializable, metaclass=abc.ABCMeta): @abc.abstractmethod def predict(self, *args, **kwargs) -> object: - """ Make predictions after modeling things """ + """Make predictions after modeling things""" pass def __call__(self, *args, **kwargs) -> object: - """ leverage Python syntactic sugar to make the models' behaviors like functions """ + """leverage Python syntactic sugar to make the models' behaviors like functions""" return self.predict(*args, **kwargs) diff --git a/qlib/portfolio/optimizer/base.py b/qlib/portfolio/optimizer/base.py index 502443869..e3f692014 100644 --- a/qlib/portfolio/optimizer/base.py +++ b/qlib/portfolio/optimizer/base.py @@ -5,9 +5,9 @@ import abc class BaseOptimizer(abc.ABC): - """ Construct portfolio with a optimization related method """ + """Construct portfolio with a optimization related method""" @abc.abstractmethod def __call__(self, *args, **kwargs) -> object: - """ Generate a optimized portfolio allocation """ + """Generate a optimized portfolio allocation""" pass From 8b8d21107c7f6dd6f6e6db371f4591179a4ad616 Mon Sep 17 00:00:00 2001 From: zhupr Date: Tue, 27 Apr 2021 21:20:47 +0800 Subject: [PATCH 19/19] Add future trading date collector --- qlib/data/data.py | 3 + scripts/data_collector/contrib/README.md | 24 +++++ .../contrib/future_trading_date_collector.py | 87 +++++++++++++++++++ .../data_collector/contrib/requirements.txt | 5 ++ scripts/data_collector/utils.py | 37 ++++++++ scripts/data_collector/yahoo/collector.py | 25 ++---- 6 files changed, 165 insertions(+), 16 deletions(-) create mode 100644 scripts/data_collector/contrib/README.md create mode 100644 scripts/data_collector/contrib/future_trading_date_collector.py create mode 100644 scripts/data_collector/contrib/requirements.txt diff --git a/qlib/data/data.py b/qlib/data/data.py index cea2f42eb..c2638e234 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -522,6 +522,9 @@ class LocalCalendarProvider(CalendarProvider): # if future calendar not exists, return current calendar if not os.path.exists(fname): get_module_logger("data").warning(f"{freq}_future.txt not exists, return current calendar!") + get_module_logger("data").warning( + "You can get future calendar by referring to the following document: https://github.com/microsoft/qlib/blob/main/scripts/data_collector/contrib/README.md" + ) fname = self._uri_cal.format(freq) else: fname = self._uri_cal.format(freq) diff --git a/scripts/data_collector/contrib/README.md b/scripts/data_collector/contrib/README.md new file mode 100644 index 000000000..011ff56e6 --- /dev/null +++ b/scripts/data_collector/contrib/README.md @@ -0,0 +1,24 @@ +# Get future trading days + +> `D.calendar(future=True)` will be used + +## Requirements + +```bash +pip install -r requirements.txt +``` + +## Collector Data + +```bash +# parse instruments, using in qlib/instruments. +python future_trading_date_collector.py --qlib_dir ~/.qlib/qlib_data/cn_data --freq day +``` + +## Parameters + +- qlib_dir: qlib data directory +- freq: value from [`day`, `1min`], default `day` + + + diff --git a/scripts/data_collector/contrib/future_trading_date_collector.py b/scripts/data_collector/contrib/future_trading_date_collector.py new file mode 100644 index 000000000..4da62d465 --- /dev/null +++ b/scripts/data_collector/contrib/future_trading_date_collector.py @@ -0,0 +1,87 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import sys +from typing import List +from pathlib import Path + +import fire +import numpy as np +import pandas as pd +from loguru import logger + +# get data from baostock +import baostock as bs + +CUR_DIR = Path(__file__).resolve().parent +sys.path.append(str(CUR_DIR.parent.parent)) + + +from data_collector.utils import generate_minutes_calendar_from_daily + + +def read_calendar_from_qlib(qlib_dir: Path) -> pd.DataFrame: + calendar_path = qlib_dir.joinpath("calendars").joinpath("day.txt") + if not calendar_path.exists(): + return pd.DataFrame() + return pd.read_csv(calendar_path, header=None) + + +def write_calendar_to_qlib(qlib_dir: Path, date_list: List[str], freq: str = "day"): + calendar_path = str(qlib_dir.joinpath("calendars").joinpath(f"{freq}_future.txt")) + + np.savetxt(calendar_path, date_list, fmt="%s", encoding="utf-8") + logger.info(f"write future calendars success: {calendar_path}") + + +def generate_qlib_calendar(date_list: List[str], freq: str) -> List[str]: + print(freq) + if freq == "day": + return date_list + elif freq == "1min": + date_list = generate_minutes_calendar_from_daily(date_list, freq=freq).tolist() + return list(map(lambda x: pd.Timestamp(x).strftime("%Y-%m-%d %H:%M:%S"), date_list)) + else: + raise ValueError(f"Unsupported freq: {freq}") + + +def future_calendar_collector(qlib_dir: [str, Path], freq: str = "day"): + """get future calendar + + Parameters + ---------- + qlib_dir: str or Path + qlib data directory + freq: str + value from ["day", "1min"], by default day + """ + qlib_dir = Path(qlib_dir).expanduser().resolve() + if not qlib_dir.exists(): + raise FileNotFoundError(str(qlib_dir)) + + lg = bs.login() + if lg.error_code != "0": + logger.error(f"login error: {lg.error_msg}") + return + # read daily calendar + daily_calendar = read_calendar_from_qlib(qlib_dir) + end_year = pd.Timestamp.now().year + if daily_calendar.empty: + start_year = pd.Timestamp.now().year + else: + start_year = pd.Timestamp(daily_calendar.iloc[-1, 0]).year + rs = bs.query_trade_dates(start_date=pd.Timestamp(f"{start_year}-01-01"), end_date=f"{end_year}-12-31") + data_list = [] + while (rs.error_code == "0") & rs.next(): + _row_data = rs.get_row_data() + if int(_row_data[1]) == 1: + data_list.append(_row_data[0]) + data_list = sorted(data_list) + date_list = generate_qlib_calendar(data_list, freq=freq) + write_calendar_to_qlib(qlib_dir, date_list, freq=freq) + bs.logout() + logger.info(f"get trading dates success: {start_year}-01-01 to {end_year}-12-31") + + +if __name__ == "__main__": + fire.Fire(future_calendar_collector) diff --git a/scripts/data_collector/contrib/requirements.txt b/scripts/data_collector/contrib/requirements.txt new file mode 100644 index 000000000..92dcb2374 --- /dev/null +++ b/scripts/data_collector/contrib/requirements.txt @@ -0,0 +1,5 @@ +baostock +fire +numpy +pandas +loguru diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index e8c9b9dc4..3f4539612 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -10,7 +10,9 @@ import random import requests import functools from pathlib import Path +from typing import Iterable, Tuple +import numpy as np import pandas as pd from lxml import etree from loguru import logger @@ -418,5 +420,40 @@ def get_trading_date_by_shift(trading_list: list, trading_date: pd.Timestamp, sh return res +def generate_minutes_calendar_from_daily( + calendars: Iterable, + freq: str = "1min", + am_range: Tuple[str, str] = ("09:30:00", "11:29:00"), + pm_range: Tuple[str, str] = ("13:00:00", "14:59:00"), +) -> pd.Index: + """generate minutes calendar + + Parameters + ---------- + calendars: Iterable + daily calendar + freq: str + by default 1min + am_range: Tuple[str, str] + AM Time Range, by default China-Stock: ("09:30:00", "11:29:00") + pm_range: Tuple[str, str] + PM Time Range, by default China-Stock: ("13:00:00", "14:59:00") + + """ + daily_format: str = "%Y-%m-%d" + res = [] + for _day in calendars: + for _range in [am_range, pm_range]: + res.append( + pd.date_range( + f"{pd.Timestamp(_day).strftime(daily_format)} {_range[0]}", + f"{pd.Timestamp(_day).strftime(daily_format)} {_range[1]}", + freq=freq, + ) + ) + + return pd.Index(sorted(set(np.hstack(res)))) + + if __name__ == "__main__": assert len(get_hs_stock_symbols()) >= MINIMUM_SYMBOLS_NUM diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index f0e110694..a6e06613e 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -24,7 +24,12 @@ from qlib.config import REG_CN as REGION_CN CUR_DIR = Path(__file__).resolve().parent sys.path.append(str(CUR_DIR.parent.parent)) from data_collector.base import BaseCollector, BaseNormalize, BaseRun -from data_collector.utils import get_calendar_list, get_hs_stock_symbols, get_us_stock_symbols +from data_collector.utils import ( + get_calendar_list, + get_hs_stock_symbols, + get_us_stock_symbols, + generate_minutes_calendar_from_daily, +) INDEX_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{index_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg={begin}&end={end}" @@ -418,21 +423,9 @@ class YahooNormalize1min(YahooNormalize, ABC): return calendar_list_1d def generate_1min_from_daily(self, calendars: Iterable) -> pd.Index: - res = [] - daily_format = self.DAILY_FORMAT - am_range = self.AM_RANGE - pm_range = self.PM_RANGE - for _day in calendars: - for _range in [am_range, pm_range]: - res.append( - pd.date_range( - f"{_day.strftime(daily_format)} {_range[0]}", - f"{_day.strftime(daily_format)} {_range[1]}", - freq="1min", - ) - ) - - return pd.Index(sorted(set(np.hstack(res)))) + return generate_minutes_calendar_from_daily( + calendars, freq="1min", am_range=self.AM_RANGE, pm_range=self.PM_RANGE + ) def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame: # TODO: using daily data factor