diff --git a/qlib/contrib/eva/alpha.py b/qlib/contrib/eva/alpha.py index e00fbfe25..c68571853 100644 --- a/qlib/contrib/eva/alpha.py +++ b/qlib/contrib/eva/alpha.py @@ -8,9 +8,7 @@ import pandas as pd from typing import Tuple -def calc_ic( - pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False -) -> Tuple[pd.Series, pd.Series]: +def calc_ic(pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False) -> Tuple[pd.Series, pd.Series]: """calc_ic. Parameters @@ -29,9 +27,7 @@ def calc_ic( """ df = pd.DataFrame({"pred": pred, "label": label}) ic = df.groupby(date_col).apply(lambda df: df["pred"].corr(df["label"])) - ric = df.groupby(date_col).apply( - lambda df: df["pred"].corr(df["label"], method="spearman") - ) + ric = df.groupby(date_col).apply(lambda df: df["pred"].corr(df["label"], method="spearman")) if dropna: return ic.dropna(), ric.dropna() else: diff --git a/qlib/data/dataset/loader.py b/qlib/data/dataset/loader.py index 7e8dd507c..eddbca044 100644 --- a/qlib/data/dataset/loader.py +++ b/qlib/data/dataset/loader.py @@ -143,8 +143,8 @@ class QlibDataLoader(DLWParser): def load_group_df(self, instruments, exprs: list, names: list, start_time=None, end_time=None) -> pd.DataFrame: if instruments is None: - warnings.warn('`instruments` is not set, will load all stocks') - instruments = 'all' + warnings.warn("`instruments` is not set, will load all stocks") + instruments = "all" if isinstance(instruments, str): instruments = D.instruments(instruments, filter_pipe=self.filter_pipe) elif self.filter_pipe is not None: @@ -161,7 +161,9 @@ class StaticDataLoader(DataLoader): DataLoader that supports loading data from file or as provided. """ - def __init__(self, feature_path_or_obj: Union[str, pd.DataFrame], label_path_or_obj: Union[str, pd.DataFrame] = None): + def __init__( + self, feature_path_or_obj: Union[str, pd.DataFrame], label_path_or_obj: Union[str, pd.DataFrame] = None + ): """ Parameters ---------- @@ -192,22 +194,18 @@ class StaticDataLoader(DataLoader): df = self._data.loc(axis=0)[:, instruments] if start_time is None and end_time is None: return df # NOTE: avoid copy by loc - return df.loc[pd.Timestamp(start_time):pd.Timestamp(end_time)] + return df.loc[pd.Timestamp(start_time) : pd.Timestamp(end_time)] def _maybe_load_raw_data(self): if self._data is not None: return self._data = load_dataset(self._feature_path_or_obj) if self._label_path_or_obj is not None: - self._data = pd.concat( - {"feature": self._data, "label": load_dataset(self._label_path_or_obj)}, axis=1 - ) + self._data = pd.concat({"feature": self._data, "label": load_dataset(self._label_path_or_obj)}, axis=1) if not isinstance(self._data.columns, pd.MultiIndex): self._data.columns = pd.MultiIndex.from_arrays( [ - np.array(["feature", "label"])[ - self._data.columns.str.contains("^LABEL").astype(int) - ], + np.array(["feature", "label"])[self._data.columns.str.contains("^LABEL").astype(int)], self._data.columns, ] ) diff --git a/qlib/utils/__init__.py b/qlib/utils/__init__.py index c77c67fa2..8a1436799 100644 --- a/qlib/utils/__init__.py +++ b/qlib/utils/__init__.py @@ -702,10 +702,10 @@ def load_dataset(path_or_obj): if isinstance(path_or_obj, pd.DataFrame): return path_or_obj _, extension = os.path.splitext(path_or_obj) - if extension == '.h5': + if extension == ".h5": return pd.read_hdf(path_or_obj) - elif extension == '.pkl': + elif extension == ".pkl": return pd.read_pickle(path_or_obj) - elif extension == '.csv': + elif extension == ".csv": return pd.read_csv(path_or_obj, parse_dates=True, index_col=[0, 1]) - raise ValueError(f'unsupported file type `{extension}`') + raise ValueError(f"unsupported file type `{extension}`") diff --git a/qlib/workflow/record_temp.py b/qlib/workflow/record_temp.py index ffb339278..81b0022c5 100644 --- a/qlib/workflow/record_temp.py +++ b/qlib/workflow/record_temp.py @@ -166,22 +166,23 @@ class SigAnaRecord(SignalRecord): "Rank IC": ric.mean(), "Rank ICIR": ric.mean() / ric.std(), } - objects = { - 'ic.pkl': ic, - 'ric.pkl': ric - } + objects = {"ic.pkl": ic, "ric.pkl": ric} if self.ana_long_short: long_short_r, long_avg_r = calc_long_short_return(pred.iloc[:, 0], label.iloc[:, 0]) - metrics.update({ - 'Long-Short Ann Return': long_short_r.mean() * self.ann_scaler, - 'Long-Short Ann Sharpe': long_short_r.mean() / long_short_r.std() * self.ann_scaler ** 0.5, - 'Long-Avg Ann Return': long_avg_r.mean() * self.ann_scaler, - 'Long-Avg Ann Sharpe': long_avg_r.mean() / long_avg_r.std() * self.ann_scaler ** 0.5, - }) - objects.update({ - 'long_short_r.pkl': long_short_r, - 'long_avg_r.pkl': long_avg_r, - }) + metrics.update( + { + "Long-Short Ann Return": long_short_r.mean() * self.ann_scaler, + "Long-Short Ann Sharpe": long_short_r.mean() / long_short_r.std() * self.ann_scaler ** 0.5, + "Long-Avg Ann Return": long_avg_r.mean() * self.ann_scaler, + "Long-Avg Ann Sharpe": long_avg_r.mean() / long_avg_r.std() * self.ann_scaler ** 0.5, + } + ) + objects.update( + { + "long_short_r.pkl": long_short_r, + "long_avg_r.pkl": long_avg_r, + } + ) self.recorder.log_metrics(**metrics) self.recorder.save_objects(**objects, artifact_path=self.get_path()) pprint(metrics) @@ -189,7 +190,7 @@ class SigAnaRecord(SignalRecord): def list(self): paths = [self.get_path("ic.pkl"), self.get_path("ric.pkl")] if self.ana_long_short: - paths.extend([self.get_path('long_short_r.pkl'), self.get_path('long_avg_r.pkl')]) + paths.extend([self.get_path("long_short_r.pkl"), self.get_path("long_avg_r.pkl")]) return paths