mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-02 18:40:58 +08:00
update format
This commit is contained in:
@@ -7,15 +7,18 @@ import pandas as pd
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
def calc_prec(pred: pd.Series, label: pd.Series, date_col="datetime", quantile: float = 0.2, dropna=False, is_alpha=False) -> Tuple[pd.Series, pd.Series]:
|
||||
""" calculate the precision
|
||||
|
||||
def calc_prec(
|
||||
pred: pd.Series, label: pd.Series, date_col="datetime", quantile: float = 0.2, dropna=False, is_alpha=False
|
||||
) -> Tuple[pd.Series, pd.Series]:
|
||||
"""calculate the precision
|
||||
pred :
|
||||
pred
|
||||
label :
|
||||
label
|
||||
date_col :
|
||||
date_col
|
||||
|
||||
|
||||
Returns
|
||||
-------
|
||||
(pd.Series, pd.Series)
|
||||
@@ -23,29 +26,28 @@ def calc_prec(pred: pd.Series, label: pd.Series, date_col="datetime", quantile:
|
||||
"""
|
||||
if is_alpha:
|
||||
label = label - label.mean(level=0)
|
||||
if int(1/quantile) >= len(label.index.get_level_values(1).unique()):
|
||||
if int(1 / quantile) >= len(label.index.get_level_values(1).unique()):
|
||||
raise ValueError("Need more instruments to calculate precision")
|
||||
|
||||
|
||||
df = pd.DataFrame({"pred": pred, "label": label})
|
||||
if dropna:
|
||||
df.dropna(inplace = True)
|
||||
|
||||
df.dropna(inplace=True)
|
||||
|
||||
group = df.groupby(level=date_col)
|
||||
|
||||
|
||||
N = lambda x: int(len(x) * quantile)
|
||||
# find the top/low quantile of prediction and treat them as long and short target
|
||||
long = group.apply(lambda x: x.nlargest(N(x), columns="pred").label).reset_index(level=0, drop=True)
|
||||
short = group.apply(lambda x: x.nsmallest(N(x), columns="pred").label).reset_index(level=0, drop=True)
|
||||
|
||||
|
||||
groupll = long.groupby(date_col)
|
||||
ll_ration = groupll.apply(lambda x: x > 0)
|
||||
ll_c = groupll.count()
|
||||
|
||||
|
||||
groups = short.groupby(date_col)
|
||||
s_ration = groups.apply(lambda x: x < 0)
|
||||
s_c = groups.count()
|
||||
return (ll_ration.groupby(date_col).sum()/ll_c), (s_ration.groupby(date_col).sum()/s_c)
|
||||
return (ll_ration.groupby(date_col).sum() / ll_c), (s_ration.groupby(date_col).sum() / s_c)
|
||||
|
||||
|
||||
def calc_ic(pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False) -> Tuple[pd.Series, pd.Series]:
|
||||
|
||||
@@ -154,12 +154,13 @@ class SignalRecord(RecordTemp):
|
||||
|
||||
def load(self, name="pred.pkl"):
|
||||
return super().load(name)
|
||||
|
||||
|
||||
|
||||
|
||||
class HFSignalRecord(SignalRecord):
|
||||
"""
|
||||
This is the Signal Analysis Record class that generates the analysis results such as IC and IR. This class inherits the ``RecordTemp`` class.
|
||||
"""
|
||||
|
||||
artifact_path = "hg_sig_analysis"
|
||||
|
||||
def __init__(self, recorder, **kwargs):
|
||||
@@ -169,7 +170,7 @@ class HFSignalRecord(SignalRecord):
|
||||
pred = self.load("pred.pkl")
|
||||
raw_label = self.load("label.pkl")
|
||||
|
||||
long_pre, short_pre = calc_prec(pred.iloc[:, 0], raw_label.iloc[:, 0], is_alpha = True)
|
||||
long_pre, short_pre = calc_prec(pred.iloc[:, 0], raw_label.iloc[:, 0], is_alpha=True)
|
||||
ic, ric = calc_ic(pred.iloc[:, 0], raw_label.iloc[:, 0])
|
||||
metrics = {
|
||||
"IC": ic.mean(),
|
||||
@@ -177,7 +178,7 @@ class HFSignalRecord(SignalRecord):
|
||||
"Rank IC": ric.mean(),
|
||||
"Rank ICIR": ric.mean() / ric.std(),
|
||||
"Long precision": long_pre.mean(),
|
||||
"Short precision": short_pre.mean()
|
||||
"Short precision": short_pre.mean(),
|
||||
}
|
||||
objects = {"ic.pkl": ic, "ric.pkl": ric}
|
||||
objects.update({"long_pre.pkl": long_pre, "short_pre.pkl": short_pre})
|
||||
@@ -199,7 +200,12 @@ class HFSignalRecord(SignalRecord):
|
||||
pprint(metrics)
|
||||
|
||||
def list(self):
|
||||
paths = [self.get_path("ic.pkl"), self.get_path("ric.pkl"), self.get_path("long_pre.pkl"), self.get_path("short_pre.pkl")]
|
||||
paths = [
|
||||
self.get_path("ic.pkl"),
|
||||
self.get_path("ric.pkl"),
|
||||
self.get_path("long_pre.pkl"),
|
||||
self.get_path("short_pre.pkl"),
|
||||
]
|
||||
paths.extend([self.get_path("long_short_r.pkl"), self.get_path("long_avg_r.pkl")])
|
||||
return paths
|
||||
|
||||
|
||||
Reference in New Issue
Block a user