mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-03 02:50:58 +08:00
support long-short analysis
This commit is contained in:
@@ -5,8 +5,12 @@ The interface should be redesigned carefully in the future.
|
||||
"""
|
||||
import pandas as pd
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
def calc_ic(pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False) -> (pd.Series, pd.Series):
|
||||
|
||||
def calc_ic(
|
||||
pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False
|
||||
) -> Tuple[pd.Series, pd.Series]:
|
||||
"""calc_ic.
|
||||
|
||||
Parameters
|
||||
@@ -25,8 +29,52 @@ def calc_ic(pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False
|
||||
"""
|
||||
df = pd.DataFrame({"pred": pred, "label": label})
|
||||
ic = df.groupby(date_col).apply(lambda df: df["pred"].corr(df["label"]))
|
||||
ric = df.groupby(date_col).apply(lambda df: df["pred"].corr(df["label"], method="spearman"))
|
||||
ric = df.groupby(date_col).apply(
|
||||
lambda df: df["pred"].corr(df["label"], method="spearman")
|
||||
)
|
||||
if dropna:
|
||||
return ic.dropna(), ric.dropna()
|
||||
else:
|
||||
return ic, ric
|
||||
|
||||
|
||||
def calc_long_short_return(
|
||||
pred: pd.Series,
|
||||
label: pd.Series,
|
||||
date_col: str = "datetime",
|
||||
quantile: float = 0.2,
|
||||
dropna: bool = False,
|
||||
) -> Tuple[pd.Series, pd.Series]:
|
||||
"""
|
||||
calculate long-short return
|
||||
|
||||
Note:
|
||||
`label` must be raw stock returns.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pred : pd.Series
|
||||
stock predictions
|
||||
label : pd.Series
|
||||
stock returns
|
||||
date_col : str
|
||||
datetime index name
|
||||
quantile : float
|
||||
long-short quantile
|
||||
|
||||
Returns
|
||||
----------
|
||||
long_short_r : pd.Series
|
||||
daily long-short returns
|
||||
long_avg_r : pd.Series
|
||||
daily long-average returns
|
||||
"""
|
||||
df = pd.DataFrame({"pred": pred, "label": label})
|
||||
if dropna:
|
||||
df.dropna(inplace=True)
|
||||
group = df.groupby(level=date_col)
|
||||
N = lambda x: int(len(x) * quantile)
|
||||
r_long = group.apply(lambda x: x.nlargest(N(x), columns="pred").label.mean())
|
||||
r_short = group.apply(lambda x: x.nsmallest(N(x), columns="pred").label.mean())
|
||||
r_avg = group.label.mean()
|
||||
return (r_long - r_short) / 2, r_avg
|
||||
|
||||
@@ -14,7 +14,7 @@ from ..data.dataset.handler import DataHandlerLP
|
||||
from ..utils import init_instance_by_config, get_module_by_module_path
|
||||
from ..log import get_module_logger
|
||||
from ..utils import flatten_dict
|
||||
from ..contrib.eva.alpha import calc_ic
|
||||
from ..contrib.eva.alpha import calc_ic, calc_long_short_return
|
||||
|
||||
logger = get_module_logger("workflow", "INFO")
|
||||
|
||||
@@ -148,7 +148,9 @@ class SigAnaRecord(SignalRecord):
|
||||
|
||||
artifact_path = "sig_analysis"
|
||||
|
||||
def __init__(self, recorder, **kwargs):
|
||||
def __init__(self, recorder, ana_long_short=False, ann_scaler=252, **kwargs):
|
||||
self.ana_long_short = ana_long_short
|
||||
self.ann_scaler = ann_scaler
|
||||
super().__init__(recorder=recorder, **kwargs)
|
||||
# The name must be unique. Otherwise it will be overridden
|
||||
|
||||
@@ -164,12 +166,31 @@ class SigAnaRecord(SignalRecord):
|
||||
"Rank IC": ric.mean(),
|
||||
"Rank ICIR": ric.mean() / ric.std(),
|
||||
}
|
||||
objects = {
|
||||
'ic.pkl': ic,
|
||||
'ric.pkl': ric
|
||||
}
|
||||
if self.ana_long_short:
|
||||
long_short_r, long_avg_r = calc_long_short_return(pred.iloc[:, 0], label.iloc[:, 0])
|
||||
metrics.update({
|
||||
'Long-Short Ann Return': long_short_r.mean() * self.ann_scaler,
|
||||
'Long-Short Ann Sharpe': long_short_r.mean() / long_short_r.std() * self.ann_scaler ** 0.5,
|
||||
'Long-Avg Ann Return': long_avg_r.mean() * self.ann_scaler,
|
||||
'Long-Avg Ann Sharpe': long_avg_r.mean() / long_avg_r.std() * self.ann_scaler ** 0.5,
|
||||
})
|
||||
objects.update({
|
||||
'long_short_r.pkl': long_short_r,
|
||||
'long_avg_r.pkl': long_avg_r,
|
||||
})
|
||||
self.recorder.log_metrics(**metrics)
|
||||
self.recorder.save_objects(**{"ic.pkl": ic, "ric.pkl": ric}, artifact_path=self.get_path())
|
||||
self.recorder.save_objects(**objects, artifact_path=self.get_path())
|
||||
pprint(metrics)
|
||||
|
||||
def list(self):
|
||||
return [self.get_path("ic.pkl"), self.get_path("ric.pkl")]
|
||||
paths = [self.get_path("ic.pkl"), self.get_path("ric.pkl")]
|
||||
if self.ana_long_short:
|
||||
paths.extend([self.get_path('long_short_r.pkl'), self.get_path('long_avg_r.pkl')])
|
||||
return paths
|
||||
|
||||
|
||||
class PortAnaRecord(SignalRecord):
|
||||
|
||||
Reference in New Issue
Block a user