From 1ad237f89fc5197a6629b8e2df2217dd3e2fb712 Mon Sep 17 00:00:00 2001 From: Alex Wang Date: Mon, 22 Mar 2021 14:20:44 +0800 Subject: [PATCH 01/30] update high freq demo --- ...rkflow_config_High_Freq_Tree_Alpha158.yaml | 65 ++++++++ qlib/contrib/eva/alpha.py | 40 +++++ qlib/contrib/model/highfreq_gdbt_model.py | 157 ++++++++++++++++++ qlib/workflow/record_temp.py | 50 +++++- 4 files changed, 311 insertions(+), 1 deletion(-) create mode 100644 examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml create mode 100644 qlib/contrib/model/highfreq_gdbt_model.py diff --git a/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml b/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml new file mode 100644 index 000000000..ca8e92d08 --- /dev/null +++ b/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml @@ -0,0 +1,65 @@ +qlib_init: + provider_uri: "~/.qlib/qlib_data/yahoo_cn_1min" + region: cn +market: &market ['SH605222', 'SZ002796', 'SZ002246', 'SZ000713', 'SZ002820', 'SH601328', 'SZ000668', 'SH603359', 'SZ002144', 'SH600195', 'SH603685', 'SH603386', 'SZ002586', 'SZ000573', 'SZ000605', 'SZ002842', 'SH600068', 'SZ300547', 'SZ000926', 'SZ002036', 'SZ002161', 'SH600715', 'SZ300427', 'SZ002573', 'SZ300142', 'SH605116', 'SZ002951', 'SH600276', 'SZ002437', 'SH603355', 'SZ002893', 'SH600584'] +start_time: &start_time "2020-09-15 00:00:00" +end_time: &end_time "2021-01-18 16:00:00" +train_end_time: &train_end_time "2020-11-15 16:00:00" +valid_start_time: &valid_start_time "2020-11-16 00:00:00" +valid_end_time: &valid_end_time "2020-11-30 16:00:00" +test_start_time: &test_start_time "2020-12-01 00:00:00" +data_handler_config: &data_handler_config + start_time: *start_time + end_time: *end_time + fit_start_time: *start_time + fit_end_time: *train_end_time + instruments: *market + freq: '1min' + infer_processors: + - class: 'RobustZScoreNorm' + kwargs: + fields_group: 'feature' + clip_outlier: false + - class: "Fillna" + kwargs: + fields_group: 'feature' + learn_processors: + - class: 'DropnaLabel' + - class: 'CSRankNorm' + kwargs: + fields_group: 'label' + label: ["Ref($close, -2) / Ref($close, -1) - 1"] + +task: + model: + class: "HF_LGBModel" + module_path: "qlib.contrib.model.highfreq_gdbt_model" + kwargs: + objective: 'binary' + metric: ['binary_logloss','auc'] + verbosity: -1 + learning_rate: 0.01 + max_depth: 8 + num_leaves: 150 + lambda_l1: 1.5 + lambda_l2: 1 + num_threads: 20 + dataset: + class: "DatasetH" + module_path: "qlib.data.dataset" + kwargs: + handler: + class: "Alpha158" + module_path: "qlib.contrib.data.handler" + kwargs: *data_handler_config + segments: + train: [*start_time, *train_end_time] + valid: [*train_end_time, *valid_end_time] + test: [*test_start_time, *end_time] + record: + - class: "SignalRecord" + module_path: "qlib.workflow.record_temp" + kwargs: {} + - class: "HFSignalRecord" + module_path: "qlib.workflow.record_temp" + kwargs: {} \ No newline at end of file diff --git a/qlib/contrib/eva/alpha.py b/qlib/contrib/eva/alpha.py index c68571853..e2beafc13 100644 --- a/qlib/contrib/eva/alpha.py +++ b/qlib/contrib/eva/alpha.py @@ -7,6 +7,46 @@ import pandas as pd from typing import Tuple +def calc_prec(pred: pd.Series, label: pd.Series, date_col="datetime", quantile: float = 0.2, dropna=False, is_alpha=False) -> Tuple[pd.Series, pd.Series]: + """ calculate the precision + pred : + pred + label : + label + date_col : + date_col + + Returns + ------- + (pd.Series, pd.Series) + long precision and short precision in time level + """ + if is_alpha: + label = label - label.mean(level=0) + if int(1/quantile) >= len(label.index.get_level_values(1).unique()): + raise ValueError("Need more instruments to calculate precision") + + + df = pd.DataFrame({"pred": pred, "label": label}) + if dropna: + df.dropna(inplace = True) + + group = df.groupby(level=date_col) + + N = lambda x: int(len(x) * quantile) + # find the top/low quantile of prediction and treat them as long and short target + long = group.apply(lambda x: x.nlargest(N(x), columns="pred").label).reset_index(level=0, drop=True) + short = group.apply(lambda x: x.nsmallest(N(x), columns="pred").label).reset_index(level=0, drop=True) + + groupll = long.groupby(date_col) + ll_ration = groupll.apply(lambda x: x > 0) + ll_c = groupll.count() + + groups = short.groupby(date_col) + s_ration = groups.apply(lambda x: x < 0) + s_c = groups.count() + return (ll_ration.groupby(date_col).sum()/ll_c), (s_ration.groupby(date_col).sum()/s_c) + def calc_ic(pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False) -> Tuple[pd.Series, pd.Series]: """calc_ic. diff --git a/qlib/contrib/model/highfreq_gdbt_model.py b/qlib/contrib/model/highfreq_gdbt_model.py new file mode 100644 index 000000000..62e45c841 --- /dev/null +++ b/qlib/contrib/model/highfreq_gdbt_model.py @@ -0,0 +1,157 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import numpy as np +import pandas as pd +import lightgbm as lgb + +from qlib.model.base import ModelFT +from qlib.data.dataset import DatasetH +from qlib.data.dataset.handler import DataHandlerLP +import warnings + + +class HF_LGBModel(ModelFT): + """LightGBM Model""" + + def __init__(self, loss="mse", **kwargs): + if loss not in {"mse", "binary"}: + raise NotImplementedError + self.params = {"objective": loss, "verbosity": -1} + self.params.update(kwargs) + self.model = None + + def _cal_signal_metrics(self, y_test, l_cut, r_cut): + """ + Calcaute the signal metrics by daily level + """ + up_pre, down_pre = [], [] + up_alpha_ll, down_alpha_ll = [], [] + for date in y_test.index.get_level_values(0).unique(): + df_res = y_test.loc[date].sort_values("pred") + if int(l_cut * len(df_res)) < 10: + warnings.warn("Warning: threhold is too low or instruments number is not enough") + continue + top = df_res.iloc[: int(l_cut * len(df_res))] + bottom = df_res.iloc[int(r_cut * len(df_res)) :] + + down_precision = len(top[top[top.columns[0]] < 0]) / (len(top)) + up_precision = len(bottom[bottom[top.columns[0]] > 0]) / (len(bottom)) + + down_alpha = top[top.columns[0]].mean() + up_alpha = bottom[bottom.columns[0]].mean() + + up_pre.append(up_precision) + down_pre.append(down_precision) + up_alpha_ll.append(up_alpha) + down_alpha_ll.append(down_alpha) + + return ( + np.array(up_pre).mean(), + np.array(down_pre).mean(), + np.array(up_alpha_ll).mean(), + np.array(down_alpha_ll).mean(), + ) + + def hf_signal_test(self, dataset: DatasetH, threhold=0.2): + """ + Test the sigal in high frequency test set + """ + if self.model == None: + raise ValueError("Model hasn't been trained yet") + df_test = dataset.prepare("test", col_set=["feature", "label"], data_key=DataHandlerLP.DK_I) + df_test.dropna(inplace=True) + x_test, y_test = df_test["feature"], df_test["label"] + # Convert label into alpha + y_test[y_test.columns[0]] = y_test[y_test.columns[0]] - y_test[y_test.columns[0]].mean(level=0) + + res = pd.Series(self.model.predict(x_test.values), index=x_test.index) + y_test["pred"] = res + + up_p, down_p, up_a, down_a = self._cal_signal_metrics(y_test, threhold, 1 - threhold) + print("===============================") + print("High frequency signal test") + print("===============================") + print("Test set precision: ") + print("Positive precision: {}, Negative precision: {}".format(up_p, down_p)) + print("Test Alpha Average in test set: ") + print("Positive average alpha: {}, Negative average alpha: {}".format(up_a, down_a)) + + def _prepare_data(self, dataset: DatasetH): + df_train, df_valid = dataset.prepare( + ["train", "valid"], col_set=["feature", "label"], data_key=DataHandlerLP.DK_L + ) + + x_train, y_train = df_train["feature"], df_train["label"] + x_valid, y_valid = df_train["feature"], df_valid["label"] + if y_train.values.ndim == 2 and y_train.values.shape[1] == 1: + l_name = df_train["label"].columns[0] + # Convert label into alpha + df_train["label"][l_name] = df_train["label"][l_name] - df_train["label"][l_name].mean(level=0) + df_valid["label"][l_name] = df_valid["label"][l_name] - df_valid["label"][l_name].mean(level=0) + mapping_fn = lambda x: 0 if x < 0 else 1 + df_train["label_c"] = df_train["label"][l_name].apply(mapping_fn) + df_valid["label_c"] = df_valid["label"][l_name].apply(mapping_fn) + x_train, y_train = df_train["feature"], df_train["label_c"].values + x_valid, y_valid = df_valid["feature"], df_valid["label_c"].values + else: + raise ValueError("LightGBM doesn't support multi-label training") + + dtrain = lgb.Dataset(x_train.values, label=y_train) + dvalid = lgb.Dataset(x_valid.values, label=y_valid) + return dtrain, dvalid + + def fit( + self, + dataset: DatasetH, + num_boost_round=1000, + early_stopping_rounds=50, + verbose_eval=20, + evals_result=dict(), + **kwargs + ): + dtrain, dvalid = self._prepare_data(dataset) + self.model = lgb.train( + self.params, + dtrain, + num_boost_round=num_boost_round, + valid_sets=[dtrain, dvalid], + valid_names=["train", "valid"], + early_stopping_rounds=early_stopping_rounds, + verbose_eval=verbose_eval, + evals_result=evals_result, + **kwargs + ) + evals_result["train"] = list(evals_result["train"].values())[0] + evals_result["valid"] = list(evals_result["valid"].values())[0] + + def predict(self, dataset): + if self.model is None: + raise ValueError("model is not fitted yet!") + x_test = dataset.prepare("test", col_set="feature", data_key=DataHandlerLP.DK_I) + return pd.Series(self.model.predict(x_test.values), index=x_test.index) + + def finetune(self, dataset: DatasetH, num_boost_round=10, verbose_eval=20): + """ + finetune model + + Parameters + ---------- + dataset : DatasetH + dataset for finetuning + num_boost_round : int + number of round to finetune model + verbose_eval : int + verbose level + """ + # Based on existing model and finetune by train more rounds + dtrain, _ = self._prepare_data(dataset) + self.model = lgb.train( + self.params, + dtrain, + num_boost_round=num_boost_round, + init_model=self.model, + valid_sets=[dtrain], + valid_names=["train"], + verbose_eval=verbose_eval, + ) diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index 2c1b6fecc..8ab8405a5 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -13,7 +13,7 @@ from ..data.dataset.handler import DataHandlerLP from ..utils import init_instance_by_config, get_module_by_module_path from ..log import get_module_logger from ..utils import flatten_dict -from ..contrib.eva.alpha import calc_ic, calc_long_short_return +from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_prec from ..contrib.strategy.strategy import BaseStrategy logger = get_module_logger("workflow", "INFO") @@ -154,6 +154,54 @@ class SignalRecord(RecordTemp): def load(self, name="pred.pkl"): return super().load(name) + + +class HFSignalRecord(SignalRecord): + """ + This is the Signal Analysis Record class that generates the analysis results such as IC and IR. This class inherits the ``RecordTemp`` class. + """ + artifact_path = "hg_sig_analysis" + + def __init__(self, recorder, **kwargs): + super().__init__(recorder=recorder) + + def generate(self): + pred = self.load("pred.pkl") + raw_label = self.load("label.pkl") + + long_pre, short_pre = calc_prec(pred.iloc[:, 0], raw_label.iloc[:, 0], is_alpha = True) + ic, ric = calc_ic(pred.iloc[:, 0], raw_label.iloc[:, 0]) + metrics = { + "IC": ic.mean(), + "ICIR": ic.mean() / ic.std(), + "Rank IC": ric.mean(), + "Rank ICIR": ric.mean() / ric.std(), + "Long precision": long_pre.mean(), + "Short precision": short_pre.mean() + } + objects = {"ic.pkl": ic, "ric.pkl": ric} + objects.update({"long_pre.pkl": long_pre, "short_pre.pkl": short_pre}) + long_short_r, long_avg_r = calc_long_short_return(pred.iloc[:, 0], raw_label.iloc[:, 0]) + metrics.update( + { + "Long-Short Average Return": long_short_r.mean(), + "Long-Short Average Sharpe": long_short_r.mean() / long_short_r.std(), + } + ) + objects.update( + { + "long_short_r.pkl": long_short_r, + "long_avg_r.pkl": long_avg_r, + } + ) + self.recorder.log_metrics(**metrics) + self.recorder.save_objects(**objects, artifact_path=self.get_path()) + pprint(metrics) + + def list(self): + paths = [self.get_path("ic.pkl"), self.get_path("ric.pkl"), self.get_path("long_pre.pkl"), self.get_path("short_pre.pkl")] + paths.extend([self.get_path("long_short_r.pkl"), self.get_path("long_avg_r.pkl")]) + return paths class SigAnaRecord(SignalRecord): From 3bf6c7f95f5cc77d4025358e618d5f688138f5cc Mon Sep 17 00:00:00 2001 From: Alex Wang Date: Mon, 22 Mar 2021 15:37:54 +0800 Subject: [PATCH 02/30] update format --- qlib/contrib/eva/alpha.py | 24 +++++++++++++----------- qlib/workflow/record_temp.py | 16 +++++++++++----- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/qlib/contrib/eva/alpha.py b/qlib/contrib/eva/alpha.py index e2beafc13..8078dd4ed 100644 --- a/qlib/contrib/eva/alpha.py +++ b/qlib/contrib/eva/alpha.py @@ -7,15 +7,18 @@ import pandas as pd from typing import Tuple -def calc_prec(pred: pd.Series, label: pd.Series, date_col="datetime", quantile: float = 0.2, dropna=False, is_alpha=False) -> Tuple[pd.Series, pd.Series]: - """ calculate the precision + +def calc_prec( + pred: pd.Series, label: pd.Series, date_col="datetime", quantile: float = 0.2, dropna=False, is_alpha=False +) -> Tuple[pd.Series, pd.Series]: + """calculate the precision pred : pred label : label date_col : date_col - + Returns ------- (pd.Series, pd.Series) @@ -23,29 +26,28 @@ def calc_prec(pred: pd.Series, label: pd.Series, date_col="datetime", quantile: """ if is_alpha: label = label - label.mean(level=0) - if int(1/quantile) >= len(label.index.get_level_values(1).unique()): + if int(1 / quantile) >= len(label.index.get_level_values(1).unique()): raise ValueError("Need more instruments to calculate precision") - df = pd.DataFrame({"pred": pred, "label": label}) if dropna: - df.dropna(inplace = True) - + df.dropna(inplace=True) + group = df.groupby(level=date_col) - + N = lambda x: int(len(x) * quantile) # find the top/low quantile of prediction and treat them as long and short target long = group.apply(lambda x: x.nlargest(N(x), columns="pred").label).reset_index(level=0, drop=True) short = group.apply(lambda x: x.nsmallest(N(x), columns="pred").label).reset_index(level=0, drop=True) - + groupll = long.groupby(date_col) ll_ration = groupll.apply(lambda x: x > 0) ll_c = groupll.count() - + groups = short.groupby(date_col) s_ration = groups.apply(lambda x: x < 0) s_c = groups.count() - return (ll_ration.groupby(date_col).sum()/ll_c), (s_ration.groupby(date_col).sum()/s_c) + return (ll_ration.groupby(date_col).sum() / ll_c), (s_ration.groupby(date_col).sum() / s_c) def calc_ic(pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False) -> Tuple[pd.Series, pd.Series]: diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index 8ab8405a5..c47b999f3 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -154,12 +154,13 @@ class SignalRecord(RecordTemp): def load(self, name="pred.pkl"): return super().load(name) - - + + class HFSignalRecord(SignalRecord): """ This is the Signal Analysis Record class that generates the analysis results such as IC and IR. This class inherits the ``RecordTemp`` class. """ + artifact_path = "hg_sig_analysis" def __init__(self, recorder, **kwargs): @@ -169,7 +170,7 @@ class HFSignalRecord(SignalRecord): pred = self.load("pred.pkl") raw_label = self.load("label.pkl") - long_pre, short_pre = calc_prec(pred.iloc[:, 0], raw_label.iloc[:, 0], is_alpha = True) + long_pre, short_pre = calc_prec(pred.iloc[:, 0], raw_label.iloc[:, 0], is_alpha=True) ic, ric = calc_ic(pred.iloc[:, 0], raw_label.iloc[:, 0]) metrics = { "IC": ic.mean(), @@ -177,7 +178,7 @@ class HFSignalRecord(SignalRecord): "Rank IC": ric.mean(), "Rank ICIR": ric.mean() / ric.std(), "Long precision": long_pre.mean(), - "Short precision": short_pre.mean() + "Short precision": short_pre.mean(), } objects = {"ic.pkl": ic, "ric.pkl": ric} objects.update({"long_pre.pkl": long_pre, "short_pre.pkl": short_pre}) @@ -199,7 +200,12 @@ class HFSignalRecord(SignalRecord): pprint(metrics) def list(self): - paths = [self.get_path("ic.pkl"), self.get_path("ric.pkl"), self.get_path("long_pre.pkl"), self.get_path("short_pre.pkl")] + paths = [ + self.get_path("ic.pkl"), + self.get_path("ric.pkl"), + self.get_path("long_pre.pkl"), + self.get_path("short_pre.pkl"), + ] paths.extend([self.get_path("long_short_r.pkl"), self.get_path("long_avg_r.pkl")]) return paths From e3739bb980b5347520a13fd510bf9bf7180c8905 Mon Sep 17 00:00:00 2001 From: Alex Wang Date: Wed, 24 Mar 2021 15:47:26 +0800 Subject: [PATCH 03/30] fix naming and code style --- ...rkflow_config_High_Freq_Tree_Alpha158.yaml | 2 +- qlib/contrib/eva/alpha.py | 29 +++++++++++++------ qlib/contrib/model/highfreq_gdbt_model.py | 4 +-- qlib/workflow/record_temp.py | 8 ++--- 4 files changed, 27 insertions(+), 16 deletions(-) diff --git a/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml b/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml index ca8e92d08..c21ef1da3 100644 --- a/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml +++ b/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml @@ -32,7 +32,7 @@ data_handler_config: &data_handler_config task: model: - class: "HF_LGBModel" + class: "HFLGBModel" module_path: "qlib.contrib.model.highfreq_gdbt_model" kwargs: objective: 'binary' diff --git a/qlib/contrib/eva/alpha.py b/qlib/contrib/eva/alpha.py index 8078dd4ed..fadef9d16 100644 --- a/qlib/contrib/eva/alpha.py +++ b/qlib/contrib/eva/alpha.py @@ -8,12 +8,23 @@ import pandas as pd from typing import Tuple -def calc_prec( +def calc_long_short_prec( pred: pd.Series, label: pd.Series, date_col="datetime", quantile: float = 0.2, dropna=False, is_alpha=False ) -> Tuple[pd.Series, pd.Series]: - """calculate the precision - pred : - pred + """ + calculate the precision for long and short operation + + + :param pred/label: index is **pd.MultiIndex**, index name is **[datetime, instruments]**; columns names is **[score]**. + + .. code-block:: python + score + datetime instrument + 2020-12-01 09:30:00 SH600068 0.553634 + SH600195 0.550017 + SH600276 0.540321 + SH600584 0.517297 + SH600715 0.544674 label : label date_col : @@ -25,7 +36,7 @@ def calc_prec( long precision and short precision in time level """ if is_alpha: - label = label - label.mean(level=0) + label = label - label.mean(level=date_col) if int(1 / quantile) >= len(label.index.get_level_values(1).unique()): raise ValueError("Need more instruments to calculate precision") @@ -41,13 +52,13 @@ def calc_prec( short = group.apply(lambda x: x.nsmallest(N(x), columns="pred").label).reset_index(level=0, drop=True) groupll = long.groupby(date_col) - ll_ration = groupll.apply(lambda x: x > 0) - ll_c = groupll.count() + l_dom = groupll.apply(lambda x: x > 0) + l_c = groupll.count() groups = short.groupby(date_col) - s_ration = groups.apply(lambda x: x < 0) + s_dom = groups.apply(lambda x: x < 0) s_c = groups.count() - return (ll_ration.groupby(date_col).sum() / ll_c), (s_ration.groupby(date_col).sum() / s_c) + return (l_dom.groupby(date_col).sum() / l_c), (s_dom.groupby(date_col).sum() / s_c) def calc_ic(pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False) -> Tuple[pd.Series, pd.Series]: diff --git a/qlib/contrib/model/highfreq_gdbt_model.py b/qlib/contrib/model/highfreq_gdbt_model.py index 62e45c841..5a2eeb50a 100644 --- a/qlib/contrib/model/highfreq_gdbt_model.py +++ b/qlib/contrib/model/highfreq_gdbt_model.py @@ -11,8 +11,8 @@ from qlib.data.dataset.handler import DataHandlerLP import warnings -class HF_LGBModel(ModelFT): - """LightGBM Model""" +class HFLGBModel(ModelFT): + """LightGBM Model for high frequency prediction""" def __init__(self, loss="mse", **kwargs): if loss not in {"mse", "binary"}: diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index c47b999f3..239527fa0 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -13,7 +13,7 @@ from ..data.dataset.handler import DataHandlerLP from ..utils import init_instance_by_config, get_module_by_module_path from ..log import get_module_logger from ..utils import flatten_dict -from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_prec +from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec from ..contrib.strategy.strategy import BaseStrategy logger = get_module_logger("workflow", "INFO") @@ -169,8 +169,7 @@ class HFSignalRecord(SignalRecord): def generate(self): pred = self.load("pred.pkl") raw_label = self.load("label.pkl") - - long_pre, short_pre = calc_prec(pred.iloc[:, 0], raw_label.iloc[:, 0], is_alpha=True) + long_pre, short_pre = calc_long_short_prec(pred.iloc[:, 0], raw_label.iloc[:, 0], is_alpha=True) ic, ric = calc_ic(pred.iloc[:, 0], raw_label.iloc[:, 0]) metrics = { "IC": ic.mean(), @@ -205,8 +204,9 @@ class HFSignalRecord(SignalRecord): self.get_path("ric.pkl"), self.get_path("long_pre.pkl"), self.get_path("short_pre.pkl"), + self.get_path("long_short_r.pkl"), + self.get_path("long_avg_r.pkl"), ] - paths.extend([self.get_path("long_short_r.pkl"), self.get_path("long_avg_r.pkl")]) return paths From bed1175e2404ccd4d711bb71aff9577c8449c6a9 Mon Sep 17 00:00:00 2001 From: Alex Wang Date: Tue, 30 Mar 2021 19:29:17 +0800 Subject: [PATCH 04/30] update dataset --- .../highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml b/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml index c21ef1da3..45c59c670 100644 --- a/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml +++ b/examples/highfreq/workflow_config_High_Freq_Tree_Alpha158.yaml @@ -1,7 +1,7 @@ qlib_init: - provider_uri: "~/.qlib/qlib_data/yahoo_cn_1min" + provider_uri: "~/.qlib/qlib_data/cn_data_1min" region: cn -market: &market ['SH605222', 'SZ002796', 'SZ002246', 'SZ000713', 'SZ002820', 'SH601328', 'SZ000668', 'SH603359', 'SZ002144', 'SH600195', 'SH603685', 'SH603386', 'SZ002586', 'SZ000573', 'SZ000605', 'SZ002842', 'SH600068', 'SZ300547', 'SZ000926', 'SZ002036', 'SZ002161', 'SH600715', 'SZ300427', 'SZ002573', 'SZ300142', 'SH605116', 'SZ002951', 'SH600276', 'SZ002437', 'SH603355', 'SZ002893', 'SH600584'] +market: &market 'csi300' start_time: &start_time "2020-09-15 00:00:00" end_time: &end_time "2021-01-18 16:00:00" train_end_time: &train_end_time "2020-11-15 16:00:00" From fe190dec4b6670a8e0d5410545d7bb8a13304157 Mon Sep 17 00:00:00 2001 From: Alex Wang Date: Wed, 14 Apr 2021 14:40:28 +0800 Subject: [PATCH 05/30] update readme --- examples/highfreq/README.md | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/examples/highfreq/README.md b/examples/highfreq/README.md index 30c2e19db..c07d8a2a0 100644 --- a/examples/highfreq/README.md +++ b/examples/highfreq/README.md @@ -25,4 +25,11 @@ The example is given in `workflow.py`, users can run the code as follows. Run the example by running the following command: ```bash python workflow.py dump_and_load_dataset -``` \ No newline at end of file +``` + +## Benchmarks Performance +### Signal Test +Here are the results of signal test for benchmark models. We will keep updating benchmark models in future. +| Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Long precision| Short Precision | Long-Short Average Return | Long-Short Average Sharpe | +|---|---|---|---|---|---|---|---|---|---| +| LightGBM | Alpha158 | 0.3042±0.00 | 1.5372±0.00| 0.3117±0.00 | 1.6258±0.00 | 0.6720±0.00 | 0.6870±0.00 | 0.000769±0.00 | 1.0190±0.00 | From 941c980d06371b83cf54eef8e84b0614104eb5d4 Mon Sep 17 00:00:00 2001 From: Alex Wang Date: Wed, 14 Apr 2021 17:35:19 +0800 Subject: [PATCH 06/30] update tabnet --- examples/benchmarks/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/benchmarks/README.md b/examples/benchmarks/README.md index f1e7437fa..c3d965d85 100644 --- a/examples/benchmarks/README.md +++ b/examples/benchmarks/README.md @@ -17,6 +17,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of | ALSTM (Yao Qin, et al.) | Alpha360 | 0.0493±0.01 | 0.3778±0.06| 0.0585±0.00 | 0.4606±0.04 | 0.0513±0.03 | 0.6727±0.38| -0.1085±0.02 | | GATs (Petar Velickovic, et al.) | Alpha360 | 0.0475±0.00 | 0.3515±0.02| 0.0592±0.00 | 0.4585±0.01 | 0.0876±0.02 | 1.1513±0.27| -0.0795±0.02 | | DoubleEnsemble (Chuheng Zhang, et al.) | Alpha360 | 0.0407±0.00| 0.3053±0.00 | 0.0490±0.00 | 0.3840±0.00 | 0.0380±0.02 | 0.5000±0.21 | -0.0984±0.02 | +| TabNet (Sercan O. Arik, et al.)| Alpha360 | 0.0192±0.00 | 0.1401±0.00| 0.0291±0.00 | 0.2163±0.00 | -0.0258±0.00 | -0.2961±0.00| -0.1429±0.00 | ## Alpha158 dataset | Model Name | Dataset | IC | ICIR | Rank IC | Rank ICIR | Annualized Return | Information Ratio | Max Drawdown | @@ -32,6 +33,7 @@ The numbers shown below demonstrate the performance of the entire `workflow` of | ALSTM (Yao Qin, et al.) | Alpha158 (with selected 20 features) | 0.0385±0.01 | 0.3022±0.06| 0.0478±0.00 | 0.3874±0.04 | 0.0486±0.03 | 0.7141±0.45| -0.1088±0.03 | | GATs (Petar Velickovic, et al.) | Alpha158 (with selected 20 features) | 0.0349±0.00 | 0.2511±0.01| 0.0457±0.00 | 0.3537±0.01 | 0.0578±0.02 | 0.8221±0.25| -0.0824±0.02 | | DoubleEnsemble (Chuheng Zhang, et al.) | Alpha158 | 0.0544±0.00 | 0.4338±0.01 | 0.0523±0.00 | 0.4257±0.01 | 0.1253±0.01 | 1.4105±0.14 | -0.0902±0.01 | +| TabNet (Sercan O. Arik, et al.)| Alpha158 | 0.0383±0.00 | 0.3414±0.00| 0.0388±0.00 | 0.3460±0.00 | 0.0226±0.00 | 0.2652±0.00| -0.1072±0.00 | - The selected 20 features are based on the feature importance of a lightgbm-based model. - The base model of DoubleEnsemble is LGBM. From 848d953226cedc782f8949838698801458b1a829 Mon Sep 17 00:00:00 2001 From: Jactus Date: Fri, 16 Apr 2021 09:58:55 +0800 Subject: [PATCH 07/30] Update qlib logger --- qlib/log.py | 22 +++++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/qlib/log.py b/qlib/log.py index 126acb9d2..ed050f6c9 100644 --- a/qlib/log.py +++ b/qlib/log.py @@ -11,6 +11,26 @@ from contextlib import contextmanager from .config import C +class QlibLogger(Logger,meta=): + ''' + Customized logger for Qlib. + ''' + def __init__(self, module_name): + self.module_name = module_name + self.level = 0 + + @property + def logger(self): + logger = logging.getLogger(self.module_name) + logger.setLevel(self.level) + return logger + + def setLevel(self, level): + self.level = level + + def __getattr__(self, name): + return self.logger.__getattribute__(name) + def get_module_logger(module_name, level: Optional[int] = None): """ @@ -27,7 +47,7 @@ def get_module_logger(module_name, level: Optional[int] = None): module_name = "qlib.{}".format(module_name) # Get logger. - module_logger = logging.getLogger(module_name) + module_logger = QlibLogger(module_name) module_logger.setLevel(level) return module_logger From 78bb8882cd4f23e20a14d69682b54cdd24a3e200 Mon Sep 17 00:00:00 2001 From: Jactus Date: Fri, 16 Apr 2021 12:00:18 +0800 Subject: [PATCH 08/30] Format --- qlib/log.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/qlib/log.py b/qlib/log.py index ed050f6c9..017e8e339 100644 --- a/qlib/log.py +++ b/qlib/log.py @@ -11,10 +11,12 @@ from contextlib import contextmanager from .config import C -class QlibLogger(Logger,meta=): - ''' + +class QlibLogger: + """ Customized logger for Qlib. - ''' + """ + def __init__(self, module_name): self.module_name = module_name self.level = 0 @@ -27,10 +29,10 @@ class QlibLogger(Logger,meta=): def setLevel(self, level): self.level = level - + def __getattr__(self, name): return self.logger.__getattribute__(name) - + def get_module_logger(module_name, level: Optional[int] = None): """ From f4bfe8e6197aa52bfb759c3981346953f8306f41 Mon Sep 17 00:00:00 2001 From: Jactus Date: Fri, 16 Apr 2021 14:35:05 +0800 Subject: [PATCH 09/30] First trial of adding docstring --- qlib/log.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/qlib/log.py b/qlib/log.py index 017e8e339..c7d269f4d 100644 --- a/qlib/log.py +++ b/qlib/log.py @@ -12,7 +12,23 @@ from contextlib import contextmanager from .config import C -class QlibLogger: +class MetaLogger(type): + def __init__(self, name, bases, dic): + super().__init__(name, bases, dic) + + def __new__(cls, name, bases, dict): + wrapper_dict = type(logging.getLogger("module_name")).__dict__.copy() + wrapper_dict.update(dict) + wrapper_dict["__doc__"] = logging.getLogger("module_name").__doc__ + return type.__new__(cls, name, bases, wrapper_dict) + + def __call__(cls, *args, **kwargs): + obj = cls.__new__(cls) + cls.__init__(cls, *args, **kwargs) + return obj + + +class QlibLogger(metaclass=MetaLogger): """ Customized logger for Qlib. """ From 4ebf68479416932d8e28fdd4af289655e54a254f Mon Sep 17 00:00:00 2001 From: Jactus Date: Fri, 16 Apr 2021 15:35:11 +0800 Subject: [PATCH 10/30] Update workflow logging --- qlib/log.py | 4 ++-- qlib/workflow/exp.py | 4 ++-- qlib/workflow/expm.py | 4 ++-- qlib/workflow/record_temp.py | 4 ++-- qlib/workflow/recorder.py | 4 ++-- qlib/workflow/utils.py | 4 ++-- 6 files changed, 12 insertions(+), 12 deletions(-) diff --git a/qlib/log.py b/qlib/log.py index c7d269f4d..4ecdceef2 100644 --- a/qlib/log.py +++ b/qlib/log.py @@ -17,9 +17,9 @@ class MetaLogger(type): super().__init__(name, bases, dic) def __new__(cls, name, bases, dict): - wrapper_dict = type(logging.getLogger("module_name")).__dict__.copy() + wrapper_dict = type(logging.getLogger("MetaLogger")).__dict__.copy() wrapper_dict.update(dict) - wrapper_dict["__doc__"] = logging.getLogger("module_name").__doc__ + wrapper_dict["__doc__"] = logging.getLogger("MetaLogger").__doc__ return type.__new__(cls, name, bases, wrapper_dict) def __call__(cls, *args, **kwargs): diff --git a/qlib/workflow/exp.py b/qlib/workflow/exp.py index dd73f7f52..7b3d1f507 100644 --- a/qlib/workflow/exp.py +++ b/qlib/workflow/exp.py @@ -1,14 +1,14 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import mlflow +import mlflow, logging from mlflow.entities import ViewType from mlflow.exceptions import MlflowException from pathlib import Path from .recorder import Recorder, MLflowRecorder from ..log import get_module_logger -logger = get_module_logger("workflow", "INFO") +logger = get_module_logger("workflow", logging.INFO) class Experiment: diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py index 5275e57d7..590790c9e 100644 --- a/qlib/workflow/expm.py +++ b/qlib/workflow/expm.py @@ -4,7 +4,7 @@ import mlflow from mlflow.exceptions import MlflowException from mlflow.entities import ViewType -import os +import os, logging from pathlib import Path from contextlib import contextmanager from typing import Optional, Text @@ -14,7 +14,7 @@ from ..config import C from .recorder import Recorder from ..log import get_module_logger -logger = get_module_logger("workflow", "INFO") +logger = get_module_logger("workflow", logging.INFO) class ExpManager: diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index dee327f64..5732c95a9 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import re +import re, logging import pandas as pd from pathlib import Path from pprint import pprint @@ -16,7 +16,7 @@ from ..utils import flatten_dict from ..contrib.eva.alpha import calc_ic, calc_long_short_return, calc_long_short_prec from ..contrib.strategy.strategy import BaseStrategy -logger = get_module_logger("workflow", "INFO") +logger = get_module_logger("workflow", logging.INFO) class RecordTemp: diff --git a/qlib/workflow/recorder.py b/qlib/workflow/recorder.py index 5915e58da..b9b2fd1b3 100644 --- a/qlib/workflow/recorder.py +++ b/qlib/workflow/recorder.py @@ -1,14 +1,14 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import mlflow +import mlflow, logging import shutil, os, pickle, tempfile, codecs, pickle from pathlib import Path from datetime import datetime from ..utils.objm import FileManager from ..log import get_module_logger -logger = get_module_logger("workflow", "INFO") +logger = get_module_logger("workflow", logging.INFO) class Recorder: diff --git a/qlib/workflow/utils.py b/qlib/workflow/utils.py index 33d251dd8..596ff0927 100644 --- a/qlib/workflow/utils.py +++ b/qlib/workflow/utils.py @@ -1,12 +1,12 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import sys, traceback, signal, atexit +import sys, traceback, signal, atexit, logging from . import R from .recorder import Recorder from ..log import get_module_logger -logger = get_module_logger("workflow", "INFO") +logger = get_module_logger("workflow", logging.INFO) # function to handle the experiment when unusual program ending occurs From cbf1fa721ed85f0d2e89ff19f9ec0e08af2339c2 Mon Sep 17 00:00:00 2001 From: Jactus Date: Sat, 17 Apr 2021 15:47:49 +0800 Subject: [PATCH 11/30] Update --- qlib/contrib/workflow/record_temp.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/qlib/contrib/workflow/record_temp.py b/qlib/contrib/workflow/record_temp.py index 12792fbcb..bedf89105 100644 --- a/qlib/contrib/workflow/record_temp.py +++ b/qlib/contrib/workflow/record_temp.py @@ -1,10 +1,11 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import logging import pandas as pd +import numpy as np from sklearn.metrics import mean_squared_error from typing import Dict, Text, Any -import numpy as np from ...contrib.eva.alpha import calc_ic from ...workflow.record_temp import RecordTemp @@ -12,7 +13,7 @@ from ...workflow.record_temp import SignalRecord from ...data import dataset as qlib_dataset from ...log import get_module_logger -logger = get_module_logger("workflow", "INFO") +logger = get_module_logger("workflow", logging.INFO) class MultiSegRecord(RecordTemp): From 6a05d4e2559f1917e6411478426cab6c4f6eaa78 Mon Sep 17 00:00:00 2001 From: Jactus Date: Mon, 19 Apr 2021 11:36:00 +0800 Subject: [PATCH 12/30] Enable IDEs docstrings --- qlib/log.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qlib/log.py b/qlib/log.py index 4ecdceef2..8b123d05d 100644 --- a/qlib/log.py +++ b/qlib/log.py @@ -50,7 +50,7 @@ class QlibLogger(metaclass=MetaLogger): return self.logger.__getattribute__(name) -def get_module_logger(module_name, level: Optional[int] = None): +def get_module_logger(module_name, level: Optional[int] = None) -> logging.Logger: """ Get a logger for a specific module. From aafaff45d2b0d2740d83b2651a8887f51011037b Mon Sep 17 00:00:00 2001 From: Jactus Date: Thu, 22 Apr 2021 14:13:36 +0800 Subject: [PATCH 13/30] Update doc --- qlib/contrib/backtest/backtest.py | 7 +++++-- qlib/contrib/report/analysis_position/cumulative_return.py | 2 +- qlib/contrib/report/analysis_position/rank_label.py | 2 +- qlib/contrib/report/analysis_position/report.py | 2 +- 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/qlib/contrib/backtest/backtest.py b/qlib/contrib/backtest/backtest.py index b87d6afe3..909948c25 100644 --- a/qlib/contrib/backtest/backtest.py +++ b/qlib/contrib/backtest/backtest.py @@ -15,7 +15,8 @@ LOG = get_module_logger("backtest") def backtest(pred, strategy, executor, trade_exchange, shift, verbose, account, benchmark, return_order): - """Parameters + """ + Parameters ---------- pred : pandas.DataFrame predict should has index and one `score` column @@ -124,7 +125,9 @@ def backtest(pred, strategy, executor, trade_exchange, shift, verbose, account, def update_account(trade_account, trade_info, trade_exchange, trade_date): - """Update the account and strategy + """ + Update the account and strategy + Parameters ---------- trade_account : Account() diff --git a/qlib/contrib/report/analysis_position/cumulative_return.py b/qlib/contrib/report/analysis_position/cumulative_return.py index abb68ea60..00985a17c 100644 --- a/qlib/contrib/report/analysis_position/cumulative_return.py +++ b/qlib/contrib/report/analysis_position/cumulative_return.py @@ -214,7 +214,7 @@ def cumulative_return_graph( features_df = D.features(D.instruments('csi500'), ['Ref($close, -1)/$close - 1'], pred_df_dates.min(), pred_df_dates.max()) features_df.columns = ['label'] - qcr.cumulative_return_graph(positions, report_normal_df, features_df) + qcr.analysis_position.cumulative_return_graph(positions, report_normal_df, features_df) Graph desc: diff --git a/qlib/contrib/report/analysis_position/rank_label.py b/qlib/contrib/report/analysis_position/rank_label.py index 72a358adc..77743b10c 100644 --- a/qlib/contrib/report/analysis_position/rank_label.py +++ b/qlib/contrib/report/analysis_position/rank_label.py @@ -94,7 +94,7 @@ def rank_label_graph( features_df = D.features(D.instruments('csi500'), ['Ref($close, -1)/$close-1'], pred_df_dates.min(), pred_df_dates.max()) features_df.columns = ['label'] - qcr.rank_label_graph(positions, features_df, pred_df_dates.min(), pred_df_dates.max()) + qcr.analysis_position.rank_label_graph(positions, features_df, pred_df_dates.min(), pred_df_dates.max()) :param position: position data; **qlib.contrib.backtest.backtest.backtest** result. diff --git a/qlib/contrib/report/analysis_position/report.py b/qlib/contrib/report/analysis_position/report.py index f82e654c4..6b83f0734 100644 --- a/qlib/contrib/report/analysis_position/report.py +++ b/qlib/contrib/report/analysis_position/report.py @@ -186,7 +186,7 @@ def report_graph(report_df: pd.DataFrame, show_notebook: bool = True) -> [list, report_normal_df, _ = backtest(pred_df, strategy, **bparas) - qcr.report_graph(report_normal_df) + qcr.analysis_position.report_graph(report_normal_df) :param report_df: **df.index.name** must be **date**, **df.columns** must contain **return**, **turnover**, **cost**, **bench**. From 8adfafa6aa6f76591ae2af537f9d8ad91ccb6c43 Mon Sep 17 00:00:00 2001 From: Jactus Date: Thu, 22 Apr 2021 14:17:25 +0800 Subject: [PATCH 14/30] Black format --- qlib/contrib/backtest/backtest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/qlib/contrib/backtest/backtest.py b/qlib/contrib/backtest/backtest.py index 909948c25..fc30065fd 100644 --- a/qlib/contrib/backtest/backtest.py +++ b/qlib/contrib/backtest/backtest.py @@ -127,7 +127,7 @@ def backtest(pred, strategy, executor, trade_exchange, shift, verbose, account, def update_account(trade_account, trade_info, trade_exchange, trade_date): """ Update the account and strategy - + Parameters ---------- trade_account : Account() From fbff4c271a7e74f2f0b4770912abf2fb01a9354b Mon Sep 17 00:00:00 2001 From: Jactus Date: Fri, 23 Apr 2021 00:38:45 +0800 Subject: [PATCH 15/30] Remove redundant methods in meta --- qlib/log.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/qlib/log.py b/qlib/log.py index 8b123d05d..3b3362d5b 100644 --- a/qlib/log.py +++ b/qlib/log.py @@ -13,8 +13,6 @@ from .config import C class MetaLogger(type): - def __init__(self, name, bases, dic): - super().__init__(name, bases, dic) def __new__(cls, name, bases, dict): wrapper_dict = type(logging.getLogger("MetaLogger")).__dict__.copy() @@ -22,11 +20,6 @@ class MetaLogger(type): wrapper_dict["__doc__"] = logging.getLogger("MetaLogger").__doc__ return type.__new__(cls, name, bases, wrapper_dict) - def __call__(cls, *args, **kwargs): - obj = cls.__new__(cls) - cls.__init__(cls, *args, **kwargs) - return obj - class QlibLogger(metaclass=MetaLogger): """ From e410caaa8fb315de7898035986ec7cca58384bf0 Mon Sep 17 00:00:00 2001 From: Jactus Date: Fri, 23 Apr 2021 10:08:12 +0800 Subject: [PATCH 16/30] Simplify meta class --- qlib/log.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/qlib/log.py b/qlib/log.py index 3b3362d5b..5888b3841 100644 --- a/qlib/log.py +++ b/qlib/log.py @@ -13,11 +13,10 @@ from .config import C class MetaLogger(type): - def __new__(cls, name, bases, dict): - wrapper_dict = type(logging.getLogger("MetaLogger")).__dict__.copy() + wrapper_dict = logging.Logger.__dict__.copy() wrapper_dict.update(dict) - wrapper_dict["__doc__"] = logging.getLogger("MetaLogger").__doc__ + wrapper_dict["__doc__"] = logging.Logger.__doc__ return type.__new__(cls, name, bases, wrapper_dict) From e15ea06122bd570706ac8b6d3ab6b96b5ee64edb Mon Sep 17 00:00:00 2001 From: zhupr Date: Sun, 25 Apr 2021 23:50:29 +0800 Subject: [PATCH 17/30] Fix ClientProvider not supporting LocalInstrumentProvider && online using the latest python-socketio --- qlib/data/data.py | 3 ++- setup.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/qlib/data/data.py b/qlib/data/data.py index 000bd1196..cea2f42eb 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -1016,7 +1016,8 @@ class ClientProvider(BaseProvider): self.logger = get_module_logger(self.__class__.__name__) if isinstance(Cal, ClientCalendarProvider): Cal.set_conn(self.client) - Inst.set_conn(self.client) + if isinstance(Inst, ClientInstrumentProvider): + Inst.set_conn(self.client) if hasattr(DatasetD, "provider"): DatasetD.provider.set_conn(self.client) else: diff --git a/setup.py b/setup.py index 83cf6e1b6..747d885f4 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ REQUIRED = [ "scipy>=1.0.0", "requests>=2.18.0", "sacred>=0.7.4", - "python-socketio==3.1.2", + "python-socketio", "redis>=3.0.1", "python-redis-lock>=3.3.1", "schedule>=0.6.0", From 5a7eecabeefdf5218a4a4ea1db5ed94343df6c42 Mon Sep 17 00:00:00 2001 From: Young Date: Tue, 27 Apr 2021 04:04:43 +0000 Subject: [PATCH 18/30] black formating (black is upgraded in github) --- examples/benchmarks/TFT/data_formatters/base.py | 2 +- qlib/contrib/backtest/position.py | 2 +- qlib/contrib/report/graph.py | 2 +- qlib/data/dataset/processor.py | 4 ++-- qlib/model/base.py | 4 ++-- qlib/portfolio/optimizer/base.py | 4 ++-- 6 files changed, 9 insertions(+), 9 deletions(-) diff --git a/examples/benchmarks/TFT/data_formatters/base.py b/examples/benchmarks/TFT/data_formatters/base.py index c68a192ba..aa1c0dc82 100644 --- a/examples/benchmarks/TFT/data_formatters/base.py +++ b/examples/benchmarks/TFT/data_formatters/base.py @@ -132,7 +132,7 @@ class GenericDataFormatter(abc.ABC): return -1, -1 def get_column_definition(self): - """"Returns formatted column definition in order expected by the TFT.""" + """Returns formatted column definition in order expected by the TFT.""" column_definition = self._column_definition diff --git a/qlib/contrib/backtest/position.py b/qlib/contrib/backtest/position.py index 6c269d505..97abc2a56 100644 --- a/qlib/contrib/backtest/position.py +++ b/qlib/contrib/backtest/position.py @@ -128,7 +128,7 @@ class Position: return self.position["cash"] def get_stock_amount_dict(self): - """generate stock amount dict {stock_id : amount of stock} """ + """generate stock amount dict {stock_id : amount of stock}""" d = {} stock_list = self.get_stock_list() for stock_code in stock_list: diff --git a/qlib/contrib/report/graph.py b/qlib/contrib/report/graph.py index 677e767ee..2d4f546e8 100644 --- a/qlib/contrib/report/graph.py +++ b/qlib/contrib/report/graph.py @@ -18,7 +18,7 @@ from ...utils import get_module_by_module_path class BaseGraph: - """""" + """ """ _name = None diff --git a/qlib/data/dataset/processor.py b/qlib/data/dataset/processor.py index e035f5624..7635a4127 100644 --- a/qlib/data/dataset/processor.py +++ b/qlib/data/dataset/processor.py @@ -129,7 +129,7 @@ class FilterCol(Processor): class TanhProcess(Processor): - """ Use tanh to process noise data""" + """Use tanh to process noise data""" def __call__(self, df): def tanh_denoise(data): @@ -144,7 +144,7 @@ class TanhProcess(Processor): class ProcessInf(Processor): - """Process infinity """ + """Process infinity""" def __call__(self, df): def replace_inf(data): diff --git a/qlib/model/base.py b/qlib/model/base.py index 1ac8f2fc9..12caf5f73 100644 --- a/qlib/model/base.py +++ b/qlib/model/base.py @@ -11,11 +11,11 @@ class BaseModel(Serializable, metaclass=abc.ABCMeta): @abc.abstractmethod def predict(self, *args, **kwargs) -> object: - """ Make predictions after modeling things """ + """Make predictions after modeling things""" pass def __call__(self, *args, **kwargs) -> object: - """ leverage Python syntactic sugar to make the models' behaviors like functions """ + """leverage Python syntactic sugar to make the models' behaviors like functions""" return self.predict(*args, **kwargs) diff --git a/qlib/portfolio/optimizer/base.py b/qlib/portfolio/optimizer/base.py index 502443869..e3f692014 100644 --- a/qlib/portfolio/optimizer/base.py +++ b/qlib/portfolio/optimizer/base.py @@ -5,9 +5,9 @@ import abc class BaseOptimizer(abc.ABC): - """ Construct portfolio with a optimization related method """ + """Construct portfolio with a optimization related method""" @abc.abstractmethod def __call__(self, *args, **kwargs) -> object: - """ Generate a optimized portfolio allocation """ + """Generate a optimized portfolio allocation""" pass From eab19de080e2b2b1de93cdce7704c6535f2b2ced Mon Sep 17 00:00:00 2001 From: Jactus Date: Tue, 27 Apr 2021 16:56:07 +0800 Subject: [PATCH 19/30] Support start exp with given exp & recorder id --- qlib/workflow/__init__.py | 18 +++++++++++++++--- qlib/workflow/exp.py | 8 +++++--- qlib/workflow/expm.py | 12 ++++++++++-- 3 files changed, 30 insertions(+), 8 deletions(-) diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py index a03665626..7cb1cf5cb 100644 --- a/qlib/workflow/__init__.py +++ b/qlib/workflow/__init__.py @@ -23,7 +23,9 @@ class QlibRecorder: @contextmanager def start( self, + experiment_id: Optional[Text] = None, experiment_name: Optional[Text] = None, + recorder_id: Optional[Text] = None, recorder_name: Optional[Text] = None, uri: Optional[Text] = None, resume: bool = False, @@ -45,8 +47,12 @@ class QlibRecorder: Parameters ---------- + experiment_id : str + id of the experiment one wants to start. experiment_name : str name of the experiment one wants to start. + recorder_id : str + id of the recorder under the experiment one wants to start. recorder_name : str name of the recorder under the experiment one wants to start. uri : str @@ -57,7 +63,7 @@ class QlibRecorder: resume : bool whether to resume the specific recorder with given name under the given experiment. """ - run = self.start_exp(experiment_name, recorder_name, uri, resume) + run = self.start_exp(experiment_id, experiment_name, recorder_id, recorder_name, uri, resume) try: yield run except Exception as e: @@ -65,7 +71,9 @@ class QlibRecorder: raise e self.end_exp(Recorder.STATUS_FI) - def start_exp(self, experiment_name=None, recorder_name=None, uri=None, resume=False): + def start_exp( + self, experiment_id=None, experiment_name=None, recorder_id=None, recorder_name=None, uri=None, resume=False + ): """ Lower level method for starting an experiment. When use this method, one should end the experiment manually and the status of the recorder may not be handled properly. Here is the example code: @@ -79,8 +87,12 @@ class QlibRecorder: Parameters ---------- + experiment_id : str + id of the experiment one wants to start. experiment_name : str the name of the experiment to be started + recorder_id : str + id of the recorder under the experiment one wants to start. recorder_name : str name of the recorder under the experiment one wants to start. uri : str @@ -93,7 +105,7 @@ class QlibRecorder: ------- An experiment instance being started. """ - return self.exp_manager.start_exp(experiment_name, recorder_name, uri, resume) + return self.exp_manager.start_exp(experiment_id, experiment_name, recorder_id, recorder_name, uri, resume) def end_exp(self, recorder_status=Recorder.STATUS_FI): """ diff --git a/qlib/workflow/exp.py b/qlib/workflow/exp.py index 7b3d1f507..0a7e0a5a9 100644 --- a/qlib/workflow/exp.py +++ b/qlib/workflow/exp.py @@ -39,12 +39,14 @@ class Experiment: output["recorders"] = list(recorders.keys()) return output - def start(self, recorder_name=None, resume=False): + def start(self, recorder_id=None, recorder_name=None, resume=False): """ Start the experiment and set it to be active. This method will also start a new recorder. Parameters ---------- + recorder_id : str + the id of the recorder to be created. recorder_name : str the name of the recorder to be created. resume : bool @@ -238,14 +240,14 @@ class MLflowExperiment(Experiment): def __repr__(self): return "{name}(id={id}, info={info})".format(name=self.__class__.__name__, id=self.id, info=self.info) - def start(self, recorder_name=None, resume=False): + def start(self, recorder_id=None, recorder_name=None, resume=False): logger.info(f"Experiment {self.id} starts running ...") # Get or create recorder if recorder_name is None: recorder_name = self._default_rec_name # resume the recorder if resume: - recorder, _ = self._get_or_create_rec(recorder_name=recorder_name) + recorder, _ = self._get_or_create_rec(recorder_id=recorder_id, recorder_name=recorder_name) # create a new recorder else: recorder = self.create_recorder(recorder_name) diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py index 590790c9e..5549bb9bf 100644 --- a/qlib/workflow/expm.py +++ b/qlib/workflow/expm.py @@ -33,7 +33,9 @@ class ExpManager: def start_exp( self, + experiment_id: Optional[Text] = None, experiment_name: Optional[Text] = None, + recorder_id: Optional[Text] = None, recorder_name: Optional[Text] = None, uri: Optional[Text] = None, resume: bool = False, @@ -45,8 +47,12 @@ class ExpManager: Parameters ---------- + experiment_id : str + id of the active experiment. experiment_name : str name of the active experiment. + recorder_id : str + id of the recorder to be started. recorder_name : str name of the recorder to be started. uri : str @@ -298,7 +304,9 @@ class MLflowExpManager(ExpManager): def start_exp( self, + experiment_id: Optional[Text] = None, experiment_name: Optional[Text] = None, + recorder_id: Optional[Text] = None, recorder_name: Optional[Text] = None, uri: Optional[Text] = None, resume: bool = False, @@ -308,11 +316,11 @@ class MLflowExpManager(ExpManager): # Create experiment if experiment_name is None: experiment_name = self._default_exp_name - experiment, _ = self._get_or_create_exp(experiment_name=experiment_name) + experiment, _ = self._get_or_create_exp(experiment_id=experiment_id, experiment_name=experiment_name) # Set up active experiment self.active_experiment = experiment # Start the experiment - self.active_experiment.start(recorder_name, resume) + self.active_experiment.start(recorder_id, recorder_name, resume) return self.active_experiment From 8b8d21107c7f6dd6f6e6db371f4591179a4ad616 Mon Sep 17 00:00:00 2001 From: zhupr Date: Tue, 27 Apr 2021 21:20:47 +0800 Subject: [PATCH 20/30] Add future trading date collector --- qlib/data/data.py | 3 + scripts/data_collector/contrib/README.md | 24 +++++ .../contrib/future_trading_date_collector.py | 87 +++++++++++++++++++ .../data_collector/contrib/requirements.txt | 5 ++ scripts/data_collector/utils.py | 37 ++++++++ scripts/data_collector/yahoo/collector.py | 25 ++---- 6 files changed, 165 insertions(+), 16 deletions(-) create mode 100644 scripts/data_collector/contrib/README.md create mode 100644 scripts/data_collector/contrib/future_trading_date_collector.py create mode 100644 scripts/data_collector/contrib/requirements.txt diff --git a/qlib/data/data.py b/qlib/data/data.py index cea2f42eb..c2638e234 100644 --- a/qlib/data/data.py +++ b/qlib/data/data.py @@ -522,6 +522,9 @@ class LocalCalendarProvider(CalendarProvider): # if future calendar not exists, return current calendar if not os.path.exists(fname): get_module_logger("data").warning(f"{freq}_future.txt not exists, return current calendar!") + get_module_logger("data").warning( + "You can get future calendar by referring to the following document: https://github.com/microsoft/qlib/blob/main/scripts/data_collector/contrib/README.md" + ) fname = self._uri_cal.format(freq) else: fname = self._uri_cal.format(freq) diff --git a/scripts/data_collector/contrib/README.md b/scripts/data_collector/contrib/README.md new file mode 100644 index 000000000..011ff56e6 --- /dev/null +++ b/scripts/data_collector/contrib/README.md @@ -0,0 +1,24 @@ +# Get future trading days + +> `D.calendar(future=True)` will be used + +## Requirements + +```bash +pip install -r requirements.txt +``` + +## Collector Data + +```bash +# parse instruments, using in qlib/instruments. +python future_trading_date_collector.py --qlib_dir ~/.qlib/qlib_data/cn_data --freq day +``` + +## Parameters + +- qlib_dir: qlib data directory +- freq: value from [`day`, `1min`], default `day` + + + diff --git a/scripts/data_collector/contrib/future_trading_date_collector.py b/scripts/data_collector/contrib/future_trading_date_collector.py new file mode 100644 index 000000000..4da62d465 --- /dev/null +++ b/scripts/data_collector/contrib/future_trading_date_collector.py @@ -0,0 +1,87 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import sys +from typing import List +from pathlib import Path + +import fire +import numpy as np +import pandas as pd +from loguru import logger + +# get data from baostock +import baostock as bs + +CUR_DIR = Path(__file__).resolve().parent +sys.path.append(str(CUR_DIR.parent.parent)) + + +from data_collector.utils import generate_minutes_calendar_from_daily + + +def read_calendar_from_qlib(qlib_dir: Path) -> pd.DataFrame: + calendar_path = qlib_dir.joinpath("calendars").joinpath("day.txt") + if not calendar_path.exists(): + return pd.DataFrame() + return pd.read_csv(calendar_path, header=None) + + +def write_calendar_to_qlib(qlib_dir: Path, date_list: List[str], freq: str = "day"): + calendar_path = str(qlib_dir.joinpath("calendars").joinpath(f"{freq}_future.txt")) + + np.savetxt(calendar_path, date_list, fmt="%s", encoding="utf-8") + logger.info(f"write future calendars success: {calendar_path}") + + +def generate_qlib_calendar(date_list: List[str], freq: str) -> List[str]: + print(freq) + if freq == "day": + return date_list + elif freq == "1min": + date_list = generate_minutes_calendar_from_daily(date_list, freq=freq).tolist() + return list(map(lambda x: pd.Timestamp(x).strftime("%Y-%m-%d %H:%M:%S"), date_list)) + else: + raise ValueError(f"Unsupported freq: {freq}") + + +def future_calendar_collector(qlib_dir: [str, Path], freq: str = "day"): + """get future calendar + + Parameters + ---------- + qlib_dir: str or Path + qlib data directory + freq: str + value from ["day", "1min"], by default day + """ + qlib_dir = Path(qlib_dir).expanduser().resolve() + if not qlib_dir.exists(): + raise FileNotFoundError(str(qlib_dir)) + + lg = bs.login() + if lg.error_code != "0": + logger.error(f"login error: {lg.error_msg}") + return + # read daily calendar + daily_calendar = read_calendar_from_qlib(qlib_dir) + end_year = pd.Timestamp.now().year + if daily_calendar.empty: + start_year = pd.Timestamp.now().year + else: + start_year = pd.Timestamp(daily_calendar.iloc[-1, 0]).year + rs = bs.query_trade_dates(start_date=pd.Timestamp(f"{start_year}-01-01"), end_date=f"{end_year}-12-31") + data_list = [] + while (rs.error_code == "0") & rs.next(): + _row_data = rs.get_row_data() + if int(_row_data[1]) == 1: + data_list.append(_row_data[0]) + data_list = sorted(data_list) + date_list = generate_qlib_calendar(data_list, freq=freq) + write_calendar_to_qlib(qlib_dir, date_list, freq=freq) + bs.logout() + logger.info(f"get trading dates success: {start_year}-01-01 to {end_year}-12-31") + + +if __name__ == "__main__": + fire.Fire(future_calendar_collector) diff --git a/scripts/data_collector/contrib/requirements.txt b/scripts/data_collector/contrib/requirements.txt new file mode 100644 index 000000000..92dcb2374 --- /dev/null +++ b/scripts/data_collector/contrib/requirements.txt @@ -0,0 +1,5 @@ +baostock +fire +numpy +pandas +loguru diff --git a/scripts/data_collector/utils.py b/scripts/data_collector/utils.py index e8c9b9dc4..3f4539612 100644 --- a/scripts/data_collector/utils.py +++ b/scripts/data_collector/utils.py @@ -10,7 +10,9 @@ import random import requests import functools from pathlib import Path +from typing import Iterable, Tuple +import numpy as np import pandas as pd from lxml import etree from loguru import logger @@ -418,5 +420,40 @@ def get_trading_date_by_shift(trading_list: list, trading_date: pd.Timestamp, sh return res +def generate_minutes_calendar_from_daily( + calendars: Iterable, + freq: str = "1min", + am_range: Tuple[str, str] = ("09:30:00", "11:29:00"), + pm_range: Tuple[str, str] = ("13:00:00", "14:59:00"), +) -> pd.Index: + """generate minutes calendar + + Parameters + ---------- + calendars: Iterable + daily calendar + freq: str + by default 1min + am_range: Tuple[str, str] + AM Time Range, by default China-Stock: ("09:30:00", "11:29:00") + pm_range: Tuple[str, str] + PM Time Range, by default China-Stock: ("13:00:00", "14:59:00") + + """ + daily_format: str = "%Y-%m-%d" + res = [] + for _day in calendars: + for _range in [am_range, pm_range]: + res.append( + pd.date_range( + f"{pd.Timestamp(_day).strftime(daily_format)} {_range[0]}", + f"{pd.Timestamp(_day).strftime(daily_format)} {_range[1]}", + freq=freq, + ) + ) + + return pd.Index(sorted(set(np.hstack(res)))) + + if __name__ == "__main__": assert len(get_hs_stock_symbols()) >= MINIMUM_SYMBOLS_NUM diff --git a/scripts/data_collector/yahoo/collector.py b/scripts/data_collector/yahoo/collector.py index f0e110694..a6e06613e 100644 --- a/scripts/data_collector/yahoo/collector.py +++ b/scripts/data_collector/yahoo/collector.py @@ -24,7 +24,12 @@ from qlib.config import REG_CN as REGION_CN CUR_DIR = Path(__file__).resolve().parent sys.path.append(str(CUR_DIR.parent.parent)) from data_collector.base import BaseCollector, BaseNormalize, BaseRun -from data_collector.utils import get_calendar_list, get_hs_stock_symbols, get_us_stock_symbols +from data_collector.utils import ( + get_calendar_list, + get_hs_stock_symbols, + get_us_stock_symbols, + generate_minutes_calendar_from_daily, +) INDEX_BENCH_URL = "http://push2his.eastmoney.com/api/qt/stock/kline/get?secid=1.{index_code}&fields1=f1%2Cf2%2Cf3%2Cf4%2Cf5&fields2=f51%2Cf52%2Cf53%2Cf54%2Cf55%2Cf56%2Cf57%2Cf58&klt=101&fqt=0&beg={begin}&end={end}" @@ -418,21 +423,9 @@ class YahooNormalize1min(YahooNormalize, ABC): return calendar_list_1d def generate_1min_from_daily(self, calendars: Iterable) -> pd.Index: - res = [] - daily_format = self.DAILY_FORMAT - am_range = self.AM_RANGE - pm_range = self.PM_RANGE - for _day in calendars: - for _range in [am_range, pm_range]: - res.append( - pd.date_range( - f"{_day.strftime(daily_format)} {_range[0]}", - f"{_day.strftime(daily_format)} {_range[1]}", - freq="1min", - ) - ) - - return pd.Index(sorted(set(np.hstack(res)))) + return generate_minutes_calendar_from_daily( + calendars, freq="1min", am_range=self.AM_RANGE, pm_range=self.PM_RANGE + ) def adjusted_price(self, df: pd.DataFrame) -> pd.DataFrame: # TODO: using daily data factor From 36ab078fbdbdd69f1ac93b0be75ab29253b357d3 Mon Sep 17 00:00:00 2001 From: blin Date: Wed, 28 Apr 2021 07:15:59 +0000 Subject: [PATCH 21/30] filter --- qlib/data/dataset/__init__.py | 44 ++++++++++++++++++++++++++++------- 1 file changed, 36 insertions(+), 8 deletions(-) diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index cd15a98c9..5485796ef 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -114,6 +114,7 @@ class DatasetH(Dataset): """ self.handler: DataHandler = init_instance_by_config(handler, accept_types=DataHandler) self.segments = segments.copy() + self.fetch_kwargs = {} super().__init__(**kwargs) def config(self, handler_kwargs: dict = None, **kwargs): @@ -171,7 +172,7 @@ class DatasetH(Dataset): ---------- slc : slice """ - return self.handler.fetch(slc, **kwargs) + return self.handler.fetch(slc, **kwargs, **self.fetch_kwargs) def prepare( self, @@ -288,13 +289,29 @@ class TSDataSampler: # the data type will be changed # The index of usable data is between start_idx and end_idx - self.start_idx, self.end_idx = self.data.index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end)) self.idx_df, self.idx_map = self.build_index(self.data) - self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance - self.data_idx = deepcopy(self.data.index) + self.data_index = deepcopy(self.data.index) + if flt_data is not None: + self.flt_data = np.array(flt_data).reshape(-1) + self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map) + self.data_index = self.data_index[np.where(self.flt_data == True)[0]] + + self.start_idx, self.end_idx = self.data_index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end)) + self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance + del self.data # save memory + @staticmethod + def flt_idx_map(flt_data, idx_map): + idx = 0 + new_idx_map = {} + for i, exist in enumerate(flt_data): + if exist: + new_idx_map[idx] = idx_map[i] + idx += 1 + return new_idx_map + def get_index(self): """ Get the pandas index of the data, it will be useful in following scenarios @@ -488,8 +505,19 @@ class TSDatasetH(DatasetH): """ split the _prepare_raw_seg is to leave a hook for data preprocessing before creating processing data """ - dtype = kwargs.pop("dtype", None) + dtype = kwargs.pop("dtype") start, end = slc.start, slc.stop - data = self._prepare_raw_seg(slc=slc, **kwargs) - tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len, dtype=dtype) - return tsds + flt_col = kwargs.pop('flt_col', None) + # TSDatasetH will retrieve more data for complete + data = self._prepare_raw_seg(slc, **kwargs) + + flt_kwargs = deepcopy(kwargs) + if flt_col is not None: + flt_kwargs['col_set'] = flt_col + flt_data = self._prepare_raw_seg(slc, **flt_kwargs) + assert len(flt_data.columns) == 1 + else: + flt_data = None + + tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len, dtype=dtype, flt_data=flt_data) + return tsds \ No newline at end of file From 40cf83e5572c141fd837c1c2e923499a6a88a31b Mon Sep 17 00:00:00 2001 From: lzh222333 Date: Wed, 28 Apr 2021 09:23:07 +0000 Subject: [PATCH 22/30] online serving V9 middle status --- .../online_srv/online_management_simulate.py | 81 ++--- .../online_srv/rolling_online_management.py | 76 +++-- examples/online_srv/update_online_pred.py | 22 +- qlib/model/trainer.py | 10 + qlib/workflow/online/manager.py | 123 +++++++- qlib/workflow/online/simulator.py | 61 ++-- qlib/workflow/online/strategy.py | 293 ++++++++++++++++++ qlib/workflow/online/utils.py | 165 ++++++++++ qlib/workflow/task/collect.py | 25 ++ 9 files changed, 721 insertions(+), 135 deletions(-) create mode 100644 qlib/workflow/online/strategy.py create mode 100644 qlib/workflow/online/utils.py diff --git a/examples/online_srv/online_management_simulate.py b/examples/online_srv/online_management_simulate.py index 1b1fed660..6a1d233ae 100644 --- a/examples/online_srv/online_management_simulate.py +++ b/examples/online_srv/online_management_simulate.py @@ -1,20 +1,23 @@ -import fire -import qlib -from qlib.model.ens.ensemble import ens_workflow -from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerRM -from qlib.workflow import R -from qlib.workflow.online.manager import RollingOnlineManager -from qlib.workflow.online.simulator import OnlineSimulator -from qlib.workflow.task.collect import RecorderCollector -from qlib.workflow.task.gen import RollingGen, task_generator -from qlib.workflow.task.manage import TaskManager -from qlib.workflow.task.utils import list_recorders - +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. """ This examples is about the OnlineManager and OnlineSimulator based on rolling tasks. The OnlineManager will focus on the updating of your online models. The OnlineSimulator will focus on the simulating real updating routine of your online models. """ +import fire +import qlib +from qlib.model.ens.ensemble import ens_workflow +from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerRM +from qlib.workflow import R +from qlib.workflow.online.manager import OnlineM # RollingOnlineManager +from qlib.workflow.online.strategy import OnlineStrategy, RollingAverageStrategy +from qlib.workflow.task.collect import RecorderCollector +from qlib.workflow.task.gen import RollingGen, task_generator +from qlib.workflow.task.manage import TaskManager +from qlib.workflow.task.utils import list_recorders + + data_handler_config = { @@ -105,6 +108,8 @@ class OnlineSimulationExample: """ self.exp_name = exp_name self.task_pool = task_pool + self.start_time = start_time + self.end_time = end_time mongo_conf = { "task_url": task_url, "task_db_name": task_db_name, @@ -115,17 +120,18 @@ class OnlineSimulationExample: ) # The rolling tasks generator, modify_end_time is false because we just need simulate to 2018-10-31. self.trainer = DelayTrainerRM(self.exp_name, self.task_pool) self.task_manager = TaskManager(self.task_pool) # A good way to manage all your tasks - self.rolling_online_manager = RollingOnlineManager( - experiment_name=exp_name, - rolling_gen=self.rolling_gen, - trainer=self.trainer, + self.rolling_online_manager = OnlineM( + RollingAverageStrategy( + exp_name, task_template=tasks, rolling_gen=self.rolling_gen, trainer=self.trainer, need_log=False + ), + begin_time=self.start_time, need_log=False, ) # The OnlineManager based on Rolling - self.onlinesimulator = OnlineSimulator( - start_time=start_time, - end_time=end_time, - online_manager=self.rolling_online_manager, - ) + # self.onlinesimulator = OnlineSimulator( + # start_time=start_time, + # end_time=end_time, + # online_manager=self.rolling_online_manager, + # ) self.tasks = tasks # Reset all things to the first status, be careful to save important data @@ -137,37 +143,16 @@ class OnlineSimulationExample: for rid in exp.list_recorders(): exp.delete_recorder(rid) - for rid in list_recorders( - RollingOnlineManager.SIGNAL_EXP, lambda x: True if x.info["name"] == self.exp_name else False - ): + for rid in list_recorders("OnlineManagerSignals", lambda x: True if x.info["name"] == self.exp_name else False): exp.delete_recorder(rid) - # Run this firstly to see the workflow in OnlineManager - def first_train(self): - print("========== first train ==========") - self.reset() - self.rolling_online_manager.first_train(self.tasks) - - # Run this secondly to see the simulating in OnlineSimulator - def simulate(self): - print("========== simulate ==========") - self.onlinesimulator.simulate() - print(self.rolling_online_manager.collect_artifact()) - - print("========== online models ==========") - recs_dict = self.onlinesimulator.online_models() - for time, recs in recs_dict.items(): - print(f"{str(time[0])} to {str(time[1])}:") - for rec in recs: - print(rec.info["id"]) - - print("========== online signals ==========") - print(self.rolling_online_manager.get_signals()) - # Run this to run all workflow automaticly def main(self): - self.first_train() - self.simulate() + self.reset() + print("========== simulate ==========") + self.rolling_online_manager.simulate(end_time=self.end_time) + print(self.rolling_online_manager.get_collector()()) + print(self.rolling_online_manager.get_online_history(self.exp_name)) if __name__ == "__main__": diff --git a/examples/online_srv/rolling_online_management.py b/examples/online_srv/rolling_online_management.py index d118afe75..7b2f58909 100644 --- a/examples/online_srv/rolling_online_management.py +++ b/examples/online_srv/rolling_online_management.py @@ -1,21 +1,22 @@ -import os -from pathlib import Path -import pickle -import fire -import qlib -from qlib.workflow import R -from qlib.workflow.task.gen import RollingGen -from qlib.workflow.task.manage import TaskManager -from qlib.workflow.online.manager import RollingOnlineManager -from qlib.workflow.task.utils import list_recorders -from qlib.model.trainer import TrainerRM - """ This example show how RollingOnlineManager works with rolling tasks. There are two parts including first train and routine. Firstly, the RollingOnlineManager will finish the first training and set trained models to `online` models. Next, the RollingOnlineManager will finish a routine process, including update online prediction -> prepare signals -> prepare tasks -> prepare new models -> reset online models """ +import os +from pathlib import Path +import pickle +import fire +import qlib +from qlib.workflow import R +from qlib.workflow.online.strategy import OnlineStrategy, RollingAverageStrategy +from qlib.workflow.task.gen import RollingGen +from qlib.workflow.task.manage import TaskManager +from qlib.workflow.online.manager import OnlineM +from qlib.workflow.task.utils import list_recorders +from qlib.model.trainer import TrainerRM +from pprint import pprint data_handler_config = { "start_time": "2013-01-01", @@ -77,58 +78,65 @@ task_xgboost_config = { class RollingOnlineExample: def __init__( self, - exp_name="rolling_exp", - task_pool="rolling_task", provider_uri="~/.qlib/qlib_data/cn_data", region="cn", task_url="mongodb://10.0.0.4:27017/", task_db_name="rolling_db", rolling_step=550, + tasks=[task_xgboost_config, task_lgb_config], ): - self.exp_name = exp_name - self.task_pool = task_pool mongo_conf = { "task_url": task_url, # your MongoDB url "task_db_name": task_db_name, # database name } qlib.init(provider_uri=provider_uri, region=region, mongo=mongo_conf) - self.rolling_online_manager = RollingOnlineManager( - experiment_name=exp_name, - rolling_gen=RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD), - trainer=TrainerRM(self.exp_name, self.task_pool), - ) + self.tasks = tasks + self.rolling_step = rolling_step + strategy = [] + for task in tasks: + name_id = task["model"]["class"] + "_" + str(self.rolling_step) + strategy.append( + RollingAverageStrategy( + name_id, + task, + RollingGen(step=rolling_step, rtype=RollingGen.ROLL_SD), + TrainerRM(experiment_name=name_id, task_pool=name_id), + ) + ) + + self.rolling_online_manager = OnlineM(strategy) _ROLLING_MANAGER_PATH = ".rolling_manager" # the RollingOnlineManager will dump to this file, for it will be loaded when calling routine. # Reset all things to the first status, be careful to save important data def reset(self): print("========== reset ==========") - TaskManager(self.task_pool).remove() - exp = R.get_exp(experiment_name=self.exp_name) - for rid in exp.list_recorders(): - exp.delete_recorder(rid) + for task in self.tasks: + name_id = task["model"]["class"] + "_" + str(self.rolling_step) + TaskManager(name_id).remove() + exp = R.get_exp(experiment_name=name_id) + for rid in exp.list_recorders(): + exp.delete_recorder(rid) - if os.path.exists(self._ROLLING_MANAGER_PATH): - os.remove(self._ROLLING_MANAGER_PATH) + if os.path.exists(self._ROLLING_MANAGER_PATH): + os.remove(self._ROLLING_MANAGER_PATH) - for rid in list_recorders( - RollingOnlineManager.SIGNAL_EXP, lambda x: True if x.info["name"] == self.exp_name else False - ): - exp.delete_recorder(rid) + for rid in list_recorders("OnlineManagerSignals", lambda x: True if x.info["name"] == name_id else False): + exp.delete_recorder(rid) def first_run(self): print("========== first_run ==========") self.reset() - self.rolling_online_manager.first_train([task_xgboost_config, task_lgb_config]) + self.rolling_online_manager.first_train() self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH) - print(self.rolling_online_manager.collect_artifact()) + print(self.rolling_online_manager.get_collector()()) def routine(self): print("========== routine ==========") with Path(self._ROLLING_MANAGER_PATH).open("rb") as f: self.rolling_online_manager = pickle.load(f) self.rolling_online_manager.routine() - print(self.rolling_online_manager.collect_artifact()) + print(self.rolling_online_manager.get_collector()()) def main(self): self.first_run() diff --git a/examples/online_srv/update_online_pred.py b/examples/online_srv/update_online_pred.py index ed2ad6997..a02b209bd 100644 --- a/examples/online_srv/update_online_pred.py +++ b/examples/online_srv/update_online_pred.py @@ -1,16 +1,14 @@ +""" +This example show how OnlineTool works when we need update prediction. +There are two parts including first_train and update_online_pred. +Firstly, we will finish the training and set the trained model to `online` model. +Next, we will finish updating online prediction. +""" import fire import qlib from qlib.config import REG_CN from qlib.model.trainer import task_train -from qlib.workflow.online.manager import OnlineManagerR -from qlib.workflow.task.utils import list_recorders - -""" -This example show how OnlineManager works when we need update prediction. -There are two parts including first_train and update_online_pred. -Firstly, the RollingOnlineManager will finish the first training and set the trained model to `online` model. -Next, the RollingOnlineManager will finish updating online prediction -""" +from qlib.workflow.online.utils import OnlineToolR data_handler_config = { "start_time": "2008-01-01", @@ -65,15 +63,15 @@ class UpdatePredExample: ): qlib.init(provider_uri=provider_uri, region=region) self.experiment_name = experiment_name - self.online_manager = OnlineManagerR(self.experiment_name) + self.online_tool = OnlineToolR(self.experiment_name) self.task_config = task_config def first_train(self): rec = task_train(self.task_config, experiment_name=self.experiment_name) - self.online_manager.reset_online_tag(rec) # set to online model + self.online_tool.reset_online_tag(rec) # set to online model def update_online_pred(self): - self.online_manager.update_online_pred() + self.online_tool.update_online_pred() def main(self): self.first_train() diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py index af65c5886..0dcc1d67a 100644 --- a/qlib/model/trainer.py +++ b/qlib/model/trainer.py @@ -25,6 +25,7 @@ def begin_task_train(task_config: dict, experiment_name: str, *args, **kwargs) - Returns: Recorder """ + # FIXME: recorder_id with R.start(experiment_name=experiment_name, recorder_name=str(time.time())): R.log_params(**flatten_dict(task_config)) R.save_objects(**{"task": task_config}) # keep the original format and datatype @@ -112,6 +113,9 @@ class Trainer: """ pass + def is_delay(self): + return False + class TrainerR(Trainer): """Trainer based on (R)ecorder. @@ -240,6 +244,9 @@ class DelayTrainerR(TrainerR): end_train_func(rec) return recs + def is_delay(self): + return True + class DelayTrainerRM(TrainerRM): """ @@ -286,3 +293,6 @@ class DelayTrainerRM(TrainerRM): before_status=TaskManager.STATUS_PART_DONE, ) return recs + + def is_delay(self): + return True diff --git a/qlib/workflow/online/manager.py b/qlib/workflow/online/manager.py index e107271d0..f8266577b 100644 --- a/qlib/workflow/online/manager.py +++ b/qlib/workflow/online/manager.py @@ -1,5 +1,14 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +This class is a component of online serving, it can manage a series of models dynamically. +With the change of time, the decisive models will be also changed. In this module, we called those contributing models as `online` models. +In every routine(such as everyday or every minutes), the `online` models maybe changed and the prediction of them need to be updated. +So this module provide a series methods to control this process. +""" from copy import deepcopy -from operator import index +from pprint import pprint import pandas as pd from qlib.model.ens.ensemble import ens_workflow from qlib.model.ens.group import RollingGroup @@ -9,20 +18,13 @@ from qlib import get_module_logger from qlib.data.data import D from qlib.model.trainer import Trainer, TrainerR, task_train from qlib.workflow import R +from qlib.workflow.online.strategy import OnlineStrategy from qlib.workflow.online.update import PredUpdater from qlib.workflow.recorder import Recorder -from qlib.workflow.task.collect import Collector, RecorderCollector +from qlib.workflow.task.collect import Collector, HyperCollector, RecorderCollector from qlib.workflow.task.gen import RollingGen, task_generator from qlib.workflow.task.utils import TimeAdjuster, list_recorders -""" -This class is a component of online serving, it can manage a series of models dynamically. -With the change of time, the decisive models will be also changed. In this module, we called those contributing models as `online` models. -In every routine(such as everyday or every minutes), the `online` models maybe changed and the prediction of them need to be updated. -So this module provide a series methods to control this process. -""" - - class OnlineManager(Serializable): ONLINE_KEY = "online_status" # the online status key in recorder @@ -357,9 +359,9 @@ class RollingOnlineManager(OnlineManagerR): Args: experiment_name (str): the experiment name. - rolling_gen (RollingGen): a instance of RollingGen - trainer (Trainer, optional): a instance of Trainer. Defaults to None. - collector (Collector, optional): a instance of Collector. Defaults to None. + rolling_gen (RollingGen): an instance of RollingGen + trainer (Trainer, optional): an instance of Trainer. Defaults to None. + collector (Collector, optional): an instance of Collector. Defaults to None. need_log (bool, optional): print log or not. Defaults to True. """ if trainer is None: @@ -475,3 +477,98 @@ class RollingOnlineManager(OnlineManagerR): if rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] == max_test: latest_rec[rid] = rec return latest_rec, max_test + + +class OnlineM(Serializable): + def __init__( + self, strategy: Union[OnlineStrategy, List[OnlineStrategy]], begin_time=None, freq="day", need_log=True + ): + self.logger = get_module_logger(self.__class__.__name__) + self.need_log = need_log + if not isinstance(strategy, list): + strategy = [strategy] + self.strategy = strategy + self.freq = freq + if begin_time is None: + begin_time = D.calendar(freq=self.freq).max() + self.cur_time = pd.Timestamp(begin_time) + self.history = {} + + def first_train(self): + """ + Train a series of models firstly and set some of them into online models. + """ + for strategy in self.strategy: + self.logger.info(f"Strategy `{strategy.name_id}` begins first training...") + online_models = strategy.first_train() + self.history.setdefault(strategy.name_id, {})[self.cur_time] = online_models + + def routine(self, cur_time=None, task_kwargs={}, model_kwargs={}): + """ + The typical update process after a routine, such as day by day or month by month. + update online prediction -> prepare signals -> prepare tasks -> prepare new models -> reset online models + + NOTE: Assumption: if using simulator (delay_prepare is True), the prediction will be prepared well after every training, so there is no need to update predictions. + + Args: + cur_time ([type], optional): [description]. Defaults to None. + delay_prepare (bool, optional): [description]. Defaults to False. + *args, **kwargs: will be passed to `prepare_tasks` and `prepare_new_models`. It can be some hyper parameter or training config. + + Returns: + [type]: [description] + """ + if cur_time is None: + cur_time = D.calendar(freq=self.freq).max() + self.cur_time = pd.Timestamp(cur_time) # None for latest date + for strategy in self.strategy: + self.logger.info(f"Strategy `{strategy.name_id}` begins routine...") + if not strategy.trainer.is_delay(): + strategy.prepare_signals() + tasks = strategy.prepare_tasks(self.cur_time, **task_kwargs) + online_models = strategy.prepare_online_models(tasks, **model_kwargs) + if len(online_models) > 0: + self.history.setdefault(strategy.name_id, {})[self.cur_time] = online_models + + def get_collector(self): + collector_dict = {} + for strategy in self.strategy: + collector_dict[strategy.name_id] = strategy.get_collector() + return HyperCollector(collector_dict) + + def get_online_history(self, strategy_name_id): + history_dict = self.history[strategy_name_id] + history = [] + for time in sorted(history_dict): + models = history_dict[time] + history.append((time, models)) + return history + + def delay_prepare(self, delay_kwargs={}): + """ + Prepare all models and signals if there are something waiting for prepare. + NOTE: Assumption: the predictions of online models are between `time_segment`, or this method will work in a wrong way. + + Args: + rec_dict (str): an online models dict likes {(begin_time, end_time):[online models]}. + *args, **kwargs: will be passed to end_train which means will be passed to customized train method. + """ + for strategy in self.strategy: + strategy.delay_prepare(self.get_online_history(strategy.name_id), **delay_kwargs) + + def simulate(self, end_time, frequency="day", task_kwargs={}, model_kwargs={}, delay_kwargs={}): + """ + Starting from start time, this method will simulate every routine in OnlineManager. + NOTE: Considering the parallel training, the models and signals can be perpared after all routine simulating. + + Returns: + Collector: the OnlineManager's collector + """ + cal = D.calendar(start_time=self.cur_time, end_time=end_time, freq=frequency) + self.first_train() + for cur_time in cal: + self.logger.info(f"Simulating at {str(cur_time)}......") + self.routine(cur_time, task_kwargs=task_kwargs, model_kwargs=model_kwargs) + self.delay_prepare(delay_kwargs=delay_kwargs) + self.logger.info(f"Finished preparing signals") + return self.get_collector() diff --git a/qlib/workflow/online/simulator.py b/qlib/workflow/online/simulator.py index d45b7d99d..ddaf2471c 100644 --- a/qlib/workflow/online/simulator.py +++ b/qlib/workflow/online/simulator.py @@ -1,6 +1,6 @@ from qlib.data import D from qlib import get_module_logger -from qlib.workflow.online.manager import OnlineManager +from qlib.workflow.online.manager import OnlineM class OnlineSimulator: @@ -32,7 +32,35 @@ class OnlineSimulator: if len(self.cal) == 0: self.logger.warn(f"There is no need to simulate bacause start_time is larger than end_time.") - def simulate(self, *args, **kwargs): + # def simulate(self, *args, **kwargs): + # """ + # Starting from start time, this method will simulate every routine in OnlineManager. + # NOTE: Considering the parallel training, the models and signals can be perpared after all routine simulating. + + # Returns: + # Collector: the OnlineManager's collector + # """ + # self.rec_dict = {} + # tmp_begin = self.start_time + # tmp_end = None + # self.olm.first_train() + # prev_recorders = self.olm.online_models() + # for cur_time in self.cal: + # self.logger.info(f"Simulating at {str(cur_time)}......") + # recorders = self.olm.routine(cur_time, True, *args, **kwargs) + # if len(recorders) == 0: + # tmp_end = cur_time + # else: + # self.rec_dict[(tmp_begin, tmp_end)] = prev_recorders + # tmp_begin = cur_time + # prev_recorders = recorders + # self.rec_dict[(tmp_begin, self.end_time)] = prev_recorders + # # finished perparing models (and pred) and signals + # self.olm.delay_prepare(self.rec_dict) + # self.logger.info(f"Finished preparing signals") + # return self.olm.get_collector() + + def simulate(self, task_kwargs={}, model_kwargs={}): """ Starting from start time, this method will simulate every routine in OnlineManager. NOTE: Considering the parallel training, the models and signals can be perpared after all routine simulating. @@ -40,33 +68,10 @@ class OnlineSimulator: Returns: Collector: the OnlineManager's collector """ - self.rec_dict = {} - tmp_begin = self.start_time - tmp_end = None - prev_recorders = self.olm.online_models() + self.olm.first_train() for cur_time in self.cal: self.logger.info(f"Simulating at {str(cur_time)}......") - recorders = self.olm.routine(cur_time, True, *args, **kwargs) - if len(recorders) == 0: - tmp_end = cur_time - else: - self.rec_dict[(tmp_begin, tmp_end)] = prev_recorders - tmp_begin = cur_time - prev_recorders = recorders - self.rec_dict[(tmp_begin, self.end_time)] = prev_recorders - # finished perparing models (and pred) and signals - self.olm.delay_prepare(self.rec_dict) + self.olm.routine(cur_time, task_kwargs={}, model_kwargs={}) + self.olm.delay_prepare() self.logger.info(f"Finished preparing signals") return self.olm.get_collector() - - def online_models(self): - """ - Return a online models dict likes {(begin_time, end_time):[online models]}. - - Returns: - dict - """ - if hasattr(self, "rec_dict"): - return self.rec_dict - self.logger.warn(f"Please call `simulate` firstly when calling `online_models`") - return {} diff --git a/qlib/workflow/online/strategy.py b/qlib/workflow/online/strategy.py new file mode 100644 index 000000000..5e4dcc024 --- /dev/null +++ b/qlib/workflow/online/strategy.py @@ -0,0 +1,293 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +""" +This module is working with OnlineManager, responsing for a set of strategy about how the models are updated and signals are perpared. +""" + +from copy import deepcopy +from typing import List, Union +import pandas as pd +from qlib.data.data import D +from qlib.log import get_module_logger +from qlib.model.ens.group import RollingGroup +from qlib.model.trainer import Trainer, TrainerR +from qlib.workflow import R +from qlib.workflow.online.utils import OnlineTool, OnlineToolR +from qlib.workflow.task.collect import HyperCollector, RecorderCollector +from qlib.workflow.task.gen import RollingGen, task_generator +from qlib.workflow.task.utils import TimeAdjuster, list_recorders + + +class OnlineStrategy: + def __init__(self, name_id: str, trainer: Trainer = None, need_log=True): + """ + init OnlineManager. + + Args: + name_id (str): a unique name or id + trainer (Trainer, optional): a instance of Trainer. Defaults to None. + need_log (bool, optional): print log or not. Defaults to True. + """ + self.name_id = name_id + self.trainer = trainer + self.logger = get_module_logger(self.__class__.__name__) + self.need_log = need_log + self.tool = OnlineTool() + self.history = {} + + def prepare_signals(self, delay=False): + """ + After perparing the data of last routine (a box in box-plot) which means the end of the routine, we can prepare trading signals for next routine. + Must use `pass` even though there is nothing to do. + """ + raise NotImplementedError(f"Please implement the `prepare_signals` method.") + + def prepare_tasks(self, *args, **kwargs): + """ + After the end of a routine, check whether we need to prepare and train some new tasks. + return the new tasks waiting for training. + """ + raise NotImplementedError(f"Please implement the `prepare_tasks` method.") + + def prepare_online_models(self, tasks, check_func=None, **kwargs): + """ + Use trainer to train a list of tasks and set the trained model to `online`. + + Args: + tasks (list): a list of tasks. + tag (str): + `ONLINE_TAG` for first train or additional train + `NEXT_ONLINE_TAG` for reset online model when calling `reset_online_tag` + `OFFLINE_TAG` for train but offline those models + check_func: the method to judge if a model can be online. + The parameter is the model record and return True for online. + None for online every models. + **kwargs: will be passed to end_train which means will be passed to customized train method. + + """ + if check_func is None: + check_func = lambda x: True + online_models = [] + if len(tasks) > 0: + new_models = self.trainer.train(tasks, **kwargs) + for model in new_models: + if check_func(model): + online_models.append(model) + self.tool.reset_online_tag(online_models) + return online_models + + def first_train(self): + """ + Train a series of models firstly and set some of them into online models. + """ + raise NotImplementedError(f"Please implement the `first_train` method.") + + def get_collector(self): + """ + Return the collector. + + Returns: + Collector + """ + raise NotImplementedError(f"Please implement the `get_collector` method.") + + def delay_prepare(self, history, **kwargs): + """ + Prepare all models and signals if there are something waiting for prepare. + NOTE: Assumption: the predictions of online models are between `time_segment`, or this method will work in a wrong way. + + Args: + rec_dict (str): an online models dict likes {(begin_time, end_time):[online models]}. + *args, **kwargs: will be passed to end_train which means will be passed to customized train method. + """ + for time_begin, recs_list in history: + self.trainer.end_train(recs_list, **kwargs) + self.tool.reset_online_tag(recs_list) + self.prepare_signals(delay=True) + + +class RollingAverageStrategy(OnlineStrategy): + + """ + This example strategy always use latest rolling model as online model and prepare trading signals using the average prediction of online models + """ + + def __init__( + self, + name_id: str, + task_template: Union[dict, List[dict]], + rolling_gen: RollingGen, + trainer: Trainer = None, + need_log=True, + signal_exp_name="OnlineManagerSignals", + ): + """ + init OnlineManagerR. + + Assumption: the str of name_id, the experiment name and the trainer's experiment name are same one. + + Args: + name_id (str): a unique name or id. Will be also the name of Experiment. + task_template (Union[dict,List[dict]]): a list of task_template or a single template, which will be used to generate many tasks using rolling_gen. + rolling_gen (RollingGen): an instance of RollingGen + trainer (Trainer, optional): a instance of Trainer. Defaults to None. + need_log (bool, optional): print log or not. Defaults to True. + signal_exp_path (str): a specific experiment to save signals of different experiment. + """ + super().__init__(name_id=name_id, trainer=trainer, need_log=need_log) + self.exp_name = self.name_id + if not isinstance(task_template, list): + task_template = [task_template] + self.task_template = task_template + self.signal_rec = None + self.signal_exp_name = signal_exp_name + self.ta = TimeAdjuster() + self.rg = rolling_gen + self.tool = OnlineToolR(self.exp_name) + + def get_collector(self, rec_key_func=None, rec_filter_func=None): + """ + Get the instance of collector to collect results. The returned collector must can distinguish results in different models. + Assumption: the models can be distinguished based on model name and rolling test segments. + If you do not want this assumption, please implement your own method or use another rec_key_func. + + Args: + rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id. + rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None. + """ + + def rec_key(recorder): + task_config = recorder.load_object("task") + model_key = task_config["model"]["class"] + rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"] + return model_key, rolling_key + + if rec_key_func is None: + rec_key_func = rec_key + + artifacts_collector = RecorderCollector( + experiment=self.exp_name, + process_list=RollingGroup(), + rec_key_func=rec_key_func, + rec_filter_func=rec_filter_func, + ) + + signals_collector = RecorderCollector( + experiment=self.signal_exp_name, + rec_key_func=lambda rec: rec.info["name"], + rec_filter_func=lambda rec: rec.info["name"] == self.exp_name, + artifacts_path={"signals": "signals"}, + ) + return HyperCollector({"artifacts": artifacts_collector, "signals": signals_collector}) + + def first_train(self): + """ + Use rolling_gen to generate different tasks based on task_template and trained them. + + Returns: + Collector: a instance of a Collector. + """ + tasks = task_generator( + tasks=self.task_template, + generators=self.rg, # generate different date segment + ) + return self.prepare_online_models(tasks) + + def prepare_tasks(self, cur_time): + """ + Prepare new tasks based on cur_time (None for latest). + + Returns: + list: a list of new tasks. + """ + latest_records, max_test = self._list_latest(self.tool.online_models()) + if max_test is None: + self.logger.warn(f"No latest online recorders, no new tasks.") + return [] + calendar_latest = D.calendar(end_time=cur_time)[-1] if cur_time is None else cur_time + if self.need_log: + self.logger.info( + f"The interval between current time {calendar_latest} and last rolling test begin time {max_test[0]} is {self.ta.cal_interval(calendar_latest, max_test[0])}, the rolling step is {self.rg.step}" + ) + if self.ta.cal_interval(calendar_latest, max_test[0]) >= self.rg.step: + old_tasks = [] + tasks_tmp = [] + for rec in latest_records: + task = rec.load_object("task") + old_tasks.append(deepcopy(task)) + test_begin = task["dataset"]["kwargs"]["segments"]["test"][0] + # modify the test segment to generate new tasks + task["dataset"]["kwargs"]["segments"]["test"] = (test_begin, calendar_latest) + tasks_tmp.append(task) + new_tasks_tmp = task_generator(tasks_tmp, self.rg) + new_tasks = [task for task in new_tasks_tmp if task not in old_tasks] + return new_tasks + return [] + + def prepare_signals(self, delay=False, over_write=False): + """ + Average the predictions of online models and offer a trading signals every routine. + The signals will be saved to `signal` file of a recorder named self.exp_name of a experiment using the name of `SIGNAL_EXP` + Even if the latest signal already exists, the latest calculation result will be overwritten. + NOTE: Given a prediction of a certain time, all signals before this time will be prepared well. + Args: + over_write (bool, optional): If True, the new signals will overwrite the file. If False, the new signals will append to the end of signals. Defaults to False. + Returns: + object: the signals. + """ + if not delay: + self.tool.update_online_pred() + if self.signal_rec is None: + with R.start(experiment_name=self.signal_exp_name, recorder_name=self.exp_name, resume=True): + self.signal_rec = R.get_recorder() + + pred = [] + try: + old_signals = self.signal_rec.load_object("signals") + except OSError: + old_signals = None + + for rec in self.tool.online_models(): + pred.append(rec.load_object("pred.pkl")) + + signals = pd.concat(pred, axis=1).mean(axis=1).to_frame("score") + signals = signals.sort_index() + if old_signals is not None and not over_write: + old_max = old_signals.index.get_level_values("datetime").max() + new_signals = signals.loc[old_max:] + signals = pd.concat([old_signals, new_signals], axis=0) + else: + new_signals = signals + if self.need_log: + self.logger.info( + f"Finished preparing new {len(new_signals)} signals to {self.signal_exp_name}/{self.exp_name}." + ) + self.signal_rec.save_objects(**{"signals": signals}) + return signals + + # def get_signals(self): + # """ + # get signals from the recorder(named self.exp_name) of the experiment(named self.SIGNAL_EXP) + + # Returns: + # signals + # """ + # if self.signal_rec is None: + # with R.start(experiment_name=self.signal_exp_name, recorder_name=self.exp_name, resume=True): + # self.signal_rec = R.get_recorder() + # signals = None + # try: + # signals = self.signal_rec.load_object("signals") + # except OSError: + # self.logger.warn("Can not find `signals`, have you called `prepare_signals` before?") + # return signals + + def _list_latest(self, rec_list): + if len(rec_list) == 0: + return rec_list, None + max_test = max(rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] for rec in rec_list) + latest_rec = [] + for rec in rec_list: + if rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] == max_test: + latest_rec.append(rec) + return latest_rec, max_test diff --git a/qlib/workflow/online/utils.py b/qlib/workflow/online/utils.py new file mode 100644 index 000000000..1cd89d668 --- /dev/null +++ b/qlib/workflow/online/utils.py @@ -0,0 +1,165 @@ +""" +This module is like a online backend, deciding which models are `online` models and how can change them +""" +from typing import List, Union +from qlib.log import get_module_logger +from qlib.workflow.online.update import PredUpdater +from qlib.workflow.recorder import Recorder +from qlib.workflow.task.utils import list_recorders + + +class OnlineTool: + + ONLINE_KEY = "online_status" # the online status key in recorder + ONLINE_TAG = "online" # the 'online' model + # NOTE: The meaning of this tag is that we can not assume the training models can be trained before we need its predition. Whenever finished training, it can be guaranteed that there are some online models. + NEXT_ONLINE_TAG = "next_online" # the 'next online' model, which can be 'online' model when call reset_online_model + OFFLINE_TAG = "offline" # the 'offline' model, not for online serving + + def __init__(self, need_log=True): + """ + init OnlineTool. + + Args: + need_log (bool, optional): print log or not. Defaults to True. + """ + self.logger = get_module_logger(self.__class__.__name__) + self.need_log = need_log + self.cur_time = None + + def set_online_tag(self, tag, recorder): + """ + Set `tag` to the model to sign whether online. + + Args: + tag (str): the tags in `ONLINE_TAG`, `NEXT_ONLINE_TAG`, `OFFLINE_TAG` + """ + raise NotImplementedError(f"Please implement the `set_online_tag` method.") + + def get_online_tag(self): + """ + Given a model and return its online tag. + """ + raise NotImplementedError(f"Please implement the `get_online_tag` method.") + + def reset_online_tag(self, recorders=None): + """offline all models and set the recorders to 'online'. If no parameter and no 'next online' model, then do nothing. + + Args: + recorders (List, optional): + the recorders you want to reset to 'online'. If don't give, set 'next online' model to 'online' model. If there isn't any 'next online' model, then maintain existing 'online' model. + + Returns: + list: new online recorder. [] if there is no update. + """ + raise NotImplementedError(f"Please implement the `reset_online_tag` method.") + + def online_models(self): + """ + Return `online` models. + """ + raise NotImplementedError(f"Please implement the `online_models` method.") + + def update_online_pred(self, to_date=None): + """ + Update the predictions of online models to a date. + + Args: + to_date (pd.Timestamp): the pred before this date will be updated. None for latest. + + """ + raise NotImplementedError(f"Please implement the `update_online_pred` method.") + + +class OnlineToolR(OnlineTool): + """ + The implementation of OnlineTool based on (R)ecorder. + + """ + + def __init__(self, experiment_name: str, need_log=True): + """ + init OnlineToolR. + + Args: + experiment_name (str): the experiment name. + need_log (bool, optional): print log or not. Defaults to True. + """ + super().__init__(need_log=need_log) + self.exp_name = experiment_name + + def set_online_tag(self, tag, recorder: Union[Recorder, List]): + """ + Set `tag` to the model to sign whether online. + + Args: + tag (str): the tags in `ONLINE_TAG`, `NEXT_ONLINE_TAG`, `OFFLINE_TAG` + recorder (Union[Recorder, List]) + """ + if isinstance(recorder, Recorder): + recorder = [recorder] + for rec in recorder: + rec.set_tags(**{self.ONLINE_KEY: tag}) + if self.need_log: + self.logger.info(f"Set {len(recorder)} models to '{tag}'.") + + def get_online_tag(self, recorder: Recorder): + """ + Given a model and return its online tag. + + Args: + recorder (Recorder): a instance of recorder + + Returns: + str: the tag + """ + tags = recorder.list_tags() + return tags.get(self.ONLINE_KEY, self.OFFLINE_TAG) + + def reset_online_tag(self, recorder: Union[Recorder, List] = None): + """offline all models and set the recorders to 'online'. If no parameter and no 'next online' model, then do nothing. + + Args: + recorders (Union[Recorder, List], optional): + the recorders you want to reset to 'online'. If don't give, set 'next online' model to 'online' model. If there isn't any 'next online' model, then maintain existing 'online' model. + + Returns: + list: new online recorder. [] if there is no update. + """ + if recorder is None: + recorder = list( + list_recorders(self.exp_name, lambda rec: self.get_online_tag(rec) == self.NEXT_ONLINE_TAG).values() + ) + if isinstance(recorder, Recorder): + recorder = [recorder] + if len(recorder) == 0: + if self.need_log: + self.logger.info("No 'next online' model, just use current 'online' models.") + return [] + recs = list_recorders(self.exp_name) + self.set_online_tag(self.OFFLINE_TAG, list(recs.values())) + self.set_online_tag(self.ONLINE_TAG, recorder) + return recorder + + def online_models(self): + """ + Return online models. + + Returns: + list: the list of online models + """ + return list(list_recorders(self.exp_name, lambda rec: self.get_online_tag(rec) == self.ONLINE_TAG).values()) + + def update_online_pred(self, to_date=None): + """ + Update the predictions of online models to a date. + + Args: + to_date (pd.Timestamp): the pred before this date will be updated. None for latest in Calendar. + """ + online_models = self.online_models() + for rec in online_models: + PredUpdater(rec, to_date=to_date, need_log=self.need_log).update() + + if self.need_log: + self.logger.info(f"Finished updating {len(online_models)} online model predictions of {self.exp_name}.") diff --git a/qlib/workflow/task/collect.py b/qlib/workflow/task/collect.py index b4c81122d..eb0a20029 100644 --- a/qlib/workflow/task/collect.py +++ b/qlib/workflow/task/collect.py @@ -1,5 +1,6 @@ from abc import abstractmethod from typing import Callable, Union +from qlib import init from qlib.workflow import R from qlib.workflow.task.utils import list_recorders from qlib.utils.serial import Serializable @@ -109,6 +110,27 @@ class Collector: raise TypeError(f"The instance of {type(collector)} is not a valid `Collector`!") +class HyperCollector(Collector): + """ + A collector to collect the results of other Collectors + """ + + def __init__(self, collector_dict, process_list=[]): + """ + Args: + collector_dict (dict): the dict like {collector_key, Collector} + process_list (list or Callable): the list of processors or the instance of processor to process dict. + """ + super().__init__(process_list=process_list) + self.collector_dict = collector_dict + + def collect(self): + collect_dict = {} + for key, collector in self.collector_dict.items(): + collect_dict[key] = collector() + return collect_dict + + class RecorderCollector(Collector): ART_KEY_RAW = "__raw" @@ -180,3 +202,6 @@ class RecorderCollector(Collector): collect_dict.setdefault(key, {})[rec_key] = artifact return collect_dict + + def get_exp_name(self): + return self.experiment.name From 67c5740c83b428519427854efb214e58c28eb9ab Mon Sep 17 00:00:00 2001 From: lzh222333 Date: Thu, 29 Apr 2021 04:30:09 +0000 Subject: [PATCH 23/30] OnlineServing V9 --- .../model_rolling/task_manager_rolling.py | 29 +- .../online_srv/online_management_simulate.py | 54 +- .../online_srv/rolling_online_management.py | 43 +- examples/online_srv/update_online_pred.py | 3 + qlib/data/dataset/__init__.py | 8 +- qlib/model/ens/ensemble.py | 45 +- qlib/model/ens/group.py | 15 +- qlib/model/task.py | 27 - qlib/model/trainer.py | 288 ++++++---- qlib/utils/serial.py | 5 +- qlib/workflow/online/manager.py | 540 ++---------------- qlib/workflow/online/simulator.py | 77 --- qlib/workflow/online/strategy.py | 103 +++- qlib/workflow/online/update.py | 84 +-- qlib/workflow/online/utils.py | 94 +-- qlib/workflow/task/collect.py | 32 +- qlib/workflow/task/gen.py | 24 +- qlib/workflow/task/manage.py | 155 +++-- qlib/workflow/task/utils.py | 61 +- 19 files changed, 677 insertions(+), 1010 deletions(-) delete mode 100644 qlib/model/task.py delete mode 100644 qlib/workflow/online/simulator.py diff --git a/examples/model_rolling/task_manager_rolling.py b/examples/model_rolling/task_manager_rolling.py index ab3a4eee5..175319885 100644 --- a/examples/model_rolling/task_manager_rolling.py +++ b/examples/model_rolling/task_manager_rolling.py @@ -1,24 +1,23 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +This example shows how a TrainerRM work based on TaskManager with rolling tasks. +After training, how to collect the rolling results will be showed in task_collecting. +""" + from pprint import pprint -import time import fire import qlib from qlib.config import REG_CN -from qlib.model.trainer import TrainerR, task_train from qlib.workflow import R from qlib.workflow.task.gen import RollingGen, task_generator -from qlib.workflow.task.manage import TaskManager, run_task +from qlib.workflow.task.manage import TaskManager from qlib.workflow.task.collect import RecorderCollector -from qlib.model.ens.ensemble import RollingEnsemble, ens_workflow -import pandas as pd -from qlib.workflow.task.utils import list_recorders from qlib.model.ens.group import RollingGroup from qlib.model.trainer import TrainerRM -""" -This example shows how a Trainer work based on TaskManager with rolling tasks. -After training, how to collect the rolling results will be showed in task_collecting. -""" data_handler_config = { "start_time": "2008-01-01", @@ -139,11 +138,13 @@ class RollingTaskExample: return True return False - artifact = ens_workflow( - RecorderCollector(experiment=self.experiment_name, rec_key_func=rec_key, rec_filter_func=my_filter), - RollingGroup(), + collector = RecorderCollector( + experiment=self.experiment_name, + process_list=RollingGroup(), + rec_key_func=rec_key, + rec_filter_func=my_filter, ) - print(artifact) + print(collector()) def main(self): self.reset() diff --git a/examples/online_srv/online_management_simulate.py b/examples/online_srv/online_management_simulate.py index 6a1d233ae..16e985ccd 100644 --- a/examples/online_srv/online_management_simulate.py +++ b/examples/online_srv/online_management_simulate.py @@ -1,23 +1,17 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. + """ -This examples is about the OnlineManager and OnlineSimulator based on rolling tasks. -The OnlineManager will focus on the updating of your online models. -The OnlineSimulator will focus on the simulating real updating routine of your online models. +This examples is about how can simulate the OnlineManager based on rolling tasks. """ + import fire import qlib -from qlib.model.ens.ensemble import ens_workflow -from qlib.model.trainer import DelayTrainerR, DelayTrainerRM, TrainerRM -from qlib.workflow import R -from qlib.workflow.online.manager import OnlineM # RollingOnlineManager -from qlib.workflow.online.strategy import OnlineStrategy, RollingAverageStrategy -from qlib.workflow.task.collect import RecorderCollector -from qlib.workflow.task.gen import RollingGen, task_generator +from qlib.model.trainer import DelayTrainerRM +from qlib.workflow.online.manager import OnlineManager +from qlib.workflow.online.strategy import RollingAverageStrategy +from qlib.workflow.task.gen import RollingGen from qlib.workflow.task.manage import TaskManager -from qlib.workflow.task.utils import list_recorders - - data_handler_config = { @@ -89,10 +83,10 @@ class OnlineSimulationExample: rolling_step=80, start_time="2018-09-10", end_time="2018-10-31", - tasks=[task_xgboost_config], # , task_lgb_config] + tasks=[task_xgboost_config, task_lgb_config], ): """ - init OnlineManagerExample. + Init OnlineManagerExample. Args: provider_uri (str, optional): the provider uri. Defaults to "~/.qlib/qlib_data/cn_data". @@ -120,42 +114,28 @@ class OnlineSimulationExample: ) # The rolling tasks generator, modify_end_time is false because we just need simulate to 2018-10-31. self.trainer = DelayTrainerRM(self.exp_name, self.task_pool) self.task_manager = TaskManager(self.task_pool) # A good way to manage all your tasks - self.rolling_online_manager = OnlineM( + self.rolling_online_manager = OnlineManager( RollingAverageStrategy( exp_name, task_template=tasks, rolling_gen=self.rolling_gen, trainer=self.trainer, need_log=False ), begin_time=self.start_time, need_log=False, - ) # The OnlineManager based on Rolling - # self.onlinesimulator = OnlineSimulator( - # start_time=start_time, - # end_time=end_time, - # online_manager=self.rolling_online_manager, - # ) + ) self.tasks = tasks - # Reset all things to the first status, be careful to save important data - def reset(self): - print("========== reset ==========") - self.task_manager.remove() - - exp = R.get_exp(experiment_name=self.exp_name) - for rid in exp.list_recorders(): - exp.delete_recorder(rid) - - for rid in list_recorders("OnlineManagerSignals", lambda x: True if x.info["name"] == self.exp_name else False): - exp.delete_recorder(rid) - - # Run this to run all workflow automaticly + # Run this to run all workflow automatically def main(self): - self.reset() + print("========== reset ==========") + self.rolling_online_manager.reset() print("========== simulate ==========") self.rolling_online_manager.simulate(end_time=self.end_time) + print("========== collect results ==========") print(self.rolling_online_manager.get_collector()()) + print("========== online history ==========") print(self.rolling_online_manager.get_online_history(self.exp_name)) if __name__ == "__main__": - ## to run all workflow automaticly with your own parameters, use the command below + ## to run all workflow automatically with your own parameters, use the command below # python online_management_simulate.py main --experiment_name="your_exp_name" --rolling_step=60 fire.Fire(OnlineSimulationExample) diff --git a/examples/online_srv/rolling_online_management.py b/examples/online_srv/rolling_online_management.py index 7b2f58909..950c9684d 100644 --- a/examples/online_srv/rolling_online_management.py +++ b/examples/online_srv/rolling_online_management.py @@ -1,22 +1,25 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + """ -This example show how RollingOnlineManager works with rolling tasks. +This example show how OnlineManager works with rolling tasks. There are two parts including first train and routine. -Firstly, the RollingOnlineManager will finish the first training and set trained models to `online` models. -Next, the RollingOnlineManager will finish a routine process, including update online prediction -> prepare signals -> prepare tasks -> prepare new models -> reset online models +Firstly, the OnlineManager will finish the first training and set trained models to `online` models. +Next, the OnlineManager will finish a routine process, including update online prediction -> prepare signals -> prepare tasks -> prepare new models -> reset online models """ + import os from pathlib import Path import pickle import fire import qlib from qlib.workflow import R -from qlib.workflow.online.strategy import OnlineStrategy, RollingAverageStrategy +from qlib.workflow.online.strategy import RollingAverageStrategy from qlib.workflow.task.gen import RollingGen from qlib.workflow.task.manage import TaskManager -from qlib.workflow.online.manager import OnlineM +from qlib.workflow.online.manager import OnlineManager from qlib.workflow.task.utils import list_recorders from qlib.model.trainer import TrainerRM -from pprint import pprint data_handler_config = { "start_time": "2013-01-01", @@ -94,7 +97,7 @@ class RollingOnlineExample: self.rolling_step = rolling_step strategy = [] for task in tasks: - name_id = task["model"]["class"] + "_" + str(self.rolling_step) + name_id = task["model"]["class"] # NOTE: Assumption: The model class can specify only one strategy strategy.append( RollingAverageStrategy( name_id, @@ -104,9 +107,12 @@ class RollingOnlineExample: ) ) - self.rolling_online_manager = OnlineM(strategy) + self.rolling_online_manager = OnlineManager(strategy) + self.collector = self.rolling_online_manager.get_collector() - _ROLLING_MANAGER_PATH = ".rolling_manager" # the RollingOnlineManager will dump to this file, for it will be loaded when calling routine. + _ROLLING_MANAGER_PATH = ( + ".RollingOnlineExample" # the OnlineManager will dump to this file, for it can be loaded when calling routine. + ) # Reset all things to the first status, be careful to save important data def reset(self): @@ -125,18 +131,23 @@ class RollingOnlineExample: exp.delete_recorder(rid) def first_run(self): + print("========== reset ==========") + self.rolling_online_manager.reset() print("========== first_run ==========") - self.reset() self.rolling_online_manager.first_train() + print("========== dump ==========") self.rolling_online_manager.to_pickle(self._ROLLING_MANAGER_PATH) - print(self.rolling_online_manager.get_collector()()) + print("========== collect results ==========") + print(self.collector()) def routine(self): - print("========== routine ==========") + print("========== load ==========") with Path(self._ROLLING_MANAGER_PATH).open("rb") as f: self.rolling_online_manager = pickle.load(f) + print("========== routine ==========") self.rolling_online_manager.routine() - print(self.rolling_online_manager.get_collector()()) + print("========== collect results ==========") + print(self.collector()) def main(self): self.first_run() @@ -145,11 +156,11 @@ class RollingOnlineExample: if __name__ == "__main__": ####### to train the first version's models, use the command below - # python task_manager_rolling_with_updating.py first_run + # python rolling_online_management.py first_run ####### to update the models and predictions after the trading time, use the command below - # python task_manager_rolling_with_updating.py after_day + # python rolling_online_management.py after_day ####### to define your own parameters, use `--` - # python task_manager_rolling_with_updating.py first_run --exp_name='your_exp_name' --rolling_step=40 + # python rolling_online_management.py first_run --exp_name='your_exp_name' --rolling_step=40 fire.Fire(RollingOnlineExample) diff --git a/examples/online_srv/update_online_pred.py b/examples/online_srv/update_online_pred.py index a02b209bd..6e2725c7a 100644 --- a/examples/online_srv/update_online_pred.py +++ b/examples/online_srv/update_online_pred.py @@ -1,3 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + """ This example show how OnlineTool works when we need update prediction. There are two parts including first_train and update_online_pred. diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index 5485796ef..4457dda5f 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -299,7 +299,7 @@ class TSDataSampler: self.start_idx, self.end_idx = self.data_index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end)) self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance - + del self.data # save memory @staticmethod @@ -507,17 +507,17 @@ class TSDatasetH(DatasetH): """ dtype = kwargs.pop("dtype") start, end = slc.start, slc.stop - flt_col = kwargs.pop('flt_col', None) + flt_col = kwargs.pop("flt_col", None) # TSDatasetH will retrieve more data for complete data = self._prepare_raw_seg(slc, **kwargs) flt_kwargs = deepcopy(kwargs) if flt_col is not None: - flt_kwargs['col_set'] = flt_col + flt_kwargs["col_set"] = flt_col flt_data = self._prepare_raw_seg(slc, **flt_kwargs) assert len(flt_data.columns) == 1 else: flt_data = None tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len, dtype=dtype, flt_data=flt_data) - return tsds \ No newline at end of file + return tsds diff --git a/qlib/model/ens/ensemble.py b/qlib/model/ens/ensemble.py index 63f6438c2..7ccf98ab2 100644 --- a/qlib/model/ens/ensemble.py +++ b/qlib/model/ens/ensemble.py @@ -1,36 +1,11 @@ -from abc import abstractmethod -from typing import Callable, Union +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Ensemble can merge the objects in an Ensemble. For example, if there are many submodels predictions, we may need to merge them in an ensemble predictions. +""" import pandas as pd -from qlib.workflow.task.collect import Collector -from qlib.utils.serial import Serializable - - -def ens_workflow(collector: Collector, process_list, *args, **kwargs): - """the ensemble workflow based on collector and different dict processors. - - Args: - collector (Collector): the collector to collect the result into {result_key: things} - process_list (list or Callable): the list of processors or the instance of processor to process dict. - The processor order is same as the list order. - For example: [Group1(..., Ensemble1()), Group2(..., Ensemble2())] - Returns: - dict: the ensemble dict - """ - collect_dict = collector.collect() - if not isinstance(process_list, list): - process_list = [process_list] - - ensemble = {} - for artifact in collect_dict: - value = collect_dict[artifact] - for process in process_list: - if not callable(process): - raise NotImplementedError(f"{type(process)} is not supported in `ens_workflow`.") - value = process(value, *args, **kwargs) - ensemble[artifact] = value - - return ensemble class Ensemble: @@ -53,17 +28,17 @@ class RollingEnsemble(Ensemble): """Merge the rolling objects in an Ensemble""" - def __call__(self, ensemble_dict: dict): + def __call__(self, ensemble_dict: dict) -> pd.DataFrame: """Merge a dict of rolling dataframe like `prediction` or `IC` into an ensemble. - NOTE: The values of dict must be pd.Dataframe, and have the index "datetime" + NOTE: The values of dict must be pd.DataFrame, and have the index "datetime" Args: - ensemble_dict (dict): a dict like {"A": pd.Dataframe, "B": pd.Dataframe}. + ensemble_dict (dict): a dict like {"A": pd.DataFrame, "B": pd.DataFrame}. The key of the dict will be ignored. Returns: - pd.Dataframe: the complete result of rolling. + pd.DataFrame: the complete result of rolling. """ artifact_list = list(ensemble_dict.values()) artifact_list.sort(key=lambda x: x.index.get_level_values("datetime").min()) diff --git a/qlib/model/ens/group.py b/qlib/model/ens/group.py index c80959b0d..d53a55f4c 100644 --- a/qlib/model/ens/group.py +++ b/qlib/model/ens/group.py @@ -1,3 +1,10 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Group can group a set of object based on `group_func` and change them to a dict. +""" + from qlib.model.ens.ensemble import Ensemble, RollingEnsemble from typing import Callable, Union from joblib import Parallel, delayed @@ -21,20 +28,20 @@ class Group: self._group_func = group_func self._ens_func = ens - def group(self, *args, **kwargs): + def group(self, *args, **kwargs) -> dict: # TODO: such design is weird when `_group_func` is the only configurable part in the class if isinstance(getattr(self, "_group_func", None), Callable): return self._group_func(*args, **kwargs) else: raise NotImplementedError(f"Please specify valid `group_func`.") - def reduce(self, *args, **kwargs): + def reduce(self, *args, **kwargs) -> dict: if isinstance(getattr(self, "_ens_func", None), Callable): return self._ens_func(*args, **kwargs) else: raise NotImplementedError(f"Please specify valid `_ens_func`.") - def __call__(self, ungrouped_dict: dict, n_jobs=1, verbose=0, *args, **kwargs): + def __call__(self, ungrouped_dict: dict, n_jobs=1, verbose=0, *args, **kwargs) -> dict: """Group the ungrouped_dict into different groups. Args: @@ -59,7 +66,7 @@ class Group: class RollingGroup(Group): """group the rolling dict""" - def group(self, rolling_dict: dict): + def group(self, rolling_dict: dict) -> dict: """Given an rolling dict likes {(A,B,R): things}, return the grouped dict likes {(A,B): {R:things}} NOTE: There is a assumption which is the rolling key is at the end of key tuple, because the rolling results always need to be ensemble firstly. diff --git a/qlib/model/task.py b/qlib/model/task.py deleted file mode 100644 index f29f513a4..000000000 --- a/qlib/model/task.py +++ /dev/null @@ -1,27 +0,0 @@ -import abc -import typing - - -class TaskGen(metaclass=abc.ABCMeta): - @abc.abstractmethod - def __call__(self, *args, **kwargs) -> typing.List[dict]: - """ - generate - - Parameters - ---------- - args, kwargs: - The info for generating tasks - Example 1): - input: a specific task template - output: rolling version of the tasks - Example 2): - input: a specific task template - output: a set of tasks with different losses - - Returns - ------- - typing.List[dict]: - A list of tasks - """ - pass diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py index 0dcc1d67a..a0d252ab4 100644 --- a/qlib/model/trainer.py +++ b/qlib/model/trainer.py @@ -1,59 +1,72 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import copy +""" +The Trainer will train a list of tasks and return a list of model recorder. +There are two steps in each Trainer including `train`(make model recorder) and `end_train`(modify model recorder). + +This is concept called "DelayTrainer", which can be used in online simulating to parallel training. +In "DelayTrainer", the first step is only to save some necessary info to model recorder, and the second step which will be finished in the end can do some concurrent and time-consuming operations such as model fitting. + +`Qlib` offer two kind of Trainer, TrainerR is simplest and TrainerRM is based on TaskManager to help manager tasks lifecycle automatically. +""" + +import socket import time -from xxlimited import Str -from qlib.utils import init_instance_by_config, flatten_dict, get_cls_kwargs -from qlib.workflow import R -from qlib.workflow.recorder import Recorder -from qlib.workflow.record_temp import SignalRecord -from qlib.workflow.task.manage import TaskManager, run_task +from typing import Callable, List + from qlib.data.dataset import Dataset from qlib.model.base import Model -import socket +from qlib.utils import flatten_dict, get_cls_kwargs, init_instance_by_config +from qlib.workflow import R +from qlib.workflow.record_temp import SignalRecord +from qlib.workflow.recorder import Recorder +from qlib.workflow.task.manage import TaskManager, run_task -def begin_task_train(task_config: dict, experiment_name: str, *args, **kwargs) -> Recorder: +def begin_task_train(task_config: dict, experiment_name: str, recorder_name: str = None) -> Recorder: """ - Begin a task training with starting a recorder and saving the task config. + Begin a task training to start a recorder and save the task config. Args: - task_config (dict) - experiment_name (str) + task_config (dict): the config of a task + experiment_name (str): the name of experiment + recorder_name (str): the given name will be the recorder name. None for using rid. Returns: - Recorder + Recorder: the model recorder """ # FIXME: recorder_id - with R.start(experiment_name=experiment_name, recorder_name=str(time.time())): + if recorder_name is None: + recorder_name = str(time.time()) + with R.start(experiment_name=experiment_name, recorder_name=recorder_name): R.log_params(**flatten_dict(task_config)) R.save_objects(**{"task": task_config}) # keep the original format and datatype - R.set_tags(**{"hostname": socket.gethostname(), "train_status": "begin_task_train"}) + R.set_tags(**{"hostname": socket.gethostname()}) recorder: Recorder = R.get_recorder() return recorder -def end_task_train(rec: Recorder, experiment_name: str, *args, **kwargs): +def end_task_train(rec: Recorder, experiment_name: str) -> Recorder: """ - Finished task training with real model fitting and saving. + Finish task training with real model fitting and saving. Args: - rec (Recorder): This recorder will be resumed - experiment_name (str) + rec (Recorder): the recorder will be resumed + experiment_name (str): the name of experiment Returns: - Recorder + Recorder: the model recorder """ with R.start(experiment_name=experiment_name, recorder_name=rec.info["name"], resume=True): task_config = R.load_object("task") - # model & dataset initiaiton + # model & dataset initiation model: Model = init_instance_by_config(task_config["model"]) dataset: Dataset = init_instance_by_config(task_config["dataset"]) # model training model.fit(dataset) R.save_objects(**{"params.pkl": model}) - # This dataset is saved for online inference. So the concrete data should not be dumped + # this dataset is saved for online inference. So the concrete data should not be dumped dataset.config(dump_all=False, recursive=True) R.save_objects(**{"dataset": dataset}) # generate records: prediction, backtest, and analysis @@ -68,18 +81,18 @@ def end_task_train(rec: Recorder, experiment_name: str, *args, **kwargs): rconf = {"recorder": rec} r = cls(**kwargs, **rconf) r.generate() - R.set_tags(**{"train_status": "end_task_train"}) + return rec def task_train(task_config: dict, experiment_name: str) -> Recorder: """ - task based training + Task based training, will be divided into two steps. Parameters ---------- task_config : dict - A dict describes a task setting. + The config of a task. experiment_name: str The name of experiment @@ -97,42 +110,79 @@ class Trainer: The trainer which can train a list of model """ - def train(self, tasks: list, *args, **kwargs): - """Given a list of model definition, begin a training and return the models. + def __init__(self): + self.delay = False + + def train(self, tasks: list, *args, **kwargs) -> list: + """ + Given a list of model definition, begin a training and return the models. + + Args: + tasks: a list of tasks Returns: list: a list of models """ raise NotImplementedError(f"Please implement the `train` method.") - def end_train(self, models, *args, **kwargs): - """Given a list of models, finished something in the end of training if you need. + def end_train(self, models: list, *args, **kwargs) -> list: + """ + Given a list of models, finished something in the end of training if you need. + The models maybe Recorder, txt file, database and so on. + + Args: + models: a list of models Returns: list: a list of models """ - pass + # do nothing if you finished all work in `train` method + return models - def is_delay(self): - return False + def is_delay(self) -> bool: + """ + If Trainer will delay finishing `end_train`. + + Returns: + bool: if DelayTrainer + """ + return self.delay + + def reset(self): + """ + Reset the Trainer status. + """ + pass class TrainerR(Trainer): - """Trainer based on (R)ecorder. + """ + Trainer based on (R)ecorder. + It will train a list of tasks and return a list of model recorder in a linear way. Assumption: models were defined by `task` and the results will saved to `Recorder` """ - def __init__(self, experiment_name, train_func=task_train): + def __init__(self, experiment_name: str, train_func: Callable = task_train): + """ + Init TrainerR. + + Args: + experiment_name (str): the name of experiment. + train_func (Callable, optional): default training method. Defaults to `task_train`. + """ + super().__init__() self.experiment_name = experiment_name self.train_func = train_func - def train(self, tasks: list, train_func=None, *args, **kwargs): - """Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed. + def train(self, tasks: list, train_func: Callable = None, **kwargs) -> List[Recorder]: + """ + Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed. Args: tasks (list): a list of definition based on `task` dict - train_func (Callable): the train method which need at least `task` and `experiment_name`. None for default. + train_func (Callable): the train method which need at least `task`s and `experiment_name`. None for default training method. + kwargs: the params for train_func. Returns: list: a list of Recorders @@ -141,17 +191,74 @@ class TrainerR(Trainer): train_func = self.train_func recs = [] for task in tasks: - recs.append(train_func(task, self.experiment_name, *args, **kwargs)) + rec = train_func(task, self.experiment_name, **kwargs) + rec.set_tags(**{"train_status": "begin_task_train"}) + recs.append(rec) + return recs + + def end_train(self, recs: list, **kwargs) -> list: + for rec in recs: + rec.set_tags(**{"train_status": "end_task_train"}) + return recs + + +class DelayTrainerR(TrainerR): + """ + A delayed implementation based on TrainerR, which means `train` method may only do some preparation and `end_train` method can do the real model fitting. + """ + + def __init__(self, experiment_name, train_func=begin_task_train, end_train_func=end_task_train): + """ + Init TrainerRM. + + Args: + experiment_name (str): the name of experiment. + train_func (Callable, optional): default train method. Defaults to `begin_task_train`. + end_train_func (Callable, optional): default end_train method. Defaults to `end_task_train`. + """ + super().__init__(experiment_name, train_func) + self.end_train_func = end_train_func + self.delay = True + + def end_train(self, recs, end_train_func=None, **kwargs) -> List[Recorder]: + """ + Given a list of Recorder and return a list of trained Recorder. + This class will finish real data loading and model fitting. + + Args: + recs (list): a list of Recorder, the tasks have been saved to them + end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func. + kwargs: the params for end_train_func. + + Returns: + list: a list of Recorders + """ + if end_train_func is None: + end_train_func = self.end_train_func + for rec in recs: + end_train_func(rec, **kwargs) + rec.set_tags(**{"train_status": "end_task_train"}) return recs class TrainerRM(Trainer): - """Trainer based on (R)ecorder and Task(M)anager + """ + Trainer based on (R)ecorder and Task(M)anager. + It can train a list of tasks and return a list of model recorder in a multiprocessing way. Assumption: `task` will be saved to TaskManager and `task` will be fetched and trained from TaskManager """ def __init__(self, experiment_name: str, task_pool: str, train_func=task_train): + """ + Init TrainerR. + + Args: + experiment_name (str): the name of experiment. + task_pool (str): task pool name in TaskManager. + train_func (Callable, optional): default training method. Defaults to `task_train`. + """ + super().__init__() self.experiment_name = experiment_name self.task_pool = task_pool self.train_func = train_func @@ -159,20 +266,23 @@ class TrainerRM(Trainer): def train( self, tasks: list, - train_func=None, - before_status=TaskManager.STATUS_WAITING, - after_status=TaskManager.STATUS_DONE, - *args, + train_func: Callable = None, + before_status: str = TaskManager.STATUS_WAITING, + after_status: str = TaskManager.STATUS_DONE, **kwargs, - ): - """Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed. + ) -> List[Recorder]: + """ + Given a list of `task`s and return a list of trained Recorder. The order can be guaranteed. This method defaults to a single process, but TaskManager offered a great way to parallel training. Users can customize their train_func to realize multiple processes or even multiple machines. Args: tasks (list): a list of definition based on `task` dict - train_func (Callable): the train method which need at least `task` and `experiment_name`. None for default. + train_func (Callable): the train method which need at least `task`s and `experiment_name`. None for default training method. + before_status (str): the tasks in before_status will be fetched and trained. Can be STATUS_WAITING, STATUS_PART_DONE. + after_status (str): the tasks after trained will become after_status. Can be STATUS_WAITING, STATUS_PART_DONE. + kwargs: the params for train_func. Returns: list: a list of Recorders @@ -187,65 +297,27 @@ class TrainerRM(Trainer): experiment_name=self.experiment_name, before_status=before_status, after_status=after_status, - *args, **kwargs, ) recs = [] for _id in _id_list: - recs.append(tm.re_query(_id)["res"]) + rec = tm.re_query(_id)["res"] + rec.set_tags(**{"train_status": "begin_task_train"}) + recs.append(rec) return recs - -class DelayTrainerR(TrainerR): - """ - A delayed implementation based on TrainerR, which means `train` method may only do some preparation and `end_train` method can do the real model fitting. - - """ - - def __init__(self, experiment_name, train_func=begin_task_train, end_train_func=end_task_train): - super().__init__(experiment_name, train_func) - self.end_train_func = end_train_func - self.recs = [] - - def train(self, tasks: list, train_func, *args, **kwargs): - """ - Same as `train` of TrainerR, the results will be recorded in self.recs - - Args: - tasks (list): a list of definition based on `task` dict - train_func (Callable): the train method which need at least `task` and `experiment_name`. None for default. - - Returns: - list: a list of Recorders - """ - self.recs = super().train(tasks, train_func=train_func, *args, **kwargs) - return self.recs - - def end_train(self, recs=None, end_train_func=None): - """ - Given a list of Recorder and return a list of trained Recorder. - This class will finished real data loading and model fitting. - - Args: - recs (list, optional): a list of Recorder, the tasks have been saved to them. Defaults to None for using self.recs. - end_train_func (Callable, optional): the end_train method which need at least `rec` and `experiment_name`. Defaults to None for using self.end_train_func. - - Returns: - list: a list of Recorders - """ - if recs is None: - recs = copy.deepcopy(self.recs) - # the models will be only trained once - self.recs = [] - if end_train_func is None: - end_train_func = self.end_train_func + def end_train(self, recs: list, **kwargs) -> list: for rec in recs: - end_train_func(rec) + rec.set_tags(**{"train_status": "end_task_train"}) return recs - def is_delay(self): - return True + def reset(self): + """ + NOTE: this method will delete all task in this task_pool! + """ + tm = TaskManager(task_pool=self.task_pool) + tm.remove() class DelayTrainerRM(TrainerRM): @@ -257,28 +329,28 @@ class DelayTrainerRM(TrainerRM): def __init__(self, experiment_name, task_pool: str, train_func=begin_task_train, end_train_func=end_task_train): super().__init__(experiment_name, task_pool, train_func) self.end_train_func = end_train_func + self.delay = True - def train(self, tasks: list, train_func=None, *args, **kwargs): + def train(self, tasks: list, train_func=None, **kwargs): """ - Same as `train` of TrainerRM, the results will be recorded in self.recs - + Same as `train` of TrainerRM, after_status will be STATUS_PART_DONE. Args: tasks (list): a list of definition based on `task` dict - train_func (Callable): the train method which need at least `task` and `experiment_name`. None for default. - + train_func (Callable): the train method which need at least `task`s and `experiment_name`. Defaults to None for using self.train_func. Returns: list: a list of Recorders """ - return super().train(tasks, train_func=train_func, after_status=TaskManager.STATUS_PART_DONE, *args, **kwargs) + return super().train(tasks, train_func=train_func, after_status=TaskManager.STATUS_PART_DONE, **kwargs) - def end_train(self, recs, end_train_func=None): + def end_train(self, recs, end_train_func=None, **kwargs): """ Given a list of Recorder and return a list of trained Recorder. - This class will finished real data loading and model fitting. + This class will finish real data loading and model fitting. Args: - recs (list, optional): a list of Recorder, the tasks have been saved to them. Defaults to None for using self.recs.. - end_train_func (Callable, optional): the end_train method which need at least `rec` and `experiment_name`. Defaults to None for using self.end_train_func. + recs (list): a list of Recorder, the tasks have been saved to them. + end_train_func (Callable, optional): the end_train method which need at least `recorder`s and `experiment_name`. Defaults to None for using self.end_train_func. + kwargs: the params for end_train_func. Returns: list: a list of Recorders @@ -291,8 +363,8 @@ class DelayTrainerRM(TrainerRM): self.task_pool, experiment_name=self.experiment_name, before_status=TaskManager.STATUS_PART_DONE, + **kwargs, ) + for rec in recs: + rec.set_tags(**{"train_status": "end_task_train"}) return recs - - def is_delay(self): - return True diff --git a/qlib/utils/serial.py b/qlib/utils/serial.py index 1b775d99a..52d326c2a 100644 --- a/qlib/utils/serial.py +++ b/qlib/utils/serial.py @@ -3,11 +3,12 @@ from pathlib import Path import pickle +from typing import Union class Serializable: """ - Serializable will change the behaviours of pickle. + Serializable will change the behaviors of pickle. - It only saves the state whose name **does not** start with `_` It provides a syntactic sugar for distinguish the attributes which user doesn't want. - For examples, a learnable Datahandler just wants to save the parameters without data when dumping to disk @@ -70,7 +71,7 @@ class Serializable: obj.config(**params, recursive=True) del self.__dict__[self.FLAG_KEY] - def to_pickle(self, path: [Path, str], dump_all: bool = None, exclude: list = None): + def to_pickle(self, path: Union[Path, str], dump_all: bool = None, exclude: list = None): self.config(dump_all=dump_all, exclude=exclude) with Path(path).open("wb") as f: pickle.dump(self, f) diff --git a/qlib/workflow/online/manager.py b/qlib/workflow/online/manager.py index f8266577b..4e9290096 100644 --- a/qlib/workflow/online/manager.py +++ b/qlib/workflow/online/manager.py @@ -2,487 +2,40 @@ # Licensed under the MIT License. """ -This class is a component of online serving, it can manage a series of models dynamically. -With the change of time, the decisive models will be also changed. In this module, we called those contributing models as `online` models. +OnlineManager can manage a set of OnlineStrategy and run them dynamically. + +With the change of time, the decisive models will be also changed. In this module, we call those contributing models as `online` models. In every routine(such as everyday or every minutes), the `online` models maybe changed and the prediction of them need to be updated. So this module provide a series methods to control this process. """ -from copy import deepcopy -from pprint import pprint -import pandas as pd -from qlib.model.ens.ensemble import ens_workflow -from qlib.model.ens.group import RollingGroup -from qlib.utils.serial import Serializable + from typing import Dict, List, Union + +import pandas as pd from qlib import get_module_logger from qlib.data.data import D -from qlib.model.trainer import Trainer, TrainerR, task_train -from qlib.workflow import R +from qlib.utils.serial import Serializable from qlib.workflow.online.strategy import OnlineStrategy -from qlib.workflow.online.update import PredUpdater -from qlib.workflow.recorder import Recorder -from qlib.workflow.task.collect import Collector, HyperCollector, RecorderCollector -from qlib.workflow.task.gen import RollingGen, task_generator -from qlib.workflow.task.utils import TimeAdjuster, list_recorders +from qlib.workflow.task.collect import HyperCollector + class OnlineManager(Serializable): - - ONLINE_KEY = "online_status" # the online status key in recorder - ONLINE_TAG = "online" # the 'online' model - # NOTE: The meaning of this tag is that we can not assume the training models can be trained before we need its predition. Whenever finished training, it can be guaranteed that there are some online models. - NEXT_ONLINE_TAG = "next_online" # the 'next online' model, which can be 'online' model when call reset_online_model - OFFLINE_TAG = "offline" # the 'offline' model, not for online serving - - SIGNAL_EXP = "OnlineManagerSignals" # a specific experiment to save signals of different experiment. - - def __init__(self, trainer: Trainer = None, need_log=True): - """ - init OnlineManager. - - Args: - trainer (Trainer, optional): a instance of Trainer. Defaults to None. - need_log (bool, optional): print log or not. Defaults to True. - """ - self.trainer = trainer - self.logger = get_module_logger(self.__class__.__name__) - self.need_log = need_log - self.cur_time = None - - def prepare_signals(self): - """ - After perparing the data of last routine (a box in box-plot) which means the end of the routine, we can prepare trading signals for next routine. - Must use `pass` even though there is nothing to do. - """ - raise NotImplementedError(f"Please implement the `prepare_signals` method.") - - def get_signals(self): - """ - After preparing signals, here is the method to get them. - """ - raise NotImplementedError(f"Please implement the `get_signals` method.") - - def prepare_tasks(self, *args, **kwargs): - """ - After the end of a routine, check whether we need to prepare and train some new tasks. - return the new tasks waiting for training. - """ - raise NotImplementedError(f"Please implement the `prepare_tasks` method.") - - def prepare_new_models(self, tasks, tag=NEXT_ONLINE_TAG, check_func=None, *args, **kwargs): - """ - Use trainer to train a list of tasks and set the trained model to `tag`. - - Args: - tasks (list): a list of tasks. - tag (str): - `ONLINE_TAG` for first train or additional train - `NEXT_ONLINE_TAG` for reset online model when calling `reset_online_tag` - `OFFLINE_TAG` for train but offline those models - check_func: the method to judge if a model can be online. - The parameter is the model record and return True for online. - None for online every models. - *args, **kwargs: will be passed to end_train which means will be passed to customized train method. - - """ - if check_func is None: - check_func = lambda x: True - if len(tasks) > 0: - if self.trainer is not None: - new_models = self.trainer.train(tasks, *args, **kwargs) - if check_func(new_models): - self.set_online_tag(tag, new_models) - if self.need_log: - self.logger.info(f"Finished preparing {len(new_models)} new models and set them to {tag}.") - else: - self.logger.warn("No trainer to train new tasks.") - - def update_online_pred(self): - """ - After the end of a routine, update the predictions of online models to latest. - """ - raise NotImplementedError(f"Please implement the `update_online_pred` method.") - - def set_online_tag(self, tag, recorder): - """ - Set `tag` to the model to sign whether online. - - Args: - tag (str): the tags in `ONLINE_TAG`, `NEXT_ONLINE_TAG`, `OFFLINE_TAG` - """ - raise NotImplementedError(f"Please implement the `set_online_tag` method.") - - def get_online_tag(self): - """ - Given a model and return its online tag. - """ - raise NotImplementedError(f"Please implement the `get_online_tag` method.") - - def reset_online_tag(self, recorders=None): - """offline all models and set the recorders to 'online'. If no parameter and no 'next online' model, then do nothing. - - Args: - recorders (List, optional): - the recorders you want to reset to 'online'. If don't give, set 'next online' model to 'online' model. If there isn't any 'next online' model, then maintain existing 'online' model. - - Returns: - list: new online recorder. [] if there is no update. - """ - raise NotImplementedError(f"Please implement the `reset_online_tag` method.") - - def online_models(self): - """ - Return online models. - """ - raise NotImplementedError(f"Please implement the `online_models` method.") - - def first_train(self): - """ - Train a series of models firstly and set some of them into online models. - """ - raise NotImplementedError(f"Please implement the `first_train` method.") - - def get_collector(self): - """ - Return the collector. - - Returns: - Collector - """ - raise NotImplementedError(f"Please implement the `get_collector` method.") - - def delay_prepare(self, rec_dict, *args, **kwargs): - """ - Prepare all models and signals if there are something waiting for prepare. - NOTE: Assumption: the predictions of online models are between `time_segment`, or this method will work in a wrong way. - - Args: - rec_dict (str): an online models dict likes {(begin_time, end_time):[online models]}. - *args, **kwargs: will be passed to end_train which means will be passed to customized train method. - """ - for time_segment, recs_list in rec_dict.items(): - self.trainer.end_train(recs_list, *args, **kwargs) - self.reset_online_tag(recs_list) - self.prepare_signals() - signal_max = self.get_signals().index.get_level_values("datetime").max() - if time_segment[1] is not None and signal_max > time_segment[1]: - raise ValueError( - f"The max time of signals prepared by online models is {signal_max}, but those models only online in {time_segment}" - ) - - def routine(self, cur_time=None, delay_prepare=False, *args, **kwargs): - """ - The typical update process after a routine, such as day by day or month by month. - update online prediction -> prepare signals -> prepare tasks -> prepare new models -> reset online models - - NOTE: Assumption: if using simulator (delay_prepare is True), the prediction will be prepared well after every training, so there is no need to update predictions. - - Args: - cur_time ([type], optional): [description]. Defaults to None. - delay_prepare (bool, optional): [description]. Defaults to False. - *args, **kwargs: will be passed to `prepare_tasks` and `prepare_new_models`. It can be some hyper parameter or training config. - - Returns: - [type]: [description] - """ - self.cur_time = cur_time # None for latest date - if not delay_prepare: - self.update_online_pred() - self.prepare_signals() - tasks = self.prepare_tasks(*args, **kwargs) - self.prepare_new_models(tasks, *args, **kwargs) - - return self.reset_online_tag() - - -class OnlineManagerR(OnlineManager): - """ - The implementation of OnlineManager based on (R)ecorder. - - """ - - def __init__(self, experiment_name: str, trainer: Trainer = None, need_log=True): - """ - init OnlineManagerR. - - Args: - experiment_name (str): the experiment name. - trainer (Trainer, optional): a instance of Trainer. Defaults to None. - need_log (bool, optional): print log or not. Defaults to True. - """ - if trainer is None: - trainer = TrainerR(experiment_name) - super().__init__(trainer=trainer, need_log=need_log) - self.exp_name = experiment_name - self.signal_rec = None - - def set_online_tag(self, tag, recorder: Union[Recorder, List]): - """ - Set `tag` to the model to sign whether online. - - Args: - tag (str): the tags in `ONLINE_TAG`, `NEXT_ONLINE_TAG`, `OFFLINE_TAG` - recorder (Union[Recorder, List]) - """ - if isinstance(recorder, Recorder): - recorder = [recorder] - for rec in recorder: - rec.set_tags(**{self.ONLINE_KEY: tag}) - if self.need_log: - self.logger.info(f"Set {len(recorder)} models to '{tag}'.") - - def get_online_tag(self, recorder: Recorder): - """ - Given a model and return its online tag. - - Args: - recorder (Recorder): a instance of recorder - - Returns: - str: the tag - """ - tags = recorder.list_tags() - return tags.get(OnlineManager.ONLINE_KEY, OnlineManager.OFFLINE_TAG) - - def reset_online_tag(self, recorder: Union[Recorder, List] = None): - """offline all models and set the recorders to 'online'. If no parameter and no 'next online' model, then do nothing. - - Args: - recorders (Union[Recorder, List], optional): - the recorders you want to reset to 'online'. If don't give, set 'next online' model to 'online' model. If there isn't any 'next online' model, then maintain existing 'online' model. - - Returns: - list: new online recorder. [] if there is no update. - """ - if recorder is None: - recorder = list( - list_recorders( - self.exp_name, lambda rec: self.get_online_tag(rec) == OnlineManager.NEXT_ONLINE_TAG - ).values() - ) - if isinstance(recorder, Recorder): - recorder = [recorder] - if len(recorder) == 0: - if self.need_log: - self.logger.info("No 'next online' model, just use current 'online' models.") - return [] - recs = list_recorders(self.exp_name) - self.set_online_tag(OnlineManager.OFFLINE_TAG, list(recs.values())) - self.set_online_tag(OnlineManager.ONLINE_TAG, recorder) - return recorder - - def get_signals(self): - """ - get signals from the recorder(named self.exp_name) of the experiment(named self.SIGNAL_EXP) - - Returns: - signals - """ - if self.signal_rec is None: - with R.start(experiment_name=self.SIGNAL_EXP, recorder_name=self.exp_name, resume=True): - self.signal_rec = R.get_recorder() - signals = None - try: - signals = self.signal_rec.load_object("signals") - except OSError: - self.logger.warn("Can not find `signals`, have you called `prepare_signals` before?") - return signals - - def online_models(self): - """ - Return online models. - - Returns: - list: the list of online models - """ - return list( - list_recorders(self.exp_name, lambda rec: self.get_online_tag(rec) == OnlineManager.ONLINE_TAG).values() - ) - - def update_online_pred(self): - """ - Update all online model predictions to the latest day in Calendar - """ - online_models = self.online_models() - for rec in online_models: - PredUpdater(rec, to_date=self.cur_time, need_log=self.need_log).update() - - if self.need_log: - self.logger.info(f"Finished updating {len(online_models)} online model predictions of {self.exp_name}.") - - def prepare_signals(self, over_write=False): - """ - Average the predictions of online models and offer a trading signals every routine. - The signals will be saved to `signal` file of a recorder named self.exp_name of a experiment using the name of `SIGNAL_EXP` - Even if the latest signal already exists, the latest calculation result will be overwritten. - NOTE: Given a prediction of a certain time, all signals before this time will be prepared well. - Args: - over_write (bool, optional): If True, the new signals will overwrite the file. If False, the new signals will append to the end of signals. Defaults to False. - """ - if self.signal_rec is None: - with R.start(experiment_name=self.SIGNAL_EXP, recorder_name=self.exp_name, resume=True): - self.signal_rec = R.get_recorder() - - pred = [] - try: - old_signals = self.signal_rec.load_object("signals") - except OSError: - old_signals = None - - for rec in self.online_models(): - pred.append(rec.load_object("pred.pkl")) - - signals = pd.concat(pred, axis=1).mean(axis=1).to_frame("score") - signals = signals.sort_index() - if old_signals is not None and not over_write: - old_max = old_signals.index.get_level_values("datetime").max() - new_signals = signals.loc[old_max:] - signals = pd.concat([old_signals, new_signals], axis=0) - else: - new_signals = signals - if self.need_log: - self.logger.info(f"Finished preparing new {len(new_signals)} signals to {self.SIGNAL_EXP}/{self.exp_name}.") - self.signal_rec.save_objects(**{"signals": signals}) - - -class RollingOnlineManager(OnlineManagerR): - """An implementation of OnlineManager based on Rolling.""" - def __init__( self, - experiment_name: str, - rolling_gen: RollingGen, - trainer: Trainer = None, + strategy: Union[OnlineStrategy, List[OnlineStrategy]], + begin_time: Union[str, pd.Timestamp] = None, + freq="day", need_log=True, ): """ - init RollingOnlineManager. + Init OnlineManager. Args: - experiment_name (str): the experiment name. - rolling_gen (RollingGen): an instance of RollingGen - trainer (Trainer, optional): an instance of Trainer. Defaults to None. - collector (Collector, optional): an instance of Collector. Defaults to None. + strategy (Union[OnlineStrategy, List[OnlineStrategy]]): an instance of OnlineStrategy or a list of OnlineStrategy + begin_time (Union[str,pd.Timestamp], optional): the OnlineManager will begin at this time. Defaults to None. + freq (str, optional): data frequency. Defaults to "day". need_log (bool, optional): print log or not. Defaults to True. """ - if trainer is None: - trainer = TrainerR(experiment_name) - super().__init__(experiment_name=experiment_name, trainer=trainer, need_log=need_log) - self.ta = TimeAdjuster() - self.rg = rolling_gen - self.logger = get_module_logger(self.__class__.__name__) - - def get_collector(self, rec_key_func=None, rec_filter_func=None): - """ - Get the instance of collector to collect results. The returned collector must can distinguish results in different models. - Assumption: the models can be distinguished based on model name and rolling test segments. - If you do not want this assumption, please implement your own method or use another rec_key_func. - - Args: - rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id. - rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None. - """ - - def rec_key(recorder): - task_config = recorder.load_object("task") - model_key = task_config["model"]["class"] - rolling_key = task_config["dataset"]["kwargs"]["segments"]["test"] - return model_key, rolling_key - - if rec_key_func is None: - rec_key_func = rec_key - - return RecorderCollector(experiment=self.exp_name, rec_key_func=rec_key_func, rec_filter_func=rec_filter_func) - - def collect_artifact(self, rec_key_func=None, rec_filter_func=None): - """ - collecting artifact based on the collector and RollingGroup. - - Args: - rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id. - rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None. - - Returns: - dict: the artifact dict after rolling ensemble - """ - artifact = ens_workflow( - self.get_collector(rec_key_func=rec_key_func, rec_filter_func=rec_filter_func), RollingGroup() - ) - return artifact - - def first_train(self, task_configs: list): - """ - Use rolling_gen to generate different tasks based on task_configs and trained them. - - Args: - task_configs (list or dict): a list of task configs or a task config - - Returns: - Collector: a instance of a Collector. - """ - tasks = task_generator( - tasks=task_configs, - generators=self.rg, # generate different date segment - ) - self.prepare_new_models(tasks, tag=self.ONLINE_TAG) - return self.get_collector() - - def prepare_tasks(self): - """ - Prepare new tasks based on new date. - - Returns: - list: a list of new tasks. - """ - latest_records, max_test = self.list_latest_recorders( - lambda rec: self.get_online_tag(rec) == OnlineManager.ONLINE_TAG - ) - if max_test is None: - self.logger.warn(f"No latest online recorders, no new tasks.") - return [] - calendar_latest = D.calendar(end_time=self.cur_time)[-1] if self.cur_time is None else self.cur_time - if self.need_log: - self.logger.info( - f"The interval between current time {calendar_latest} and last rolling test begin time {max_test[0]} is {self.ta.cal_interval(calendar_latest, max_test[0])}, the rolling step is {self.rg.step}" - ) - if self.ta.cal_interval(calendar_latest, max_test[0]) >= self.rg.step: - old_tasks = [] - tasks_tmp = [] - for rid, rec in latest_records.items(): - task = rec.load_object("task") - old_tasks.append(deepcopy(task)) - test_begin = task["dataset"]["kwargs"]["segments"]["test"][0] - # modify the test segment to generate new tasks - task["dataset"]["kwargs"]["segments"]["test"] = (test_begin, calendar_latest) - tasks_tmp.append(task) - new_tasks_tmp = task_generator(tasks_tmp, self.rg) - new_tasks = [task for task in new_tasks_tmp if task not in old_tasks] - return new_tasks - return [] - - def list_latest_recorders(self, rec_filter_func=None): - """find latest recorders based on test segments. - - Args: - rec_filter_func (Callable, optional): recorder filter. Defaults to None. - - Returns: - dict, tuple: the latest recorders and the latest date of them - """ - recs_flt = list_recorders(self.exp_name, rec_filter_func) - if len(recs_flt) == 0: - return recs_flt, None - max_test = max(rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] for rec in recs_flt.values()) - latest_rec = {} - for rid, rec in recs_flt.items(): - if rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] == max_test: - latest_rec[rid] = rec - return latest_rec, max_test - - -class OnlineM(Serializable): - def __init__( - self, strategy: Union[OnlineStrategy, List[OnlineStrategy]], begin_time=None, freq="day", need_log=True - ): self.logger = get_module_logger(self.__class__.__name__) self.need_log = need_log if not isinstance(strategy, list): @@ -491,38 +44,37 @@ class OnlineM(Serializable): self.freq = freq if begin_time is None: begin_time = D.calendar(freq=self.freq).max() - self.cur_time = pd.Timestamp(begin_time) + self.begin_time = pd.Timestamp(begin_time) + self.cur_time = self.begin_time self.history = {} def first_train(self): """ - Train a series of models firstly and set some of them into online models. + Run every strategy first_train method and record the online history """ for strategy in self.strategy: self.logger.info(f"Strategy `{strategy.name_id}` begins first training...") online_models = strategy.first_train() self.history.setdefault(strategy.name_id, {})[self.cur_time] = online_models - def routine(self, cur_time=None, task_kwargs={}, model_kwargs={}): + def routine(self, cur_time: Union[str, pd.Timestamp] = None, task_kwargs: dict = {}, model_kwargs: dict = {}): """ + Run typical update process for every strategy and record the online history. + The typical update process after a routine, such as day by day or month by month. update online prediction -> prepare signals -> prepare tasks -> prepare new models -> reset online models - NOTE: Assumption: if using simulator (delay_prepare is True), the prediction will be prepared well after every training, so there is no need to update predictions. - Args: - cur_time ([type], optional): [description]. Defaults to None. - delay_prepare (bool, optional): [description]. Defaults to False. - *args, **kwargs: will be passed to `prepare_tasks` and `prepare_new_models`. It can be some hyper parameter or training config. - - Returns: - [type]: [description] + cur_time (Union[str,pd.Timestamp], optional): run routine method in this time. Defaults to None. + task_kwargs (dict): the params for `prepare_tasks` + model_kwargs (dict): the params for `prepare_online_models` """ if cur_time is None: cur_time = D.calendar(freq=self.freq).max() self.cur_time = pd.Timestamp(cur_time) # None for latest date for strategy in self.strategy: - self.logger.info(f"Strategy `{strategy.name_id}` begins routine...") + if self.need_log: + self.logger.info(f"Strategy `{strategy.name_id}` begins routine...") if not strategy.trainer.is_delay(): strategy.prepare_signals() tasks = strategy.prepare_tasks(self.cur_time, **task_kwargs) @@ -530,13 +82,28 @@ class OnlineM(Serializable): if len(online_models) > 0: self.history.setdefault(strategy.name_id, {})[self.cur_time] = online_models - def get_collector(self): + def get_collector(self) -> HyperCollector: + """ + Get the instance of HyperCollector to collect results from every strategy. + + Returns: + HyperCollector: the collector can collect other collectors. + """ collector_dict = {} for strategy in self.strategy: collector_dict[strategy.name_id] = strategy.get_collector() return HyperCollector(collector_dict) - def get_online_history(self, strategy_name_id): + def get_online_history(self, strategy_name_id: str) -> list: + """ + Get the online history based on strategy_name_id. + + Args: + strategy_name_id (str): the name_id of strategy + + Returns: + dict: a list like [(time, [online_models])] + """ history_dict = self.history[strategy_name_id] history = [] for time in sorted(history_dict): @@ -547,22 +114,20 @@ class OnlineM(Serializable): def delay_prepare(self, delay_kwargs={}): """ Prepare all models and signals if there are something waiting for prepare. - NOTE: Assumption: the predictions of online models are between `time_segment`, or this method will work in a wrong way. Args: - rec_dict (str): an online models dict likes {(begin_time, end_time):[online models]}. - *args, **kwargs: will be passed to end_train which means will be passed to customized train method. + delay_kwargs: the params for `delay_prepare` """ for strategy in self.strategy: strategy.delay_prepare(self.get_online_history(strategy.name_id), **delay_kwargs) - def simulate(self, end_time, frequency="day", task_kwargs={}, model_kwargs={}, delay_kwargs={}): + def simulate(self, end_time, frequency="day", task_kwargs={}, model_kwargs={}, delay_kwargs={}) -> HyperCollector: """ - Starting from start time, this method will simulate every routine in OnlineManager. + Starting from cur time, this method will simulate every routine in OnlineManager. NOTE: Considering the parallel training, the models and signals can be perpared after all routine simulating. Returns: - Collector: the OnlineManager's collector + HyperCollector: the OnlineManager's collector """ cal = D.calendar(start_time=self.cur_time, end_time=end_time, freq=frequency) self.first_train() @@ -572,3 +137,12 @@ class OnlineM(Serializable): self.delay_prepare(delay_kwargs=delay_kwargs) self.logger.info(f"Finished preparing signals") return self.get_collector() + + def reset(self): + """ + NOTE: This method will reset all strategy! Be careful to use it. + """ + self.cur_time = self.begin_time + self.history = {} + for strategy in self.strategy: + strategy.reset() diff --git a/qlib/workflow/online/simulator.py b/qlib/workflow/online/simulator.py deleted file mode 100644 index ddaf2471c..000000000 --- a/qlib/workflow/online/simulator.py +++ /dev/null @@ -1,77 +0,0 @@ -from qlib.data import D -from qlib import get_module_logger -from qlib.workflow.online.manager import OnlineM - - -class OnlineSimulator: - """ - To simulate online serving in the past, like a "online serving backtest". - """ - - def __init__( - self, - start_time, - end_time, - online_manager: OnlineManager, - frequency="day", - ): - """ - init OnlineSimulator. - - Args: - start_time (str or pd.Timestamp): the start time of simulating. - end_time (str or pd.Timestamp): the end time of simulating. If None, then end_time is latest. - onlinemanager (OnlineManager): the instance of OnlineManager - frequency (str, optional): the data frequency. Defaults to "day". - """ - self.logger = get_module_logger(self.__class__.__name__) - self.cal = D.calendar(start_time=start_time, end_time=end_time, freq=frequency) - self.start_time = self.cal[0] - self.end_time = self.cal[-1] - self.olm = online_manager - if len(self.cal) == 0: - self.logger.warn(f"There is no need to simulate bacause start_time is larger than end_time.") - - # def simulate(self, *args, **kwargs): - # """ - # Starting from start time, this method will simulate every routine in OnlineManager. - # NOTE: Considering the parallel training, the models and signals can be perpared after all routine simulating. - - # Returns: - # Collector: the OnlineManager's collector - # """ - # self.rec_dict = {} - # tmp_begin = self.start_time - # tmp_end = None - # self.olm.first_train() - # prev_recorders = self.olm.online_models() - # for cur_time in self.cal: - # self.logger.info(f"Simulating at {str(cur_time)}......") - # recorders = self.olm.routine(cur_time, True, *args, **kwargs) - # if len(recorders) == 0: - # tmp_end = cur_time - # else: - # self.rec_dict[(tmp_begin, tmp_end)] = prev_recorders - # tmp_begin = cur_time - # prev_recorders = recorders - # self.rec_dict[(tmp_begin, self.end_time)] = prev_recorders - # # finished perparing models (and pred) and signals - # self.olm.delay_prepare(self.rec_dict) - # self.logger.info(f"Finished preparing signals") - # return self.olm.get_collector() - - def simulate(self, task_kwargs={}, model_kwargs={}): - """ - Starting from start time, this method will simulate every routine in OnlineManager. - NOTE: Considering the parallel training, the models and signals can be perpared after all routine simulating. - - Returns: - Collector: the OnlineManager's collector - """ - self.olm.first_train() - for cur_time in self.cal: - self.logger.info(f"Simulating at {str(cur_time)}......") - self.olm.routine(cur_time, task_kwargs={}, model_kwargs={}) - self.olm.delay_prepare() - self.logger.info(f"Finished preparing signals") - return self.olm.get_collector() diff --git a/qlib/workflow/online/strategy.py b/qlib/workflow/online/strategy.py index 5e4dcc024..3782ee652 100644 --- a/qlib/workflow/online/strategy.py +++ b/qlib/workflow/online/strategy.py @@ -1,11 +1,14 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. + """ -This module is working with OnlineManager, responsing for a set of strategy about how the models are updated and signals are perpared. +OnlineStrategy is a set of strategy of online serving. +It is working with OnlineManager, responsing how the tasks are generated, the models are updated and signals are perpared. """ from copy import deepcopy -from typing import List, Union +from typing import List, Tuple, Union + import pandas as pd from qlib.data.data import D from qlib.log import get_module_logger @@ -13,7 +16,8 @@ from qlib.model.ens.group import RollingGroup from qlib.model.trainer import Trainer, TrainerR from qlib.workflow import R from qlib.workflow.online.utils import OnlineTool, OnlineToolR -from qlib.workflow.task.collect import HyperCollector, RecorderCollector +from qlib.workflow.recorder import Recorder +from qlib.workflow.task.collect import Collector, HyperCollector, RecorderCollector from qlib.workflow.task.gen import RollingGen, task_generator from qlib.workflow.task.utils import TimeAdjuster, list_recorders @@ -21,7 +25,7 @@ from qlib.workflow.task.utils import TimeAdjuster, list_recorders class OnlineStrategy: def __init__(self, name_id: str, trainer: Trainer = None, need_log=True): """ - init OnlineManager. + Init OnlineStrategy. Args: name_id (str): a unique name or id @@ -33,12 +37,15 @@ class OnlineStrategy: self.logger = get_module_logger(self.__class__.__name__) self.need_log = need_log self.tool = OnlineTool() - self.history = {} - def prepare_signals(self, delay=False): + def prepare_signals(self, delay: bool = False): """ After perparing the data of last routine (a box in box-plot) which means the end of the routine, we can prepare trading signals for next routine. - Must use `pass` even though there is nothing to do. + + NOTE: Given a set prediction, all signals before these prediction end time will be prepared well. + Args: + delay: bool + If this method was called by `delay_prepare` """ raise NotImplementedError(f"Please implement the `prepare_signals` method.") @@ -46,6 +53,8 @@ class OnlineStrategy: """ After the end of a routine, check whether we need to prepare and train some new tasks. return the new tasks waiting for training. + + You can find last online models by OnlineTool.online_models. """ raise NotImplementedError(f"Please implement the `prepare_tasks` method.") @@ -53,6 +62,8 @@ class OnlineStrategy: """ Use trainer to train a list of tasks and set the trained model to `online`. + NOTE: This method will first offline all models and online the online models prepared by this method. So you can find last online models by OnlineTool.online_models if you still need them. + Args: tasks (list): a list of tasks. tag (str): @@ -78,33 +89,43 @@ class OnlineStrategy: def first_train(self): """ - Train a series of models firstly and set some of them into online models. + Train a series of models firstly and set some of them as online models. """ raise NotImplementedError(f"Please implement the `first_train` method.") - def get_collector(self): + def get_collector(self) -> Collector: """ - Return the collector. + Get the instance of collector to collect results of online serving. + + For example: + 1) collect predictions in Recorder + 2) collect signals in .txt file Returns: Collector """ raise NotImplementedError(f"Please implement the `get_collector` method.") - def delay_prepare(self, history, **kwargs): + def delay_prepare(self, history: list, **kwargs): """ Prepare all models and signals if there are something waiting for prepare. - NOTE: Assumption: the predictions of online models are between `time_segment`, or this method will work in a wrong way. + NOTE: Assumption: the predictions of online models need less than next begin_time, or this method will work in a wrong way. Args: - rec_dict (str): an online models dict likes {(begin_time, end_time):[online models]}. - *args, **kwargs: will be passed to end_train which means will be passed to customized train method. + history (list): an online models list likes [begin_time:[online models]]. + **kwargs: will be passed to end_train which means will be passed to customized train method. """ - for time_begin, recs_list in history: + for begin_time, recs_list in history: self.trainer.end_train(recs_list, **kwargs) self.tool.reset_online_tag(recs_list) self.prepare_signals(delay=True) + def reset(self): + """ + Delete all things and set them to default status. This method is convenient to explore the strategy for online simulation. + """ + pass + class RollingAverageStrategy(OnlineStrategy): @@ -122,7 +143,7 @@ class RollingAverageStrategy(OnlineStrategy): signal_exp_name="OnlineManagerSignals", ): """ - init OnlineManagerR. + Init RollingAverageStrategy. Assumption: the str of name_id, the experiment name and the trainer's experiment name are same one. @@ -139,11 +160,11 @@ class RollingAverageStrategy(OnlineStrategy): if not isinstance(task_template, list): task_template = [task_template] self.task_template = task_template - self.signal_rec = None self.signal_exp_name = signal_exp_name - self.ta = TimeAdjuster() self.rg = rolling_gen self.tool = OnlineToolR(self.exp_name) + self.ta = TimeAdjuster() + self.signal_rec = None # the recorder to record signals def get_collector(self, rec_key_func=None, rec_filter_func=None): """ @@ -180,12 +201,12 @@ class RollingAverageStrategy(OnlineStrategy): ) return HyperCollector({"artifacts": artifacts_collector, "signals": signals_collector}) - def first_train(self): + def first_train(self) -> List[Recorder]: """ Use rolling_gen to generate different tasks based on task_template and trained them. Returns: - Collector: a instance of a Collector. + List[Recorder]: a list of Recorder. """ tasks = task_generator( tasks=self.task_template, @@ -193,12 +214,14 @@ class RollingAverageStrategy(OnlineStrategy): ) return self.prepare_online_models(tasks) - def prepare_tasks(self, cur_time): + def prepare_tasks(self, cur_time) -> List[dict]: """ Prepare new tasks based on cur_time (None for latest). + You can find last online models by OnlineToolR.online_models. + Returns: - list: a list of new tasks. + List[dict]: a list of new tasks. """ latest_records, max_test = self._list_latest(self.tool.online_models()) if max_test is None: @@ -224,7 +247,7 @@ class RollingAverageStrategy(OnlineStrategy): return new_tasks return [] - def prepare_signals(self, delay=False, over_write=False): + def prepare_signals(self, delay=False, over_write=False) -> pd.DataFrame: """ Average the predictions of online models and offer a trading signals every routine. The signals will be saved to `signal` file of a recorder named self.exp_name of a experiment using the name of `SIGNAL_EXP` @@ -233,7 +256,7 @@ class RollingAverageStrategy(OnlineStrategy): Args: over_write (bool, optional): If True, the new signals will overwrite the file. If False, the new signals will append to the end of signals. Defaults to False. Returns: - object: the signals. + pd.DataFrame: the signals. """ if not delay: self.tool.update_online_pred() @@ -250,7 +273,7 @@ class RollingAverageStrategy(OnlineStrategy): for rec in self.tool.online_models(): pred.append(rec.load_object("pred.pkl")) - signals = pd.concat(pred, axis=1).mean(axis=1).to_frame("score") + signals: pd.DataFrame = pd.concat(pred, axis=1).mean(axis=1).to_frame("score") signals = signals.sort_index() if old_signals is not None and not over_write: old_max = old_signals.index.get_level_values("datetime").max() @@ -275,14 +298,19 @@ class RollingAverageStrategy(OnlineStrategy): # if self.signal_rec is None: # with R.start(experiment_name=self.signal_exp_name, recorder_name=self.exp_name, resume=True): # self.signal_rec = R.get_recorder() - # signals = None - # try: - # signals = self.signal_rec.load_object("signals") - # except OSError: - # self.logger.warn("Can not find `signals`, have you called `prepare_signals` before?") + # signals = self.signal_rec.load_object("signals") # return signals - def _list_latest(self, rec_list): + def _list_latest(self, rec_list: List[Recorder]): + """ + List latest recorder form rec_list + + Args: + rec_list (List[Recorder]): a list of Recorder + + Returns: + List[Recorder], pd.Timestamp: the latest recorders and its test end time + """ if len(rec_list) == 0: return rec_list, None max_test = max(rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] for rec in rec_list) @@ -291,3 +319,16 @@ class RollingAverageStrategy(OnlineStrategy): if rec.load_object("task")["dataset"]["kwargs"]["segments"]["test"] == max_test: latest_rec.append(rec) return latest_rec, max_test + + def reset(self): + """ + NOTE: This method will delete all recorder in Experiment and reset the Trainer! + """ + self.trainer.reset() + # delete models + exp = R.get_exp(experiment_name=self.exp_name) + for rid in exp.list_recorders(): + exp.delete_recorder(rid) + # delete signals + for rid in list_recorders(self.signal_exp_name, lambda x: True if x.info["name"] == self.exp_name else False): + exp.delete_recorder(rid) diff --git a/qlib/workflow/online/update.py b/qlib/workflow/online/update.py index 5b58360d8..69ad55324 100644 --- a/qlib/workflow/online/update.py +++ b/qlib/workflow/online/update.py @@ -1,18 +1,20 @@ -from typing import Union, List -from qlib.data.dataset import DatasetH -from qlib.workflow import R -from qlib.data import D +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Update is a module to update artifacts such as predictions, when the stock data updating. +""" + +from abc import ABCMeta, abstractmethod + import pandas as pd from qlib import get_module_logger -from qlib.workflow import R -from qlib.model import Model -from qlib.model.trainer import task_train -from qlib.workflow.recorder import Recorder -from qlib.workflow.task.utils import list_recorders -from qlib.data.dataset.handler import DataHandlerLP +from qlib.data import D from qlib.data.dataset import DatasetH -from abc import ABCMeta, abstractmethod +from qlib.data.dataset.handler import DataHandlerLP +from qlib.model import Model from qlib.utils import get_date_by_shift +from qlib.workflow.recorder import Recorder class RMDLoader: @@ -25,19 +27,22 @@ class RMDLoader: def get_dataset(self, start_time, end_time, segments=None) -> DatasetH: """ - load, config and setup dataset. + Load, config and setup dataset. - This dataset is for inference + This dataset is for inference. + + Args: + start_time : + the start_time of underlying data + end_time : + the end_time of underlying data + segments : dict + the segments config for dataset + Due to the time series dataset (TSDatasetH), the test segments maybe different from start_time and end_time + + Returns: + DatasetH: the instance of DatasetH - Parameters - ---------- - start_time : - the start_time of underlying data - end_time : - the end_time of underlying data - segments : dict - the segments config for dataset - Due to the time series dataset (TSDatasetH), the test segments maybe different from start_time and end_time """ if segments is None: segments = {"test": (start_time, end_time)} @@ -52,7 +57,7 @@ class RMDLoader: class RecordUpdater(metaclass=ABCMeta): """ - Updata a specific recorders + Update a specific recorders """ def __init__(self, record: Recorder, need_log=True, *args, **kwargs): @@ -75,16 +80,17 @@ class PredUpdater(RecordUpdater): def __init__(self, record: Recorder, to_date=None, hist_ref: int = 0, freq="day", need_log=True): """ - Parameters - ---------- - record : Recorder - to_date : - update to prediction to the `to_date` - hist_ref : int - Sometimes, the dataset will have historical depends. - Leave the problem to user to set the length of historical dependancy - NOTE: the start_time is not included in the hist_ref - # TODO: automate this step in the future. + Init PredUpdater. + + Args: + record : Recorder + to_date : + update to prediction to the `to_date` + hist_ref : int + Sometimes, the dataset will have historical depends. + Leave the problem to user to set the length of historical dependency + NOTE: the start_time is not included in the hist_ref + # TODO: automate this step in the future. """ super().__init__(record=record, need_log=need_log) @@ -101,9 +107,12 @@ class PredUpdater(RecordUpdater): def prepare_data(self) -> DatasetH: """ - # Load dataset + Load dataset Seperating this function will make it easier to reuse the dataset + + Returns: + DatasetH: the instance of DatasetH """ start_time_buffer = get_date_by_shift(self.last_end, -self.hist_ref + 1, clip_shift=False, freq=self.freq) start_time = get_date_by_shift(self.last_end, 1, freq=self.freq) @@ -113,9 +122,12 @@ class PredUpdater(RecordUpdater): def update(self, dataset: DatasetH = None): """ - update the precition in a recorder + Update the precition in a recorder + + Args: + DatasetH: the instance of DatasetH. None for reprepare. """ - # FIXME: the problme below is not solved + # FIXME: the problem below is not solved # The model dumped on GPU instances can not be loaded on CPU instance. Follow exception will raised # RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU. # https://github.com/pytorch/pytorch/issues/16797 diff --git a/qlib/workflow/online/utils.py b/qlib/workflow/online/utils.py index 1cd89d668..4d630a665 100644 --- a/qlib/workflow/online/utils.py +++ b/qlib/workflow/online/utils.py @@ -1,7 +1,14 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + """ -This module is like a online backend, deciding which models are `online` models and how can change them +OnlineTool is a module to set and unset a series of `online` models. +The `online` models are some decisive models in some time point, which can be changed with the change of time. +This allows us to use efficient submodels as the market style changing. """ + from typing import List, Union + from qlib.log import get_module_logger from qlib.workflow.online.update import PredUpdater from qlib.workflow.recorder import Recorder @@ -12,60 +19,66 @@ class OnlineTool: ONLINE_KEY = "online_status" # the online status key in recorder ONLINE_TAG = "online" # the 'online' model - # NOTE: The meaning of this tag is that we can not assume the training models can be trained before we need its predition. Whenever finished training, it can be guaranteed that there are some online models. - NEXT_ONLINE_TAG = "next_online" # the 'next online' model, which can be 'online' model when call reset_online_model OFFLINE_TAG = "offline" # the 'offline' model, not for online serving def __init__(self, need_log=True): """ - init OnlineTool. + Init OnlineTool. Args: need_log (bool, optional): print log or not. Defaults to True. """ self.logger = get_module_logger(self.__class__.__name__) self.need_log = need_log - self.cur_time = None - def set_online_tag(self, tag, recorder): + def set_online_tag(self, tag, recorder: Union[list, object]): """ Set `tag` to the model to sign whether online. Args: - tag (str): the tags in `ONLINE_TAG`, `NEXT_ONLINE_TAG`, `OFFLINE_TAG` + tag (str): the tags in `ONLINE_TAG`, `OFFLINE_TAG` + recorder (Union[list,object]): the model's recorder """ raise NotImplementedError(f"Please implement the `set_online_tag` method.") - def get_online_tag(self): + def get_online_tag(self, recorder: object) -> str: """ - Given a model and return its online tag. + Given a model recorder and return its online tag. + + Args: + recorder (Object): the model's recorder + + Returns: + str: the online tag """ raise NotImplementedError(f"Please implement the `get_online_tag` method.") - def reset_online_tag(self, recorders=None): - """offline all models and set the recorders to 'online'. If no parameter and no 'next online' model, then do nothing. + def reset_online_tag(self, recorder: Union[list, object]): + """ + Offline all models and set the recorders to 'online'. Args: - recorders (List, optional): - the recorders you want to reset to 'online'. If don't give, set 'next online' model to 'online' model. If there isn't any 'next online' model, then maintain existing 'online' model. + recorder (Union[list,object]): + the recorder you want to reset to 'online'. - Returns: - list: new online recorder. [] if there is no update. """ raise NotImplementedError(f"Please implement the `reset_online_tag` method.") - def online_models(self): + def online_models(self) -> list: """ - Return `online` models. + Get current `online` models + + Returns: + list: a list of `online` models. """ raise NotImplementedError(f"Please implement the `online_models` method.") def update_online_pred(self, to_date=None): """ - Update the predictions of online models to a date. + Update the predictions of `online` models to a date. Args: - to_date (pd.Timestamp): the pred before this date will be updated. None for latest. + to_date (pd.Timestamp): the pred before this date will be updated. None for update to latest. """ raise NotImplementedError(f"Please implement the `update_online_pred` method.") @@ -74,12 +87,11 @@ class OnlineTool: class OnlineToolR(OnlineTool): """ The implementation of OnlineTool based on (R)ecorder. - """ def __init__(self, experiment_name: str, need_log=True): """ - init OnlineToolR. + Init OnlineToolR. Args: experiment_name (str): the experiment name. @@ -90,11 +102,11 @@ class OnlineToolR(OnlineTool): def set_online_tag(self, tag, recorder: Union[Recorder, List]): """ - Set `tag` to the model to sign whether online. + Set `tag` to the model's recorder to sign whether online. Args: tag (str): the tags in `ONLINE_TAG`, `NEXT_ONLINE_TAG`, `OFFLINE_TAG` - recorder (Union[Recorder, List]) + recorder (Union[Recorder, List]): a list of Recorder or an instance of Recorder """ if isinstance(recorder, Recorder): recorder = [recorder] @@ -103,50 +115,40 @@ class OnlineToolR(OnlineTool): if self.need_log: self.logger.info(f"Set {len(recorder)} models to '{tag}'.") - def get_online_tag(self, recorder: Recorder): + def get_online_tag(self, recorder: Recorder) -> str: """ - Given a model and return its online tag. + Given a model recorder and return its online tag. Args: - recorder (Recorder): a instance of recorder + recorder (Recorder): an instance of recorder Returns: - str: the tag + str: the online tag """ tags = recorder.list_tags() return tags.get(self.ONLINE_KEY, self.OFFLINE_TAG) - def reset_online_tag(self, recorder: Union[Recorder, List] = None): - """offline all models and set the recorders to 'online'. If no parameter and no 'next online' model, then do nothing. + def reset_online_tag(self, recorder: Union[Recorder, List]): + """ + Offline all models and set the recorders to 'online'. Args: - recorders (Union[Recorder, List], optional): - the recorders you want to reset to 'online'. If don't give, set 'next online' model to 'online' model. If there isn't any 'next online' model, then maintain existing 'online' model. + recorder (Union[Recorder, List]): + the recorder you want to reset to 'online'. - Returns: - list: new online recorder. [] if there is no update. """ - if recorder is None: - recorder = list( - list_recorders(self.exp_name, lambda rec: self.get_online_tag(rec) == self.NEXT_ONLINE_TAG).values() - ) if isinstance(recorder, Recorder): recorder = [recorder] - if len(recorder) == 0: - if self.need_log: - self.logger.info("No 'next online' model, just use current 'online' models.") - return [] recs = list_recorders(self.exp_name) self.set_online_tag(self.OFFLINE_TAG, list(recs.values())) self.set_online_tag(self.ONLINE_TAG, recorder) - return recorder - def online_models(self): + def online_models(self) -> list: """ - Return online models. + Get current `online` models Returns: - list: the list of online models + list: a list of `online` models. """ return list(list_recorders(self.exp_name, lambda rec: self.get_online_tag(rec) == self.ONLINE_TAG).values()) @@ -155,7 +157,7 @@ class OnlineToolR(OnlineTool): Update the predictions of online models to a date. Args: - to_date (pd.Timestamp): the pred before this date will be updated. None for latest in Calendar. + to_date (pd.Timestamp): the pred before this date will be updated. None for update to latest time in Calendar. """ online_models = self.online_models() for rec in online_models: diff --git a/qlib/workflow/task/collect.py b/qlib/workflow/task/collect.py index eb0a20029..d74d08184 100644 --- a/qlib/workflow/task/collect.py +++ b/qlib/workflow/task/collect.py @@ -1,9 +1,11 @@ -from abc import abstractmethod -from typing import Callable, Union -from qlib import init +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +""" +Collector can collect object from everywhere and process them such as merging, grouping, averaging and so on. +""" + from qlib.workflow import R -from qlib.workflow.task.utils import list_recorders -from qlib.utils.serial import Serializable import dill as pickle @@ -19,7 +21,7 @@ class Collector: process_list = [process_list] self.process_list = process_list - def collect(self): + def collect(self) -> dict: """Collect the results and return a dict like {key: things} Returns: @@ -36,7 +38,7 @@ class Collector: raise NotImplementedError(f"Please implement the `collect` method.") @staticmethod - def process_collect(collected_dict, process_list=[], *args, **kwargs): + def process_collect(collected_dict, process_list=[], *args, **kwargs) -> dict: """do a series of processing to the dict returned by collect and return a dict like {key: things} For example: you can group and ensemble. @@ -61,7 +63,7 @@ class Collector: result[artifact] = value return result - def __call__(self, *args, **kwargs): + def __call__(self, *args, **kwargs) -> dict: """ do the workflow including collect and process_collect @@ -124,7 +126,7 @@ class HyperCollector(Collector): super().__init__(process_list=process_list) self.collector_dict = collector_dict - def collect(self): + def collect(self) -> dict: collect_dict = {} for key, collector in self.collector_dict.items(): collect_dict[key] = collector() @@ -153,10 +155,10 @@ class RecorderCollector(Collector): artifacts_path (dict, optional): The artifacts name and its path in Recorder. Defaults to {"pred": "pred.pkl", "IC": "sig_analysis/ic.pkl"}. artifacts_key (str or List, optional): the artifacts key you want to get. If None, get all artifacts. """ + super().__init__(process_list=process_list) if isinstance(experiment, str): experiment = R.get_exp(experiment_name=experiment) self.experiment = experiment - self.process_list = process_list self.artifacts_path = artifacts_path if rec_key_func is None: rec_key_func = lambda rec: rec.info["id"] @@ -166,7 +168,7 @@ class RecorderCollector(Collector): self.artifacts_key = artifacts_key self._rec_filter_func = rec_filter_func - def collect(self, artifacts_key=None, rec_filter_func=None): + def collect(self, artifacts_key=None, rec_filter_func=None) -> dict: """Collect different artifacts based on recorder after filtering. Args: @@ -203,5 +205,11 @@ class RecorderCollector(Collector): return collect_dict - def get_exp_name(self): + def get_exp_name(self) -> str: + """ + Get experiment name + + Returns: + str: experiment name + """ return self.experiment.name diff --git a/qlib/workflow/task/gen.py b/qlib/workflow/task/gen.py index 158bc9916..c4c6bab7f 100644 --- a/qlib/workflow/task/gen.py +++ b/qlib/workflow/task/gen.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. """ -this is a task generator +Task generator can generate many tasks based on TaskGen and some task templates. """ import abc import copy @@ -113,7 +113,7 @@ class RollingGen(TaskGen): self.test_key = "test" self.train_key = "train" - def generate(self, task: dict): + def generate(self, task: dict) -> typing.List[dict]: """ Converting the task into a rolling task. @@ -158,6 +158,10 @@ class RollingGen(TaskGen): }, ] } + + Returns + ---------- + typing.List[dict]: a list of tasks """ res = [] @@ -196,16 +200,18 @@ class RollingGen(TaskGen): # update segments of this task t["dataset"]["kwargs"]["segments"] = copy.deepcopy(segments) - # if end_time < the end of test_segments, then change end_time to allow load more data - if ( - self.modify_end_time - and self.ta.cal_interval( + + try: + interval = self.ta.cal_interval( t["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"], t["dataset"]["kwargs"]["segments"][self.test_key][1], ) - < 0 - ): - t["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"] = copy.deepcopy(segments[self.test_key][1]) + # if end_time < the end of test_segments, then change end_time to allow load more data + if self.modify_end_time and interval < 0: + t["dataset"]["kwargs"]["handler"]["kwargs"]["end_time"] = copy.deepcopy(segments[self.test_key][1]) + except KeyError: + # Maybe the user dataset has no handler or end_time + pass prev_seg = segments res.append(t) return res diff --git a/qlib/workflow/task/manage.py b/qlib/workflow/task/manage.py index 9d50d8563..3c3144fe8 100644 --- a/qlib/workflow/task/manage.py +++ b/qlib/workflow/task/manage.py @@ -1,31 +1,39 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. + """ -A task consists of 3 parts +TaskManager can fetch unused tasks automatically and manager the lifecycle of a set of tasks with error handling. +These features can run tasks concurrently and ensure every task will be used only once. +Task Manager will store all tasks in `MongoDB `_. +Users **MUST** finished the configuration of `MongoDB `_ when using this module. + +A task in TaskManager consists of 3 parts - tasks description: the desc will define the task - tasks status: the status of the task - tasks result information : A user can get the task with the task description and task result. - """ -from bson.binary import Binary -import pickle -from pymongo.errors import InvalidDocument -from bson.objectid import ObjectId -from contextlib import contextmanager -import qlib -from tqdm.cli import tqdm -import time import concurrent -import pymongo -from qlib.config import C -from .utils import get_mongodb -from qlib import get_module_logger, auto_init +import pickle +import time +from contextlib import contextmanager +from typing import Callable, List + import fire +import pymongo +from bson.binary import Binary +from bson.objectid import ObjectId +from pymongo.errors import InvalidDocument +from qlib import auto_init, get_module_logger +from tqdm.cli import tqdm + +from .utils import get_mongodb class TaskManager: - """TaskManager - here is what will a task looks like when it created by TaskManager + """ + TaskManager + + Here is what will a task looks like when it created by TaskManager .. code-block:: python @@ -42,6 +50,12 @@ class TaskManager: .. note:: Assumption: the data in MongoDB was encoded and the data out of MongoDB was decoded + + Here are four status which are: + STATUS_WAITING: waiting for train + STATUS_RUNNING: training + STATUS_PART_DONE: finished some step and waiting for next step. + STATUS_DONE: all work done """ STATUS_WAITING = "waiting" @@ -53,7 +67,7 @@ class TaskManager: def __init__(self, task_pool: str = None): """ - init Task Manager, remember to make the statement of MongoDB url and database name firstly. + Init Task Manager, remember to make the statement of MongoDB url and database name firstly. Parameters ---------- @@ -65,7 +79,7 @@ class TaskManager: self.task_pool = getattr(self.mdb, task_pool) self.logger = get_module_logger(self.__class__.__name__) - def list(self): + def list(self) -> list: """ list the all collection(task_pool) of the db @@ -92,7 +106,9 @@ class TaskManager: return {k: str(v) for k, v in flt.items()} def replace_task(self, task, new_task): - # assume that the data out of interface was decoded and the data in interface was encoded + """ + Use a new task to replace a old one + """ new_task = self._encode_task(new_task) query = {"_id": ObjectId(task["_id"])} try: @@ -121,7 +137,7 @@ class TaskManager: Returns ------- - + pymongo.results.InsertOneResult """ task = self._encode_task( { @@ -133,9 +149,9 @@ class TaskManager: insert_result = self.insert_task(task) return insert_result - def create_task(self, task_def_l, dry_run=False, print_nt=False): + def create_task(self, task_def_l, dry_run=False, print_nt=False) -> List[str]: """ - if the tasks in task_def_l is new, then insert new tasks into the task_pool + If the tasks in task_def_l is new, then insert new tasks into the task_pool Parameters ---------- @@ -145,6 +161,7 @@ class TaskManager: if insert those new tasks to task pool print_nt: bool if print new task + Returns ------- list @@ -165,7 +182,7 @@ class TaskManager: print(t) if dry_run: - return + return [] _id_list = [] for t in new_tasks: @@ -174,7 +191,17 @@ class TaskManager: return _id_list - def fetch_task(self, query={}, status=STATUS_WAITING): + def fetch_task(self, query={}, status=STATUS_WAITING) -> dict: + """ + Use query to fetch tasks + + Args: + query (dict, optional): query dict. Defaults to {}. + status (str, optional): [description]. Defaults to STATUS_WAITING. + + Returns: + dict: a task(document in collection) after decoding + """ query = query.copy() if "_id" in query: query["_id"] = ObjectId(query["_id"]) @@ -191,7 +218,7 @@ class TaskManager: @contextmanager def safe_fetch_task(self, query={}, status=STATUS_WAITING): """ - fetch task from task_pool using query with contextmanager + Fetch task from task_pool using query with contextmanager Parameters ---------- @@ -200,7 +227,7 @@ class TaskManager: Returns ------- - + dict: a task(document in collection) after decoding """ task = self.fetch_task(query=query, status=status) try: @@ -231,7 +258,7 @@ class TaskManager: Returns ------- - + dict: a task(document in collection) after decoding """ query = query.copy() if "_id" in query: @@ -240,16 +267,40 @@ class TaskManager: yield self._decode_task(t) def re_query(self, _id): + """ + Use _id to query task. + + Args: + _id (str): _id of a document + + Returns: + dict: a task(document in collection) after decoding + """ t = self.task_pool.find_one({"_id": ObjectId(_id)}) return self._decode_task(t) - def commit_task_res(self, task, res, status=None): + def commit_task_res(self, task, res, status=STATUS_DONE): + """ + Commit the result to task['res']. + + Args: + task ([type]): [description] + res (object): the result you want to save + status (str, optional): STATUS_WAITING, STATUS_RUNNING, STATUS_DONE, STATUS_PART_DONE. Defaults to STATUS_DONE. + """ # A workaround to use the class attribute. if status is None: status = TaskManager.STATUS_DONE self.task_pool.update_one({"_id": task["_id"]}, {"$set": {"status": status, "res": Binary(pickle.dumps(res))}}) - def return_task(self, task, status=None): + def return_task(self, task, status=STATUS_WAITING): + """ + Return a task to status. Alway using in error handling. + + Args: + task ([type]): [description] + status (str, optional): STATUS_WAITING, STATUS_RUNNING, STATUS_DONE, STATUS_PART_DONE. Defaults to STATUS_WAITING. + """ if status is None: status = TaskManager.STATUS_WAITING update_dict = {"$set": {"status": status}} @@ -257,7 +308,7 @@ class TaskManager: def remove(self, query={}): """ - remove the task using query + Remove the task using query Parameters ---------- @@ -295,7 +346,7 @@ class TaskManager: def prioritize(self, task, priority: int): """ - set priority for task + Set priority for task Parameters ---------- @@ -331,29 +382,37 @@ class TaskManager: def run_task( - task_func, - task_pool, - force_release=False, - before_status=TaskManager.STATUS_WAITING, - after_status=TaskManager.STATUS_DONE, - *args, + task_func: Callable, + task_pool: str, + force_release: bool = False, + before_status: str = TaskManager.STATUS_WAITING, + after_status: str = TaskManager.STATUS_DONE, **kwargs, ): """ While task pool is not empty (has WAITING tasks), use task_func to fetch and run tasks in task_pool + After running this method, here are 4 situations (before_status -> after_status): + STATUS_WAITING -> STATUS_DONE: use task["def"] as `task_func` param + STATUS_WAITING -> STATUS_PART_DONE: use task["def"] as `task_func` param + STATUS_PART_DONE -> STATUS_PART_DONE: use task["res"] as `task_func` param + STATUS_PART_DONE -> STATUS_DONE: use task["res"] as `task_func` param + Parameters ---------- - task_func : def (task_def, *args, **kwargs) -> - the function to run the task + task_func : Callable + def (task_def, **kwargs) -> + the function to run the task task_pool : str the name of the task pool (Collection in MongoDB) - force_release : + force_release : bool will the program force to release the resource - args : - args - kwargs : - kwargs + before_status : str: + the tasks in before_status will be fetched and trained. Can be STATUS_WAITING, STATUS_PART_DONE. + after_status : str: + the tasks after trained will become after_status. Can be STATUS_WAITING, STATUS_PART_DONE. + kwargs + the params for `task_func` """ tm = TaskManager(task_pool) @@ -364,19 +423,19 @@ def run_task( if task is None: break get_module_logger("run_task").info(task["def"]) - # when fetching `WAITING` task, use task_def to train + # when fetching `WAITING` task, use task["def"] to train if before_status == TaskManager.STATUS_WAITING: param = task["def"] - # when fetching `PART_DONE` task, use task_res to train for the result has been saved + # when fetching `PART_DONE` task, use task["res"] to train because the middle result has been saved to task["res"] elif before_status == TaskManager.STATUS_PART_DONE: param = task["res"] else: raise ValueError("The fetched task must be `STATUS_WAITING` or `STATUS_PART_DONE`!") if force_release: with concurrent.futures.ProcessPoolExecutor(max_workers=1) as executor: - res = executor.submit(task_func, param, *args, **kwargs).result() + res = executor.submit(task_func, param, **kwargs).result() else: - res = task_func(param, *args, **kwargs) + res = task_func(param, **kwargs) tm.commit_task_res(task, res, status=after_status) ever_run = True diff --git a/qlib/workflow/task/utils.py b/qlib/workflow/task/utils.py index ce8e0dfa3..ed5e1a235 100644 --- a/qlib/workflow/task/utils.py +++ b/qlib/workflow/task/utils.py @@ -1,5 +1,10 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. + +""" +Some tools for task management. +""" + import bisect import pandas as pd from qlib.data import D @@ -7,13 +12,14 @@ from qlib.workflow import R from qlib.config import C from qlib.log import get_module_logger from pymongo import MongoClient +from pymongo.database import Database from typing import Union -def get_mongodb(): - """ +def get_mongodb() -> Database: - get database in MongoDB, which means you need to declare the address and the name of database. + """ + Get database in MongoDB, which means you need to declare the address and the name of database. for example: Using qlib.init(): @@ -31,6 +37,8 @@ def get_mongodb(): "task_db_name" : "rolling_db" } + Returns: + Database: the Database instance """ try: cfg = C["mongo"] @@ -43,7 +51,8 @@ def get_mongodb(): def list_recorders(experiment, rec_filter_func=None): - """list all recorders which can pass the filter in a experiment. + """ + List all recorders which can pass the filter in a experiment. Args: experiment (str or Experiment): the name of a Experiment or a instance @@ -65,7 +74,7 @@ def list_recorders(experiment, rec_filter_func=None): class TimeAdjuster: """ - find appropriate date and adjust date. + Find appropriate date and adjust date. """ def __init__(self, future=True, end_time=None): @@ -88,15 +97,15 @@ class TimeAdjuster: return None return self.cals[idx] - def max(self): + def max(self) -> pd.Timestamp: """ Return the max calendar datetime """ return max(self.cals) - def align_idx(self, time_point, tp_type="start"): + def align_idx(self, time_point, tp_type="start") -> int: """ - align the index of time_point in the calendar + Align the index of time_point in the calendar Parameters ---------- @@ -116,9 +125,9 @@ class TimeAdjuster: raise NotImplementedError(f"This type of input is not supported") return idx - def cal_interval(self, time_point_A, time_point_B): + def cal_interval(self, time_point_A, time_point_B) -> int: """ - calculate the trading day interval + Calculate the trading day interval (time_point_A - time_point_B) Args: time_point_A : time_point_A @@ -129,20 +138,22 @@ class TimeAdjuster: """ return self.align_idx(time_point_A) - self.align_idx(time_point_B) - def align_time(self, time_point, tp_type="start"): + def align_time(self, time_point, tp_type="start") -> pd.Timestamp: """ Align time_point to trade date of calendar - Parameters - ---------- - time_point - Time point - tp_type : str - time point type (`"start"`, `"end"`) + Args: + time_point + Time point + tp_type : str + time point type (`"start"`, `"end"`) + + Returns: + pd.Timestamp """ return self.cals[self.align_idx(time_point, tp_type=tp_type)] - def align_seg(self, segment: Union[dict, tuple]): + def align_seg(self, segment: Union[dict, tuple]) -> Union[dict, tuple]: """ align the given date to trade date @@ -162,7 +173,7 @@ class TimeAdjuster: Returns ------- - the start and end trade date (pd.Timestamp) between the given start and end date. + Union[dict, tuple]: the start and end trade date (pd.Timestamp) between the given start and end date. """ if isinstance(segment, dict): return {k: self.align_seg(seg) for k, seg in segment.items()} @@ -171,7 +182,7 @@ class TimeAdjuster: else: raise NotImplementedError(f"This type of input is not supported") - def truncate(self, segment: tuple, test_start, days: int): + def truncate(self, segment: tuple, test_start, days: int) -> tuple: """ truncate the segment based on the test_start date @@ -183,6 +194,10 @@ class TimeAdjuster: days : int The trading days to be truncated the data in this segment may need 'days' data + + Returns + --------- + tuple: new segment """ test_idx = self.align_idx(test_start) if isinstance(segment, tuple): @@ -198,7 +213,7 @@ class TimeAdjuster: SHIFT_SD = "sliding" SHIFT_EX = "expanding" - def shift(self, seg: tuple, step: int, rtype=SHIFT_SD): + def shift(self, seg: tuple, step: int, rtype=SHIFT_SD) -> tuple: """ shift the datatime of segment @@ -211,6 +226,10 @@ class TimeAdjuster: rtype : str rolling type ("sliding" or "expanding") + Returns + -------- + tuple: new segment + Raises ------ KeyError: From f58c61a2e0c313074729da6715d30d58e1503e69 Mon Sep 17 00:00:00 2001 From: Jactus Date: Thu, 29 Apr 2021 16:54:51 +0800 Subject: [PATCH 24/30] Fix logger pickling error --- qlib/log.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/qlib/log.py b/qlib/log.py index 5888b3841..1d604e0c0 100644 --- a/qlib/log.py +++ b/qlib/log.py @@ -17,6 +17,7 @@ class MetaLogger(type): wrapper_dict = logging.Logger.__dict__.copy() wrapper_dict.update(dict) wrapper_dict["__doc__"] = logging.Logger.__doc__ + del wrapper_dict["__reduce__"] # make Logger object can be pickled return type.__new__(cls, name, bases, wrapper_dict) @@ -29,6 +30,15 @@ class QlibLogger(metaclass=MetaLogger): self.module_name = module_name self.level = 0 + def __getstate__(self): + return vars(self) + + def __setstate__(self, state): + vars(self).update(state) + + def __reduce__(self): + return (QlibLogger, (self.module_name,)) + @property def logger(self): logger = logging.getLogger(self.module_name) From ca92cb980ca9a49d9c41f98e5f2c2c6941a8a1ae Mon Sep 17 00:00:00 2001 From: Jactus Date: Thu, 29 Apr 2021 22:40:52 +0800 Subject: [PATCH 25/30] Update meta logger --- qlib/log.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/qlib/log.py b/qlib/log.py index 1d604e0c0..19331f5d5 100644 --- a/qlib/log.py +++ b/qlib/log.py @@ -15,10 +15,11 @@ from .config import C class MetaLogger(type): def __new__(cls, name, bases, dict): wrapper_dict = logging.Logger.__dict__.copy() - wrapper_dict.update(dict) - wrapper_dict["__doc__"] = logging.Logger.__doc__ - del wrapper_dict["__reduce__"] # make Logger object can be pickled - return type.__new__(cls, name, bases, wrapper_dict) + for key in wrapper_dict: + if key not in dict and key != "__reduce__": + dict[key] = wrapper_dict[key] + dict["__doc__"] = logging.Logger.__doc__ + return type.__new__(cls, name, bases, dict) class QlibLogger(metaclass=MetaLogger): From 51b649ec395f4a80e96dd88b51ebdd8d2a192db2 Mon Sep 17 00:00:00 2001 From: Jactus Date: Fri, 30 Apr 2021 13:13:05 +0800 Subject: [PATCH 26/30] Update QlibLogger --- qlib/log.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/qlib/log.py b/qlib/log.py index 19331f5d5..d095d571a 100644 --- a/qlib/log.py +++ b/qlib/log.py @@ -31,12 +31,6 @@ class QlibLogger(metaclass=MetaLogger): self.module_name = module_name self.level = 0 - def __getstate__(self): - return vars(self) - - def __setstate__(self, state): - vars(self).update(state) - def __reduce__(self): return (QlibLogger, (self.module_name,)) @@ -50,6 +44,9 @@ class QlibLogger(metaclass=MetaLogger): self.level = level def __getattr__(self, name): + # During unpickling, python will call __getattr__. Use this line to avoid maximum recursion error. + if name in {"__setstate__"}: + raise AttributeError return self.logger.__getattribute__(name) From 694ae3402766e582a6c067de807a997f1a9719c4 Mon Sep 17 00:00:00 2001 From: Jactus Date: Fri, 30 Apr 2021 13:27:19 +0800 Subject: [PATCH 27/30] Update api --- qlib/workflow/__init__.py | 21 ++++++++++++++++++--- qlib/workflow/exp.py | 4 ++-- qlib/workflow/expm.py | 4 +++- 3 files changed, 23 insertions(+), 6 deletions(-) diff --git a/qlib/workflow/__init__.py b/qlib/workflow/__init__.py index 7cb1cf5cb..8135bab60 100644 --- a/qlib/workflow/__init__.py +++ b/qlib/workflow/__init__.py @@ -23,6 +23,7 @@ class QlibRecorder: @contextmanager def start( self, + *, experiment_id: Optional[Text] = None, experiment_name: Optional[Text] = None, recorder_id: Optional[Text] = None, @@ -63,7 +64,14 @@ class QlibRecorder: resume : bool whether to resume the specific recorder with given name under the given experiment. """ - run = self.start_exp(experiment_id, experiment_name, recorder_id, recorder_name, uri, resume) + run = self.start_exp( + experiment_id=experiment_id, + experiment_name=experiment_name, + recorder_id=recorder_id, + recorder_name=recorder_name, + uri=uri, + resume=resume, + ) try: yield run except Exception as e: @@ -72,7 +80,7 @@ class QlibRecorder: self.end_exp(Recorder.STATUS_FI) def start_exp( - self, experiment_id=None, experiment_name=None, recorder_id=None, recorder_name=None, uri=None, resume=False + self, *, experiment_id=None, experiment_name=None, recorder_id=None, recorder_name=None, uri=None, resume=False ): """ Lower level method for starting an experiment. When use this method, one should end the experiment manually @@ -105,7 +113,14 @@ class QlibRecorder: ------- An experiment instance being started. """ - return self.exp_manager.start_exp(experiment_id, experiment_name, recorder_id, recorder_name, uri, resume) + return self.exp_manager.start_exp( + experiment_id=experiment_id, + experiment_name=experiment_name, + recorder_id=recorder_id, + recorder_name=recorder_name, + uri=uri, + resume=resume, + ) def end_exp(self, recorder_status=Recorder.STATUS_FI): """ diff --git a/qlib/workflow/exp.py b/qlib/workflow/exp.py index 0a7e0a5a9..467c7c3f4 100644 --- a/qlib/workflow/exp.py +++ b/qlib/workflow/exp.py @@ -39,7 +39,7 @@ class Experiment: output["recorders"] = list(recorders.keys()) return output - def start(self, recorder_id=None, recorder_name=None, resume=False): + def start(self, *, recorder_id=None, recorder_name=None, resume=False): """ Start the experiment and set it to be active. This method will also start a new recorder. @@ -240,7 +240,7 @@ class MLflowExperiment(Experiment): def __repr__(self): return "{name}(id={id}, info={info})".format(name=self.__class__.__name__, id=self.id, info=self.info) - def start(self, recorder_id=None, recorder_name=None, resume=False): + def start(self, *, recorder_id=None, recorder_name=None, resume=False): logger.info(f"Experiment {self.id} starts running ...") # Get or create recorder if recorder_name is None: diff --git a/qlib/workflow/expm.py b/qlib/workflow/expm.py index 5549bb9bf..04cc3bcb7 100644 --- a/qlib/workflow/expm.py +++ b/qlib/workflow/expm.py @@ -33,6 +33,7 @@ class ExpManager: def start_exp( self, + *, experiment_id: Optional[Text] = None, experiment_name: Optional[Text] = None, recorder_id: Optional[Text] = None, @@ -304,6 +305,7 @@ class MLflowExpManager(ExpManager): def start_exp( self, + *, experiment_id: Optional[Text] = None, experiment_name: Optional[Text] = None, recorder_id: Optional[Text] = None, @@ -320,7 +322,7 @@ class MLflowExpManager(ExpManager): # Set up active experiment self.active_experiment = experiment # Start the experiment - self.active_experiment.start(recorder_id, recorder_name, resume) + self.active_experiment.start(recorder_id=recorder_id, recorder_name=recorder_name, resume=resume) return self.active_experiment From 5eb9dfff166b79cdd2e00bc0ff7430f266db46b0 Mon Sep 17 00:00:00 2001 From: Jactus Date: Fri, 30 Apr 2021 15:28:37 +0800 Subject: [PATCH 28/30] Remove redundant --- qlib/log.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/qlib/log.py b/qlib/log.py index d095d571a..e714bc15a 100644 --- a/qlib/log.py +++ b/qlib/log.py @@ -18,7 +18,6 @@ class MetaLogger(type): for key in wrapper_dict: if key not in dict and key != "__reduce__": dict[key] = wrapper_dict[key] - dict["__doc__"] = logging.Logger.__doc__ return type.__new__(cls, name, bases, dict) @@ -31,9 +30,6 @@ class QlibLogger(metaclass=MetaLogger): self.module_name = module_name self.level = 0 - def __reduce__(self): - return (QlibLogger, (self.module_name,)) - @property def logger(self): logger = logging.getLogger(self.module_name) From 5bc2b96346605404faa571e76ee7c37755514b0c Mon Sep 17 00:00:00 2001 From: you-n-g Date: Mon, 3 May 2021 12:34:08 +0800 Subject: [PATCH 29/30] Update data.rst --- docs/component/data.rst | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/component/data.rst b/docs/component/data.rst index 26f44a076..3cee803e6 100644 --- a/docs/component/data.rst +++ b/docs/component/data.rst @@ -182,6 +182,11 @@ The `trade unit` defines the unit number of stocks can be used in a trade, and t qlib.init(provider_uri='~/.qlib/qlib_data/us_data', region=REG_US) +.. note:: + + PRs for new data source are highly welcome! Users could commit the code to crawl data as a PR like `the examples here `_. And then we will use the code to create data cache on our server which other users could use directly. + + Data API ======================== From 84c56f13bd47ee45ae50ec74c5a154295cf55a43 Mon Sep 17 00:00:00 2001 From: lzh222333 Date: Thu, 6 May 2021 04:18:55 +0000 Subject: [PATCH 30/30] docs and bug fixed --- docs/advanced/task_management.rst | 48 ++++++---- docs/component/online.rst | 41 ++++++++ docs/index.rst | 1 + docs/reference/api.rst | 55 +++++++++-- .../online_srv/online_management_simulate.py | 2 + .../online_srv/rolling_online_management.py | 4 +- qlib/data/dataset/__init__.py | 55 +++-------- qlib/data/dataset/handler.py | 8 +- qlib/model/ens/ensemble.py | 46 +++++++++ qlib/model/ens/group.py | 7 ++ qlib/model/trainer.py | 16 ++-- qlib/workflow/online/manager.py | 51 +++++++--- qlib/workflow/online/strategy.py | 95 ++++++++++--------- qlib/workflow/online/update.py | 10 +- qlib/workflow/online/utils.py | 3 + qlib/workflow/task/collect.py | 5 +- qlib/workflow/task/manage.py | 10 +- 17 files changed, 312 insertions(+), 145 deletions(-) create mode 100644 docs/component/online.rst diff --git a/docs/advanced/task_management.rst b/docs/advanced/task_management.rst index 230a4e9d1..a68c12627 100644 --- a/docs/advanced/task_management.rst +++ b/docs/advanced/task_management.rst @@ -1,4 +1,4 @@ -.. _task_managment: +.. _task_management: ================================= Task Management @@ -10,15 +10,17 @@ Introduction ============= The `Workflow <../component/introduction.html>`_ part introduces how to run research workflow in a loosely-coupled way. But it can only execute one ``task`` when you use ``qrun``. -To automatically generate and execute different tasks, ``Task Management`` provides a whole process including `Task Generating`_, `Task Storing`_, `Task Running`_ and `Task Collecting`_. +To automatically generate and execute different tasks, ``Task Management`` provides a whole process including `Task Generating`_, `Task Storing`_, `Task Training`_ and `Task Collecting`_. With this module, users can run their ``task`` automatically at different periods, in different losses, or even by different models. -An example of the entire process is shown `here `_. +This whole process can be used in `Online Serving <../component/online.html>`_. + +An example of the entire process is shown `here `_. Task Generating =============== A ``task`` consists of `Model`, `Dataset`, `Record` or anything added by users. -The specific task template(/definition/config) can be viewed in +The specific task template can be viewed in `Task Section <../component/workflow.html#task-section>`_. Even though the task template is fixed, users can customize their ``TaskGen`` to generate different ``task`` by task template. @@ -27,15 +29,16 @@ Here is the base class of ``TaskGen``: .. autoclass:: qlib.workflow.task.gen.TaskGen :members: -``Qlib`` provider a class `RollingGen `_ to generate a list of ``task`` of the dataset in different date segments. -This class allows users to verify the effect of data from different periods on the model in one experiment. +``Qlib`` provides a class `RollingGen `_ to generate a list of ``task`` of the dataset in different date segments. +This class allows users to verify the effect of data from different periods on the model in one experiment. More information in `here <../reference/api.html#TaskGen>`_. Task Storing =============== To achieve higher efficiency and the possibility of cluster operation, ``Task Manager`` will store all tasks in `MongoDB `_. +``TaskManager`` can fetch undone tasks automatically and manage the lifecycle of a set of tasks with error handling. Users **MUST** finished the configuration of `MongoDB `_ when using this module. -Users need to provide the URL and database name of ``task`` storing like this. +Users need to provide the MongoDB URL and database name for using ``TaskManager`` in `initialization <../start/initialization.html#Parameters>`_ or make statement like this. .. code-block:: python @@ -45,13 +48,12 @@ Users need to provide the URL and database name of ``task`` storing like this. "task_db_name" : "rolling_db" # database name } -The CRUD methods of ``task`` can be found in TaskManager. -More methods can be seen in the `Github `_. - .. autoclass:: qlib.workflow.task.manage.TaskManager :members: -Task Running +More information of ``Task Manager`` can be found in `here <../reference/api.html#TaskManager>`_. + +Task Training =============== After generating and storing those ``task``, it's time to run the ``task`` which are in the *WAITING* status. ``Qlib`` provides a method called ``run_task`` to run those ``task`` in task pool, however, users can also customize how tasks are executed. @@ -60,14 +62,24 @@ It will run the whole workflow defined by ``task``, which includes *Model*, *Dat .. autofunction:: qlib.workflow.task.manage.run_task +Meanwhile, ``Qlib`` provides a module called ``Trainer``. +``Trainer`` will train a list of tasks and return a list of model recorder. +``Qlib`` offer two kind of Trainer, TrainerR is the simplest way and TrainerRM is based on TaskManager to help manager tasks lifecycle automatically. +If you do not want to use ``Task Manager`` to manage tasks, then use TrainerR to train a list of tasks generated by ``TaskGen`` is enough. +More information is in `here <../reference/api.html#Trainer>`_. + Task Collecting =============== -To see the results of ``task`` after running or to update something, ``Qlib`` provides a ``TaskCollector`` to collect the tasks by filter condition (optional). -Here are some methods in this class. +To collect the results of ``task`` after training, ``Qlib`` provides `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_ to collect the results in a readable, expandable and loosely-coupled way. -.. autoclass:: qlib.workflow.task.collect.TaskCollector - :members: +`Collector <../reference/api.html#Collector>`_ can collect object from everywhere and process them such as merging, grouping, averaging and so on. It has 2 step action including ``collect`` (collect anything in a dict) and ``process_collect`` (process collected dict). -``Qlib`` provides a concrete `example `_, including a whole process of `Task Generating`_ (using `RollingGen `_), `Task Storing`_, `Task Running`_ and `Task Collecting`_. -Besides, the `example `_ uses a ``ModelUpdater`` inherited from ``TaskCollector``, which can update the inferences and retrain the model if it is out of date. -Actually, the model updating can be viewed as a subset of ``Online Serving``. \ No newline at end of file +`Group <../reference/api.html#Group>`_ also has 2 steps including ``group`` (can group a set of object based on `group_func` and change them to a dict) and ``reduce`` (can make a dict become an ensemble based on some rule). +For example: {(A,B,C1): object, (A,B,C2): object} ---``group``---> {(A,B): {C1: object, C2: object}} ---``reduce``---> {(A,B): object} + +`Ensemble <../reference/api.html#Ensemble>`_ can merge the objects in an ensemble. +For example: {C1: object, C2: object} ---``Ensemble``---> object + +So the hierarchy is ``Collector``'s second step correspond to ``Group``. And ``Group``'s second step correspond to ``Ensemble``. + +For more information, please see `Collector <../reference/api.html#Collector>`_, `Group <../reference/api.html#Group>`_ and `Ensemble <../reference/api.html#Ensemble>`_, or the `example `_ \ No newline at end of file diff --git a/docs/component/online.rst b/docs/component/online.rst new file mode 100644 index 000000000..e25173153 --- /dev/null +++ b/docs/component/online.rst @@ -0,0 +1,41 @@ +.. _online: + +================================= +Online Serving +================================= +.. currentmodule:: qlib + + +Introduction +============= +In addition to backtesting, one way to test a model is effective is to make predictions in real market conditions or even do real trading based on those predictions. +``Online Serving`` is a set of module for online models using latest data, +which including `Online Manager <#Online Manager>`_, `Online Strategy <#Online Strategy>`_, `Online Tool <#Online Tool>`_, `Updater <#Updater>`_. + +`Here `_ are several examples for reference, which demonstrate different features of ``Online Serving``. +If you have many models or `task` need to be managed, please consider `Task Management <../advanced/task_management.html>`_. +The `examples `_ maybe based on `Task Management <../advanced/task_management.html>`_ such as ``TrainerRM`` or ``Collector``. + +Online Manager +============= + +.. automodule:: qlib.workflow.online.manager + :members: + +Online Strategy +============= + +.. automodule:: qlib.workflow.online.strategy + :members: + +Online Tool +============= + +.. automodule:: qlib.workflow.online.utils + :members: + +Updater +============= + +.. automodule:: qlib.workflow.online.update + :members: \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index 274dc8045..803aa97d2 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -42,6 +42,7 @@ Document Structure Intraday Trading: Model&Strategy Testing Qlib Recorder: Experiment Management Analysis: Evaluation & Results Analysis + Online Serving: Online Management & Strategy & Tool .. toctree:: :maxdepth: 3 diff --git a/docs/reference/api.rst b/docs/reference/api.rst index 691dff703..edba6228a 100644 --- a/docs/reference/api.rst +++ b/docs/reference/api.rst @@ -154,36 +154,71 @@ Record Template .. automodule:: qlib.workflow.record_temp :members: - Task Management ==================== -RollingGen +TaskGen -------------------- -.. autoclass:: qlib.workflow.task.gen.RollingGen +.. automodule:: qlib.workflow.task.gen :members: TaskManager -------------------- -.. autoclass:: qlib.workflow.task.manage.TaskManager +.. automodule:: qlib.workflow.task.manage :members: -TaskCollector +Trainer -------------------- -.. autoclass:: qlib.workflow.task.collect.TaskCollector +.. automodule:: qlib.model.trainer :members: -ModelUpdater +Collector -------------------- -.. autoclass:: qlib.workflow.task.update.ModelUpdater +.. automodule:: qlib.workflow.task.collect :members: -TimeAdjuster +Group -------------------- -.. autoclass:: qlib.workflow.task.utils.TimeAdjuster +.. automodule:: qlib.model.ens.group :members: +Ensemble +-------------------- +.. automodule:: qlib.model.ens.ensemble + :members: + +Utils +-------------------- +.. automodule:: qlib.workflow.task.utils + :members: + + +Online Serving +==================== + + +Online Manager +-------------------- +.. automodule:: qlib.workflow.online.manager + :members: + +Online Strategy +-------------------- +.. automodule:: qlib.workflow.online.strategy + :members: + +Online Tool +-------------------- +.. automodule:: qlib.workflow.online.utils + :members: + +RecordUpdater +-------------------- +.. automodule:: qlib.workflow.online.update + :members: + + Utils ==================== diff --git a/examples/online_srv/online_management_simulate.py b/examples/online_srv/online_management_simulate.py index 16e985ccd..7be46d999 100644 --- a/examples/online_srv/online_management_simulate.py +++ b/examples/online_srv/online_management_simulate.py @@ -131,6 +131,8 @@ class OnlineSimulationExample: self.rolling_online_manager.simulate(end_time=self.end_time) print("========== collect results ==========") print(self.rolling_online_manager.get_collector()()) + print("========== signals ==========") + print(self.rolling_online_manager.get_signals()) print("========== online history ==========") print(self.rolling_online_manager.get_online_history(self.exp_name)) diff --git a/examples/online_srv/rolling_online_management.py b/examples/online_srv/rolling_online_management.py index 950c9684d..25b6fc4da 100644 --- a/examples/online_srv/rolling_online_management.py +++ b/examples/online_srv/rolling_online_management.py @@ -86,7 +86,7 @@ class RollingOnlineExample: task_url="mongodb://10.0.0.4:27017/", task_db_name="rolling_db", rolling_step=550, - tasks=[task_xgboost_config, task_lgb_config], + tasks=[task_xgboost_config], # , task_lgb_config], ): mongo_conf = { "task_url": task_url, # your MongoDB url @@ -148,6 +148,8 @@ class RollingOnlineExample: self.rolling_online_manager.routine() print("========== collect results ==========") print(self.collector()) + print("========== signals ==========") + print(self.rolling_online_manager.get_signals()) def main(self): self.first_run() diff --git a/qlib/data/dataset/__init__.py b/qlib/data/dataset/__init__.py index 4457dda5f..4ae73c670 100644 --- a/qlib/data/dataset/__init__.py +++ b/qlib/data/dataset/__init__.py @@ -27,7 +27,7 @@ class Dataset(Serializable): - setup data - The data related attributes' names should start with '_' so that it will not be saved on disk when serializing. - The data could specify the info to caculate the essential data for preparation + The data could specify the info to calculate the essential data for preparation """ self.setup_data(**kwargs) super().__init__() @@ -92,7 +92,7 @@ class DatasetH(Dataset): handler : Union[dict, DataHandler] handler could be: - - insntance of `DataHandler` + - instance of `DataHandler` - config of `DataHandler`. Please refer to `DataHandler` @@ -114,7 +114,6 @@ class DatasetH(Dataset): """ self.handler: DataHandler = init_instance_by_config(handler, accept_types=DataHandler) self.segments = segments.copy() - self.fetch_kwargs = {} super().__init__(**kwargs) def config(self, handler_kwargs: dict = None, **kwargs): @@ -124,7 +123,7 @@ class DatasetH(Dataset): Parameters ---------- handler_kwargs : dict - Config of DataHanlder, which could include the following arguments: + Config of DataHandler, which could include the following arguments: - arguments of DataHandler.conf_data, such as 'instruments', 'start_time' and 'end_time'. @@ -148,11 +147,11 @@ class DatasetH(Dataset): Parameters ---------- handler_kwargs : dict - init arguments of DataHanlder, which could include the following arguments: + init arguments of DataHandler, which could include the following arguments: - init_type : Init Type of Handler - - enable_cache : wheter to enable cache + - enable_cache : whether to enable cache """ super().setup_data(**kwargs) @@ -172,7 +171,7 @@ class DatasetH(Dataset): ---------- slc : slice """ - return self.handler.fetch(slc, **kwargs, **self.fetch_kwargs) + return self.handler.fetch(slc, **kwargs) def prepare( self, @@ -232,7 +231,7 @@ class TSDataSampler: (T)ime-(S)eries DataSampler This is the result of TSDatasetH - It works like `torch.data.utils.Dataset`, it provides a very convient interface for constructing time-series + It works like `torch.data.utils.Dataset`, it provides a very convenient interface for constructing time-series dataset based on tabular data. If user have further requirements for processing data, user could process them based on `TSDataSampler` or create @@ -289,29 +288,12 @@ class TSDataSampler: # the data type will be changed # The index of usable data is between start_idx and end_idx + self.start_idx, self.end_idx = self.data.index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end)) self.idx_df, self.idx_map = self.build_index(self.data) - self.data_index = deepcopy(self.data.index) - - if flt_data is not None: - self.flt_data = np.array(flt_data).reshape(-1) - self.idx_map = self.flt_idx_map(self.flt_data, self.idx_map) - self.data_index = self.data_index[np.where(self.flt_data == True)[0]] - - self.start_idx, self.end_idx = self.data_index.slice_locs(start=pd.Timestamp(start), end=pd.Timestamp(end)) self.idx_arr = np.array(self.idx_df.values, dtype=np.float64) # for better performance - + self.data_idx = deepcopy(self.data.index) del self.data # save memory - @staticmethod - def flt_idx_map(flt_data, idx_map): - idx = 0 - new_idx_map = {} - for i, exist in enumerate(flt_data): - if exist: - new_idx_map[idx] = idx_map[i] - idx += 1 - return new_idx_map - def get_index(self): """ Get the pandas index of the data, it will be useful in following scenarios @@ -461,7 +443,7 @@ class TSDatasetH(DatasetH): (T)ime-(S)eries Dataset (H)andler - Covnert the tabular data to Time-Series data + Convert the tabular data to Time-Series data Requirements analysis @@ -505,19 +487,8 @@ class TSDatasetH(DatasetH): """ split the _prepare_raw_seg is to leave a hook for data preprocessing before creating processing data """ - dtype = kwargs.pop("dtype") + dtype = kwargs.pop("dtype", None) start, end = slc.start, slc.stop - flt_col = kwargs.pop("flt_col", None) - # TSDatasetH will retrieve more data for complete - data = self._prepare_raw_seg(slc, **kwargs) - - flt_kwargs = deepcopy(kwargs) - if flt_col is not None: - flt_kwargs["col_set"] = flt_col - flt_data = self._prepare_raw_seg(slc, **flt_kwargs) - assert len(flt_data.columns) == 1 - else: - flt_data = None - - tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len, dtype=dtype, flt_data=flt_data) + data = self._prepare_raw_seg(slc=slc, **kwargs) + tsds = TSDataSampler(data=data, start=start, end=end, step_len=self.step_len, dtype=dtype) return tsds diff --git a/qlib/data/dataset/handler.py b/qlib/data/dataset/handler.py index f1fa39c3b..63b49d78b 100644 --- a/qlib/data/dataset/handler.py +++ b/qlib/data/dataset/handler.py @@ -36,7 +36,7 @@ class DataHandler(Serializable): The data handler try to maintain a handler with 2 level. `datetime` & `instruments`. - Any order of the index level can be suported (The order will be implied in the data). + Any order of the index level can be supported (The order will be implied in the data). The order <`datetime`, `instruments`> will be used when the dataframe index name is missed. Example of the data: @@ -77,7 +77,7 @@ class DataHandler(Serializable): data_loader : Tuple[dict, str, DataLoader] data loader to load the data. init_data : - intialize the original data in the constructor. + initialize the original data in the constructor. fetch_orig : bool Return the original data instead of copy if possible. """ @@ -128,7 +128,7 @@ class DataHandler(Serializable): def setup_data(self, enable_cache: bool = False): """ - Set Up the data in case of running intialization for multiple time + Set Up the data in case of running initialization for multiple time It is responsible for maintaining following variable 1) self._data @@ -453,7 +453,7 @@ class DataHandlerLP(DataHandler): def setup_data(self, init_type: str = IT_FIT_SEQ, **kwargs): """ - Set up the data in case of running intialization for multiple time + Set up the data in case of running initialization for multiple time Parameters ---------- diff --git a/qlib/model/ens/ensemble.py b/qlib/model/ens/ensemble.py index 7ccf98ab2..1fb14a37b 100644 --- a/qlib/model/ens/ensemble.py +++ b/qlib/model/ens/ensemble.py @@ -5,6 +5,7 @@ Ensemble can merge the objects in an Ensemble. For example, if there are many submodels predictions, we may need to merge them in an ensemble predictions. """ +from typing import Union import pandas as pd @@ -24,6 +25,30 @@ class Ensemble: raise NotImplementedError(f"Please implement the `__call__` method.") +class SingleKeyEnsemble(Ensemble): + + """ + Extract the object if there is only one key and value in dict. Make result more readable. + {Only key: Only value} -> Only value + If there are more than 1 key or less than 1 key, then do nothing. + Even you can run this recursively to make dict more readable. + NOTE: Default run recursively. + """ + + def __call__(self, ensemble_dict: Union[dict, object], recursion: bool = True) -> object: + if not isinstance(ensemble_dict, dict): + return ensemble_dict + if recursion: + tmp_dict = {} + for k, v in ensemble_dict.items(): + tmp_dict[k] = self(v, recursion) + ensemble_dict = tmp_dict + keys = list(ensemble_dict.keys()) + if len(keys) == 1: + ensemble_dict = ensemble_dict[keys[0]] + return ensemble_dict + + class RollingEnsemble(Ensemble): """Merge the rolling objects in an Ensemble""" @@ -47,3 +72,24 @@ class RollingEnsemble(Ensemble): artifact = artifact[~artifact.index.duplicated(keep="last")] artifact = artifact.sort_index() return artifact + + +class AverageEnsemble(Ensemble): + def __call__(self, ensemble_dict: dict): + """ + Average a dict of same shape dataframe like `prediction` or `IC` into an ensemble. + + NOTE: The values of dict must be pd.DataFrame, and have the index "datetime" + + Args: + ensemble_dict (dict): a dict like {"A": pd.DataFrame, "B": pd.DataFrame}. + The key of the dict will be ignored. + + Returns: + pd.DataFrame: the complete result of averaging. + """ + values = list(ensemble_dict.values()) + results = pd.concat(values, axis=1) + results = results.mean(axis=1).to_frame("score") + results = results.sort_index() + return results diff --git a/qlib/model/ens/group.py b/qlib/model/ens/group.py index d53a55f4c..d8f174105 100644 --- a/qlib/model/ens/group.py +++ b/qlib/model/ens/group.py @@ -3,6 +3,13 @@ """ Group can group a set of object based on `group_func` and change them to a dict. +After group, we provide a method to reduce them. + +For example: + +group: {(A,B,C1): object, (A,B,C2): object} -> {(A,B): {C1: object, C2: object}} +reduce: {(A,B): {C1: object, C2: object}} -> {(A,B): object} + """ from qlib.model.ens.ensemble import Ensemble, RollingEnsemble diff --git a/qlib/model/trainer.py b/qlib/model/trainer.py index a0d252ab4..7680674a6 100644 --- a/qlib/model/trainer.py +++ b/qlib/model/trainer.py @@ -3,12 +3,12 @@ """ The Trainer will train a list of tasks and return a list of model recorder. -There are two steps in each Trainer including `train`(make model recorder) and `end_train`(modify model recorder). +There are two steps in each Trainer including ``train``(make model recorder) and ``end_train``(modify model recorder). -This is concept called "DelayTrainer", which can be used in online simulating to parallel training. -In "DelayTrainer", the first step is only to save some necessary info to model recorder, and the second step which will be finished in the end can do some concurrent and time-consuming operations such as model fitting. +This is concept called ``DelayTrainer``, which can be used in online simulating to parallel training. +In ``DelayTrainer``, the first step is only to save some necessary info to model recorder, and the second step which will be finished in the end can do some concurrent and time-consuming operations such as model fitting. -`Qlib` offer two kind of Trainer, TrainerR is simplest and TrainerRM is based on TaskManager to help manager tasks lifecycle automatically. +``Qlib`` offer two kind of Trainer, ``TrainerR`` is the simplest way and ``TrainerRM`` is based on TaskManager to help manager tasks lifecycle automatically. """ import socket @@ -36,9 +36,6 @@ def begin_task_train(task_config: dict, experiment_name: str, recorder_name: str Returns: Recorder: the model recorder """ - # FIXME: recorder_id - if recorder_name is None: - recorder_name = str(time.time()) with R.start(experiment_name=experiment_name, recorder_name=recorder_name): R.log_params(**flatten_dict(task_config)) R.save_objects(**{"task": task_config}) # keep the original format and datatype @@ -58,7 +55,7 @@ def end_task_train(rec: Recorder, experiment_name: str) -> Recorder: Returns: Recorder: the model recorder """ - with R.start(experiment_name=experiment_name, recorder_name=rec.info["name"], resume=True): + with R.start(experiment_name=experiment_name, recorder_id=rec.info["id"], resume=True): task_config = R.load_object("task") # model & dataset initiation model: Model = init_instance_by_config(task_config["model"]) @@ -314,7 +311,8 @@ class TrainerRM(Trainer): def reset(self): """ - NOTE: this method will delete all task in this task_pool! + .. note:: + this method will delete all task in this task_pool! """ tm = TaskManager(task_pool=self.task_pool) tm.remove() diff --git a/qlib/workflow/online/manager.py b/qlib/workflow/online/manager.py index 4e9290096..6c62fbce9 100644 --- a/qlib/workflow/online/manager.py +++ b/qlib/workflow/online/manager.py @@ -2,11 +2,14 @@ # Licensed under the MIT License. """ -OnlineManager can manage a set of OnlineStrategy and run them dynamically. +OnlineManager can manage a set of `Online Strategy <#Online Strategy>`_ and run them dynamically. With the change of time, the decisive models will be also changed. In this module, we call those contributing models as `online` models. In every routine(such as everyday or every minutes), the `online` models maybe changed and the prediction of them need to be updated. So this module provide a series methods to control this process. + +This module also provide a method to simulate `Online Strategy <#Online Strategy>`_ in the history. +Which means you can verify your strategy or find a better one. """ from typing import Dict, List, Union @@ -14,12 +17,18 @@ from typing import Dict, List, Union import pandas as pd from qlib import get_module_logger from qlib.data.data import D +from qlib.model.ens.ensemble import AverageEnsemble, SingleKeyEnsemble from qlib.utils.serial import Serializable from qlib.workflow.online.strategy import OnlineStrategy from qlib.workflow.task.collect import HyperCollector class OnlineManager(Serializable): + """ + OnlineManager can manage online models with `Online Strategy <#Online Strategy>`_. + It also provide a history recording which models are onlined at what time. + """ + def __init__( self, strategy: Union[OnlineStrategy, List[OnlineStrategy]], @@ -29,10 +38,11 @@ class OnlineManager(Serializable): ): """ Init OnlineManager. + One OnlineManager must have at least one OnlineStrategy. Args: strategy (Union[OnlineStrategy, List[OnlineStrategy]]): an instance of OnlineStrategy or a list of OnlineStrategy - begin_time (Union[str,pd.Timestamp], optional): the OnlineManager will begin at this time. Defaults to None. + begin_time (Union[str,pd.Timestamp], optional): the OnlineManager will begin at this time. Defaults to None for using latest date. freq (str, optional): data frequency. Defaults to "day". need_log (bool, optional): print log or not. Defaults to True. """ @@ -50,7 +60,7 @@ class OnlineManager(Serializable): def first_train(self): """ - Run every strategy first_train method and record the online history + Run every strategy first_train method and record the online history. """ for strategy in self.strategy: self.logger.info(f"Strategy `{strategy.name_id}` begins first training...") @@ -62,7 +72,7 @@ class OnlineManager(Serializable): Run typical update process for every strategy and record the online history. The typical update process after a routine, such as day by day or month by month. - update online prediction -> prepare signals -> prepare tasks -> prepare new models -> reset online models + The process is: Prepare signals -> Prepare tasks -> Prepare online models. Args: cur_time (Union[str,pd.Timestamp], optional): run routine method in this time. Defaults to None. @@ -84,15 +94,15 @@ class OnlineManager(Serializable): def get_collector(self) -> HyperCollector: """ - Get the instance of HyperCollector to collect results from every strategy. + Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results from every strategy. Returns: - HyperCollector: the collector can collect other collectors. + HyperCollector: the collector to collect other collectors (using SingleKeyEnsemble() to make results more readable). """ collector_dict = {} for strategy in self.strategy: collector_dict[strategy.name_id] = strategy.get_collector() - return HyperCollector(collector_dict) + return HyperCollector(collector_dict, process_list=SingleKeyEnsemble()) def get_online_history(self, strategy_name_id: str) -> list: """ @@ -102,7 +112,7 @@ class OnlineManager(Serializable): strategy_name_id (str): the name_id of strategy Returns: - dict: a list like [(time, [online_models])] + list: a list like [(begin_time, [online_models])] """ history_dict = self.history[strategy_name_id] history = [] @@ -121,10 +131,27 @@ class OnlineManager(Serializable): for strategy in self.strategy: strategy.delay_prepare(self.get_online_history(strategy.name_id), **delay_kwargs) + def get_signals(self) -> pd.DataFrame: + """ + Average all strategy signals as the online signals. + + Assumption: the signals from every strategy is pd.DataFrame. Override this function to change. + + Returns: + pd.DataFrame: signals + """ + signals_dict = {} + for strategy in self.strategy: + signals_dict[strategy.name_id] = strategy.get_signals() + return AverageEnsemble()(signals_dict) + def simulate(self, end_time, frequency="day", task_kwargs={}, model_kwargs={}, delay_kwargs={}) -> HyperCollector: """ - Starting from cur time, this method will simulate every routine in OnlineManager. - NOTE: Considering the parallel training, the models and signals can be perpared after all routine simulating. + Starting from current time, this method will simulate every routine in OnlineManager until end time. + + Considering the parallel training, the models and signals can be perpared after all routine simulating. + + The delay training way can be ``DelayTrainer`` and the delay preparing signals way can be ``delay_prepare``. Returns: HyperCollector: the OnlineManager's collector @@ -140,7 +167,9 @@ class OnlineManager(Serializable): def reset(self): """ - NOTE: This method will reset all strategy! Be careful to use it. + This method will reset all strategy! + + **Be careful to use it.** """ self.cur_time = self.begin_time self.history = {} diff --git a/qlib/workflow/online/strategy.py b/qlib/workflow/online/strategy.py index 3782ee652..0cae11b7f 100644 --- a/qlib/workflow/online/strategy.py +++ b/qlib/workflow/online/strategy.py @@ -2,8 +2,7 @@ # Licensed under the MIT License. """ -OnlineStrategy is a set of strategy of online serving. -It is working with OnlineManager, responsing how the tasks are generated, the models are updated and signals are perpared. +OnlineStrategy is a set of strategy for online serving. """ from copy import deepcopy @@ -12,6 +11,7 @@ from typing import List, Tuple, Union import pandas as pd from qlib.data.data import D from qlib.log import get_module_logger +from qlib.model.ens.ensemble import AverageEnsemble, SingleKeyEnsemble from qlib.model.ens.group import RollingGroup from qlib.model.trainer import Trainer, TrainerR from qlib.workflow import R @@ -23,9 +23,14 @@ from qlib.workflow.task.utils import TimeAdjuster, list_recorders class OnlineStrategy: + """ + OnlineStrategy is working with `Online Manager <#Online Manager>`_, responsing how the tasks are generated, the models are updated and signals are perpared. + """ + def __init__(self, name_id: str, trainer: Trainer = None, need_log=True): """ Init OnlineStrategy. + This module **MUST** use `Trainer <../reference/api.html#Trainer>`_ to finishing model training. Args: name_id (str): a unique name or id @@ -43,6 +48,7 @@ class OnlineStrategy: After perparing the data of last routine (a box in box-plot) which means the end of the routine, we can prepare trading signals for next routine. NOTE: Given a set prediction, all signals before these prediction end time will be prepared well. + Args: delay: bool If this method was called by `delay_prepare` @@ -52,7 +58,7 @@ class OnlineStrategy: def prepare_tasks(self, *args, **kwargs): """ After the end of a routine, check whether we need to prepare and train some new tasks. - return the new tasks waiting for training. + Return the new tasks waiting for training. You can find last online models by OnlineTool.online_models. """ @@ -66,10 +72,6 @@ class OnlineStrategy: Args: tasks (list): a list of tasks. - tag (str): - `ONLINE_TAG` for first train or additional train - `NEXT_ONLINE_TAG` for reset online model when calling `reset_online_tag` - `OFFLINE_TAG` for train but offline those models check_func: the method to judge if a model can be online. The parameter is the model record and return True for online. None for online every models. @@ -95,7 +97,8 @@ class OnlineStrategy: def get_collector(self) -> Collector: """ - Get the instance of collector to collect results of online serving. + Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results of online serving. + For example: 1) collect predictions in Recorder @@ -109,7 +112,8 @@ class OnlineStrategy: def delay_prepare(self, history: list, **kwargs): """ Prepare all models and signals if there are something waiting for prepare. - NOTE: Assumption: the predictions of online models need less than next begin_time, or this method will work in a wrong way. + + Assumption: the predictions of online models need less than next begin_time, or this method will work in a wrong way. Args: history (list): an online models list likes [begin_time:[online models]]. @@ -120,6 +124,12 @@ class OnlineStrategy: self.tool.reset_online_tag(recs_list) self.prepare_signals(delay=True) + def get_signals(self): + """ + Get prepared signals. + """ + raise NotImplementedError(f"Please implement the `get_signals` method.") + def reset(self): """ Delete all things and set them to default status. This method is convenient to explore the strategy for online simulation. @@ -164,17 +174,20 @@ class RollingAverageStrategy(OnlineStrategy): self.rg = rolling_gen self.tool = OnlineToolR(self.exp_name) self.ta = TimeAdjuster() - self.signal_rec = None # the recorder to record signals + with R.start(experiment_name=self.signal_exp_name, recorder_name=self.exp_name, resume=True): + self.signal_rec = R.get_recorder() # the recorder to record signals + self.signal_rec.save_objects(**{"signals": None}) - def get_collector(self, rec_key_func=None, rec_filter_func=None): + def get_collector(self, process_list=[RollingGroup()], rec_key_func=None, rec_filter_func=None, artifacts_key=None): """ - Get the instance of collector to collect results. The returned collector must can distinguish results in different models. + Get the instance of `Collector <../advanced/task_management.html#Task Collecting>`_ to collect results. The returned collector must can distinguish results in different models. Assumption: the models can be distinguished based on model name and rolling test segments. If you do not want this assumption, please implement your own method or use another rec_key_func. Args: rec_key_func (Callable): a function to get the key of a recorder. If None, use recorder id. rec_filter_func (Callable, optional): filter the recorder by return True or False. Defaults to None. + artifacts_key (List[str], optional): the artifacts key you want to get. If None, get all artifacts. """ def rec_key(recorder): @@ -188,18 +201,13 @@ class RollingAverageStrategy(OnlineStrategy): artifacts_collector = RecorderCollector( experiment=self.exp_name, - process_list=RollingGroup(), + process_list=process_list, rec_key_func=rec_key_func, rec_filter_func=rec_filter_func, + artifacts_key=artifacts_key, ) - signals_collector = RecorderCollector( - experiment=self.signal_exp_name, - rec_key_func=lambda rec: rec.info["name"], - rec_filter_func=lambda rec: rec.info["name"] == self.exp_name, - artifacts_path={"signals": "signals"}, - ) - return HyperCollector({"artifacts": artifacts_collector, "signals": signals_collector}) + return artifacts_collector def first_train(self) -> List[Recorder]: """ @@ -252,7 +260,11 @@ class RollingAverageStrategy(OnlineStrategy): Average the predictions of online models and offer a trading signals every routine. The signals will be saved to `signal` file of a recorder named self.exp_name of a experiment using the name of `SIGNAL_EXP` Even if the latest signal already exists, the latest calculation result will be overwritten. - NOTE: Given a prediction of a certain time, all signals before this time will be prepared well. + + .. note:: + + Given a prediction of a certain time, all signals before this time will be prepared well. + Args: over_write (bool, optional): If True, the new signals will overwrite the file. If False, the new signals will append to the end of signals. Defaults to False. Returns: @@ -260,21 +272,17 @@ class RollingAverageStrategy(OnlineStrategy): """ if not delay: self.tool.update_online_pred() - if self.signal_rec is None: - with R.start(experiment_name=self.signal_exp_name, recorder_name=self.exp_name, resume=True): - self.signal_rec = R.get_recorder() - pred = [] - try: - old_signals = self.signal_rec.load_object("signals") - except OSError: - old_signals = None + # Get a collector to average online models predictions + online_collector = self.get_collector( + process_list=[AverageEnsemble()], + rec_filter_func=lambda x: True if self.tool.get_online_tag(x) == self.tool.ONLINE_TAG else False, + artifacts_key="pred", + ) + online_results = online_collector() + signals = online_results["pred"] - for rec in self.tool.online_models(): - pred.append(rec.load_object("pred.pkl")) - - signals: pd.DataFrame = pd.concat(pred, axis=1).mean(axis=1).to_frame("score") - signals = signals.sort_index() + old_signals = self.get_signals() if old_signals is not None and not over_write: old_max = old_signals.index.get_level_values("datetime").max() new_signals = signals.loc[old_max:] @@ -288,18 +296,15 @@ class RollingAverageStrategy(OnlineStrategy): self.signal_rec.save_objects(**{"signals": signals}) return signals - # def get_signals(self): - # """ - # get signals from the recorder(named self.exp_name) of the experiment(named self.SIGNAL_EXP) + def get_signals(self) -> object: + """ + Get signals from the recorder(named self.exp_name) of the experiment(named self.SIGNAL_EXP) - # Returns: - # signals - # """ - # if self.signal_rec is None: - # with R.start(experiment_name=self.signal_exp_name, recorder_name=self.exp_name, resume=True): - # self.signal_rec = R.get_recorder() - # signals = self.signal_rec.load_object("signals") - # return signals + Returns: + object: signals + """ + signals = self.signal_rec.load_object("signals") + return signals def _list_latest(self, rec_list: List[Recorder]): """ diff --git a/qlib/workflow/online/update.py b/qlib/workflow/online/update.py index 69ad55324..ab910ba8d 100644 --- a/qlib/workflow/online/update.py +++ b/qlib/workflow/online/update.py @@ -2,7 +2,7 @@ # Licensed under the MIT License. """ -Update is a module to update artifacts such as predictions, when the stock data updating. +Updater is a module to update artifacts such as predictions, when the stock data is updating. """ from abc import ABCMeta, abstractmethod @@ -89,9 +89,13 @@ class PredUpdater(RecordUpdater): hist_ref : int Sometimes, the dataset will have historical depends. Leave the problem to user to set the length of historical dependency - NOTE: the start_time is not included in the hist_ref - # TODO: automate this step in the future. + + .. note:: + + the start_time is not included in the hist_ref + """ + # TODO: automate this hist_ref in the future. super().__init__(record=record, need_log=need_log) self.to_date = to_date diff --git a/qlib/workflow/online/utils.py b/qlib/workflow/online/utils.py index 4d630a665..296ca3ea6 100644 --- a/qlib/workflow/online/utils.py +++ b/qlib/workflow/online/utils.py @@ -16,6 +16,9 @@ from qlib.workflow.task.utils import list_recorders class OnlineTool: + """ + OnlineTool. + """ ONLINE_KEY = "online_status" # the online status key in recorder ONLINE_TAG = "online" # the 'online' model diff --git a/qlib/workflow/task/collect.py b/qlib/workflow/task/collect.py index d74d08184..28320e2ce 100644 --- a/qlib/workflow/task/collect.py +++ b/qlib/workflow/task/collect.py @@ -5,6 +5,7 @@ Collector can collect object from everywhere and process them such as merging, grouping, averaging and so on. """ +from qlib.model.ens.ensemble import SingleKeyEnsemble from qlib.workflow import R import dill as pickle @@ -81,7 +82,7 @@ class Collector: filepath (str): the path of file Returns: - bool: if successed + bool: if succeeded """ try: with open(filepath, "wb") as f: @@ -122,6 +123,8 @@ class HyperCollector(Collector): Args: collector_dict (dict): the dict like {collector_key, Collector} process_list (list or Callable): the list of processors or the instance of processor to process dict. + NOTE: process_list = [SingleKeyEnsemble()] can ignore key and use value directly if there is only one {k,v} in a dict. + This can make result more readable. If you want to maintain as it should be, just give a empty process list. """ super().__init__(process_list=process_list) self.collector_dict = collector_dict diff --git a/qlib/workflow/task/manage.py b/qlib/workflow/task/manage.py index 3c3144fe8..c71be7d39 100644 --- a/qlib/workflow/task/manage.py +++ b/qlib/workflow/task/manage.py @@ -52,9 +52,13 @@ class TaskManager: Assumption: the data in MongoDB was encoded and the data out of MongoDB was decoded Here are four status which are: + STATUS_WAITING: waiting for train + STATUS_RUNNING: training - STATUS_PART_DONE: finished some step and waiting for next step. + + STATUS_PART_DONE: finished some step and waiting for next step + STATUS_DONE: all work done """ @@ -393,9 +397,13 @@ def run_task( While task pool is not empty (has WAITING tasks), use task_func to fetch and run tasks in task_pool After running this method, here are 4 situations (before_status -> after_status): + STATUS_WAITING -> STATUS_DONE: use task["def"] as `task_func` param + STATUS_WAITING -> STATUS_PART_DONE: use task["def"] as `task_func` param + STATUS_PART_DONE -> STATUS_PART_DONE: use task["res"] as `task_func` param + STATUS_PART_DONE -> STATUS_DONE: use task["res"] as `task_func` param Parameters