From 3bf6c7f95f5cc77d4025358e618d5f688138f5cc Mon Sep 17 00:00:00 2001 From: Alex Wang Date: Mon, 22 Mar 2021 15:37:54 +0800 Subject: [PATCH] update format --- qlib/contrib/eva/alpha.py | 24 +++++++++++++----------- qlib/workflow/record_temp.py | 16 +++++++++++----- 2 files changed, 24 insertions(+), 16 deletions(-) diff --git a/qlib/contrib/eva/alpha.py b/qlib/contrib/eva/alpha.py index e2beafc13..8078dd4ed 100644 --- a/qlib/contrib/eva/alpha.py +++ b/qlib/contrib/eva/alpha.py @@ -7,15 +7,18 @@ import pandas as pd from typing import Tuple -def calc_prec(pred: pd.Series, label: pd.Series, date_col="datetime", quantile: float = 0.2, dropna=False, is_alpha=False) -> Tuple[pd.Series, pd.Series]: - """ calculate the precision + +def calc_prec( + pred: pd.Series, label: pd.Series, date_col="datetime", quantile: float = 0.2, dropna=False, is_alpha=False +) -> Tuple[pd.Series, pd.Series]: + """calculate the precision pred : pred label : label date_col : date_col - + Returns ------- (pd.Series, pd.Series) @@ -23,29 +26,28 @@ def calc_prec(pred: pd.Series, label: pd.Series, date_col="datetime", quantile: """ if is_alpha: label = label - label.mean(level=0) - if int(1/quantile) >= len(label.index.get_level_values(1).unique()): + if int(1 / quantile) >= len(label.index.get_level_values(1).unique()): raise ValueError("Need more instruments to calculate precision") - df = pd.DataFrame({"pred": pred, "label": label}) if dropna: - df.dropna(inplace = True) - + df.dropna(inplace=True) + group = df.groupby(level=date_col) - + N = lambda x: int(len(x) * quantile) # find the top/low quantile of prediction and treat them as long and short target long = group.apply(lambda x: x.nlargest(N(x), columns="pred").label).reset_index(level=0, drop=True) short = group.apply(lambda x: x.nsmallest(N(x), columns="pred").label).reset_index(level=0, drop=True) - + groupll = long.groupby(date_col) ll_ration = groupll.apply(lambda x: x > 0) ll_c = groupll.count() - + groups = short.groupby(date_col) s_ration = groups.apply(lambda x: x < 0) s_c = groups.count() - return (ll_ration.groupby(date_col).sum()/ll_c), (s_ration.groupby(date_col).sum()/s_c) + return (ll_ration.groupby(date_col).sum() / ll_c), (s_ration.groupby(date_col).sum() / s_c) def calc_ic(pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False) -> Tuple[pd.Series, pd.Series]: diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index 8ab8405a5..c47b999f3 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -154,12 +154,13 @@ class SignalRecord(RecordTemp): def load(self, name="pred.pkl"): return super().load(name) - - + + class HFSignalRecord(SignalRecord): """ This is the Signal Analysis Record class that generates the analysis results such as IC and IR. This class inherits the ``RecordTemp`` class. """ + artifact_path = "hg_sig_analysis" def __init__(self, recorder, **kwargs): @@ -169,7 +170,7 @@ class HFSignalRecord(SignalRecord): pred = self.load("pred.pkl") raw_label = self.load("label.pkl") - long_pre, short_pre = calc_prec(pred.iloc[:, 0], raw_label.iloc[:, 0], is_alpha = True) + long_pre, short_pre = calc_prec(pred.iloc[:, 0], raw_label.iloc[:, 0], is_alpha=True) ic, ric = calc_ic(pred.iloc[:, 0], raw_label.iloc[:, 0]) metrics = { "IC": ic.mean(), @@ -177,7 +178,7 @@ class HFSignalRecord(SignalRecord): "Rank IC": ric.mean(), "Rank ICIR": ric.mean() / ric.std(), "Long precision": long_pre.mean(), - "Short precision": short_pre.mean() + "Short precision": short_pre.mean(), } objects = {"ic.pkl": ic, "ric.pkl": ric} objects.update({"long_pre.pkl": long_pre, "short_pre.pkl": short_pre}) @@ -199,7 +200,12 @@ class HFSignalRecord(SignalRecord): pprint(metrics) def list(self): - paths = [self.get_path("ic.pkl"), self.get_path("ric.pkl"), self.get_path("long_pre.pkl"), self.get_path("short_pre.pkl")] + paths = [ + self.get_path("ic.pkl"), + self.get_path("ric.pkl"), + self.get_path("long_pre.pkl"), + self.get_path("short_pre.pkl"), + ] paths.extend([self.get_path("long_short_r.pkl"), self.get_path("long_avg_r.pkl")]) return paths