1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-06-06 05:51:17 +08:00
This commit is contained in:
Jactus
2020-11-24 16:42:32 +08:00
parent acae9087e9
commit 93ce9a4cb2
4 changed files with 30 additions and 35 deletions

View File

@@ -8,9 +8,7 @@ import pandas as pd
from typing import Tuple
def calc_ic(
pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False
) -> Tuple[pd.Series, pd.Series]:
def calc_ic(pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False) -> Tuple[pd.Series, pd.Series]:
"""calc_ic.
Parameters
@@ -29,9 +27,7 @@ def calc_ic(
"""
df = pd.DataFrame({"pred": pred, "label": label})
ic = df.groupby(date_col).apply(lambda df: df["pred"].corr(df["label"]))
ric = df.groupby(date_col).apply(
lambda df: df["pred"].corr(df["label"], method="spearman")
)
ric = df.groupby(date_col).apply(lambda df: df["pred"].corr(df["label"], method="spearman"))
if dropna:
return ic.dropna(), ric.dropna()
else:

View File

@@ -143,8 +143,8 @@ class QlibDataLoader(DLWParser):
def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None) -> pd.DataFrame:
if instruments is None:
warnings.warn('`instruments` is not set, will load all stocks')
instruments = 'all'
warnings.warn("`instruments` is not set, will load all stocks")
instruments = "all"
if isinstance(instruments, str):
instruments = D.instruments(instruments, filter_pipe=self.filter_pipe)
elif self.filter_pipe is not None:
@@ -161,7 +161,9 @@ class StaticDataLoader(DataLoader):
DataLoader that supports loading data from file or as provided.
"""
def __init__(self, feature_path_or_obj: Union[str, pd.DataFrame], label_path_or_obj: Union[str, pd.DataFrame] = None):
def __init__(
self, feature_path_or_obj: Union[str, pd.DataFrame], label_path_or_obj: Union[str, pd.DataFrame] = None
):
"""
Parameters
----------
@@ -192,22 +194,18 @@ class StaticDataLoader(DataLoader):
df = self._data.loc(axis=0)[:, instruments]
if start_time is None and end_time is None:
return df # NOTE: avoid copy by loc
return df.loc[pd.Timestamp(start_time):pd.Timestamp(end_time)]
return df.loc[pd.Timestamp(start_time) : pd.Timestamp(end_time)]
def _maybe_load_raw_data(self):
if self._data is not None:
return
self._data = load_dataset(self._feature_path_or_obj)
if self._label_path_or_obj is not None:
self._data = pd.concat(
{"feature": self._data, "label": load_dataset(self._label_path_or_obj)}, axis=1
)
self._data = pd.concat({"feature": self._data, "label": load_dataset(self._label_path_or_obj)}, axis=1)
if not isinstance(self._data.columns, pd.MultiIndex):
self._data.columns = pd.MultiIndex.from_arrays(
[
np.array(["feature", "label"])[
self._data.columns.str.contains("^LABEL").astype(int)
],
np.array(["feature", "label"])[self._data.columns.str.contains("^LABEL").astype(int)],
self._data.columns,
]
)

View File

@@ -702,10 +702,10 @@ def load_dataset(path_or_obj):
if isinstance(path_or_obj, pd.DataFrame):
return path_or_obj
_, extension = os.path.splitext(path_or_obj)
if extension == '.h5':
if extension == ".h5":
return pd.read_hdf(path_or_obj)
elif extension == '.pkl':
elif extension == ".pkl":
return pd.read_pickle(path_or_obj)
elif extension == '.csv':
elif extension == ".csv":
return pd.read_csv(path_or_obj, parse_dates=True, index_col=[0, 1])
raise ValueError(f'unsupported file type `{extension}`')
raise ValueError(f"unsupported file type `{extension}`")

View File

@@ -166,22 +166,23 @@ class SigAnaRecord(SignalRecord):
"Rank IC": ric.mean(),
"Rank ICIR": ric.mean() / ric.std(),
}
objects = {
'ic.pkl': ic,
'ric.pkl': ric
}
objects = {"ic.pkl": ic, "ric.pkl": ric}
if self.ana_long_short:
long_short_r, long_avg_r = calc_long_short_return(pred.iloc[:, 0], label.iloc[:, 0])
metrics.update({
'Long-Short Ann Return': long_short_r.mean() * self.ann_scaler,
'Long-Short Ann Sharpe': long_short_r.mean() / long_short_r.std() * self.ann_scaler ** 0.5,
'Long-Avg Ann Return': long_avg_r.mean() * self.ann_scaler,
'Long-Avg Ann Sharpe': long_avg_r.mean() / long_avg_r.std() * self.ann_scaler ** 0.5,
})
objects.update({
'long_short_r.pkl': long_short_r,
'long_avg_r.pkl': long_avg_r,
})
metrics.update(
{
"Long-Short Ann Return": long_short_r.mean() * self.ann_scaler,
"Long-Short Ann Sharpe": long_short_r.mean() / long_short_r.std() * self.ann_scaler ** 0.5,
"Long-Avg Ann Return": long_avg_r.mean() * self.ann_scaler,
"Long-Avg Ann Sharpe": long_avg_r.mean() / long_avg_r.std() * self.ann_scaler ** 0.5,
}
)
objects.update(
{
"long_short_r.pkl": long_short_r,
"long_avg_r.pkl": long_avg_r,
}
)
self.recorder.log_metrics(**metrics)
self.recorder.save_objects(**objects, artifact_path=self.get_path())
pprint(metrics)
@@ -189,7 +190,7 @@ class SigAnaRecord(SignalRecord):
def list(self):
paths = [self.get_path("ic.pkl"), self.get_path("ric.pkl")]
if self.ana_long_short:
paths.extend([self.get_path('long_short_r.pkl'), self.get_path('long_avg_r.pkl')])
paths.extend([self.get_path("long_short_r.pkl"), self.get_path("long_avg_r.pkl")])
return paths