mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
Format
This commit is contained in:
@@ -8,9 +8,7 @@ import pandas as pd
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
def calc_ic(
|
||||
pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False
|
||||
) -> Tuple[pd.Series, pd.Series]:
|
||||
def calc_ic(pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False) -> Tuple[pd.Series, pd.Series]:
|
||||
"""calc_ic.
|
||||
|
||||
Parameters
|
||||
@@ -29,9 +27,7 @@ def calc_ic(
|
||||
"""
|
||||
df = pd.DataFrame({"pred": pred, "label": label})
|
||||
ic = df.groupby(date_col).apply(lambda df: df["pred"].corr(df["label"]))
|
||||
ric = df.groupby(date_col).apply(
|
||||
lambda df: df["pred"].corr(df["label"], method="spearman")
|
||||
)
|
||||
ric = df.groupby(date_col).apply(lambda df: df["pred"].corr(df["label"], method="spearman"))
|
||||
if dropna:
|
||||
return ic.dropna(), ric.dropna()
|
||||
else:
|
||||
|
||||
@@ -143,8 +143,8 @@ class QlibDataLoader(DLWParser):
|
||||
|
||||
def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None) -> pd.DataFrame:
|
||||
if instruments is None:
|
||||
warnings.warn('`instruments` is not set, will load all stocks')
|
||||
instruments = 'all'
|
||||
warnings.warn("`instruments` is not set, will load all stocks")
|
||||
instruments = "all"
|
||||
if isinstance(instruments, str):
|
||||
instruments = D.instruments(instruments, filter_pipe=self.filter_pipe)
|
||||
elif self.filter_pipe is not None:
|
||||
@@ -161,7 +161,9 @@ class StaticDataLoader(DataLoader):
|
||||
DataLoader that supports loading data from file or as provided.
|
||||
"""
|
||||
|
||||
def __init__(self, feature_path_or_obj: Union[str, pd.DataFrame], label_path_or_obj: Union[str, pd.DataFrame] = None):
|
||||
def __init__(
|
||||
self, feature_path_or_obj: Union[str, pd.DataFrame], label_path_or_obj: Union[str, pd.DataFrame] = None
|
||||
):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
@@ -192,22 +194,18 @@ class StaticDataLoader(DataLoader):
|
||||
df = self._data.loc(axis=0)[:, instruments]
|
||||
if start_time is None and end_time is None:
|
||||
return df # NOTE: avoid copy by loc
|
||||
return df.loc[pd.Timestamp(start_time):pd.Timestamp(end_time)]
|
||||
return df.loc[pd.Timestamp(start_time) : pd.Timestamp(end_time)]
|
||||
|
||||
def _maybe_load_raw_data(self):
|
||||
if self._data is not None:
|
||||
return
|
||||
self._data = load_dataset(self._feature_path_or_obj)
|
||||
if self._label_path_or_obj is not None:
|
||||
self._data = pd.concat(
|
||||
{"feature": self._data, "label": load_dataset(self._label_path_or_obj)}, axis=1
|
||||
)
|
||||
self._data = pd.concat({"feature": self._data, "label": load_dataset(self._label_path_or_obj)}, axis=1)
|
||||
if not isinstance(self._data.columns, pd.MultiIndex):
|
||||
self._data.columns = pd.MultiIndex.from_arrays(
|
||||
[
|
||||
np.array(["feature", "label"])[
|
||||
self._data.columns.str.contains("^LABEL").astype(int)
|
||||
],
|
||||
np.array(["feature", "label"])[self._data.columns.str.contains("^LABEL").astype(int)],
|
||||
self._data.columns,
|
||||
]
|
||||
)
|
||||
|
||||
@@ -702,10 +702,10 @@ def load_dataset(path_or_obj):
|
||||
if isinstance(path_or_obj, pd.DataFrame):
|
||||
return path_or_obj
|
||||
_, extension = os.path.splitext(path_or_obj)
|
||||
if extension == '.h5':
|
||||
if extension == ".h5":
|
||||
return pd.read_hdf(path_or_obj)
|
||||
elif extension == '.pkl':
|
||||
elif extension == ".pkl":
|
||||
return pd.read_pickle(path_or_obj)
|
||||
elif extension == '.csv':
|
||||
elif extension == ".csv":
|
||||
return pd.read_csv(path_or_obj, parse_dates=True, index_col=[0, 1])
|
||||
raise ValueError(f'unsupported file type `{extension}`')
|
||||
raise ValueError(f"unsupported file type `{extension}`")
|
||||
|
||||
@@ -166,22 +166,23 @@ class SigAnaRecord(SignalRecord):
|
||||
"Rank IC": ric.mean(),
|
||||
"Rank ICIR": ric.mean() / ric.std(),
|
||||
}
|
||||
objects = {
|
||||
'ic.pkl': ic,
|
||||
'ric.pkl': ric
|
||||
}
|
||||
objects = {"ic.pkl": ic, "ric.pkl": ric}
|
||||
if self.ana_long_short:
|
||||
long_short_r, long_avg_r = calc_long_short_return(pred.iloc[:, 0], label.iloc[:, 0])
|
||||
metrics.update({
|
||||
'Long-Short Ann Return': long_short_r.mean() * self.ann_scaler,
|
||||
'Long-Short Ann Sharpe': long_short_r.mean() / long_short_r.std() * self.ann_scaler ** 0.5,
|
||||
'Long-Avg Ann Return': long_avg_r.mean() * self.ann_scaler,
|
||||
'Long-Avg Ann Sharpe': long_avg_r.mean() / long_avg_r.std() * self.ann_scaler ** 0.5,
|
||||
})
|
||||
objects.update({
|
||||
'long_short_r.pkl': long_short_r,
|
||||
'long_avg_r.pkl': long_avg_r,
|
||||
})
|
||||
metrics.update(
|
||||
{
|
||||
"Long-Short Ann Return": long_short_r.mean() * self.ann_scaler,
|
||||
"Long-Short Ann Sharpe": long_short_r.mean() / long_short_r.std() * self.ann_scaler ** 0.5,
|
||||
"Long-Avg Ann Return": long_avg_r.mean() * self.ann_scaler,
|
||||
"Long-Avg Ann Sharpe": long_avg_r.mean() / long_avg_r.std() * self.ann_scaler ** 0.5,
|
||||
}
|
||||
)
|
||||
objects.update(
|
||||
{
|
||||
"long_short_r.pkl": long_short_r,
|
||||
"long_avg_r.pkl": long_avg_r,
|
||||
}
|
||||
)
|
||||
self.recorder.log_metrics(**metrics)
|
||||
self.recorder.save_objects(**objects, artifact_path=self.get_path())
|
||||
pprint(metrics)
|
||||
@@ -189,7 +190,7 @@ class SigAnaRecord(SignalRecord):
|
||||
def list(self):
|
||||
paths = [self.get_path("ic.pkl"), self.get_path("ric.pkl")]
|
||||
if self.ana_long_short:
|
||||
paths.extend([self.get_path('long_short_r.pkl'), self.get_path('long_avg_r.pkl')])
|
||||
paths.extend([self.get_path("long_short_r.pkl"), self.get_path("long_avg_r.pkl")])
|
||||
return paths
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user