mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
black format
This commit is contained in:
@@ -1,12 +1,12 @@
|
||||
'''
|
||||
"""
|
||||
Here is a batch of evaluation functions.
|
||||
|
||||
The interface should be redesigned carefully in the future.
|
||||
'''
|
||||
"""
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def calc_ic(pred: pd.Series, label: pd.Series, date_col='datetime', dropna=False) -> (pd.Series, pd.Series):
|
||||
def calc_ic(pred: pd.Series, label: pd.Series, date_col="datetime", dropna=False) -> (pd.Series, pd.Series):
|
||||
"""calc_ic.
|
||||
|
||||
Parameters
|
||||
@@ -23,9 +23,9 @@ def calc_ic(pred: pd.Series, label: pd.Series, date_col='datetime', dropna=False
|
||||
(pd.Series, pd.Series)
|
||||
ic and rank ic
|
||||
"""
|
||||
df = pd.DataFrame({'pred': pred, 'label': label})
|
||||
ic = df.groupby(date_col).apply(lambda df: df['pred'].corr(df['label']))
|
||||
ric = df.groupby(date_col).apply(lambda df: df['pred'].corr(df['label'], method='spearman'))
|
||||
df = pd.DataFrame({"pred": pred, "label": label})
|
||||
ic = df.groupby(date_col).apply(lambda df: df["pred"].corr(df["label"]))
|
||||
ric = df.groupby(date_col).apply(lambda df: df["pred"].corr(df["label"], method="spearman"))
|
||||
if dropna:
|
||||
return ic.dropna(), ric.dropna()
|
||||
else:
|
||||
|
||||
@@ -100,11 +100,13 @@ class DatasetH(Dataset):
|
||||
self._handler = init_instance_by_config(handler, accept_types=DataHandler)
|
||||
self._segments = segments.copy()
|
||||
|
||||
def prepare(self,
|
||||
segments: Union[List[str], Tuple[str], str, slice],
|
||||
col_set=DataHandler.CS_ALL,
|
||||
data_key=DataHandlerLP.DK_I,
|
||||
**kwargs) -> Union[List[pd.DataFrame], pd.DataFrame]:
|
||||
def prepare(
|
||||
self,
|
||||
segments: Union[List[str], Tuple[str], str, slice],
|
||||
col_set=DataHandler.CS_ALL,
|
||||
data_key=DataHandlerLP.DK_I,
|
||||
**kwargs,
|
||||
) -> Union[List[pd.DataFrame], pd.DataFrame]:
|
||||
"""
|
||||
prepare the data for learning and inference
|
||||
|
||||
@@ -132,8 +134,8 @@ class DatasetH(Dataset):
|
||||
logger = get_module_logger("DatasetH")
|
||||
fetch_kwargs = {"col_set": col_set}
|
||||
fetch_kwargs.update(kwargs)
|
||||
if "data_key"in getfullargspec(self._handler.fetch).args:
|
||||
fetch_kwargs['data_key'] = data_key
|
||||
if "data_key" in getfullargspec(self._handler.fetch).args:
|
||||
fetch_kwargs["data_key"] = data_key
|
||||
else:
|
||||
logger.info(f"data_key[{data_key}] is ignored.")
|
||||
|
||||
|
||||
@@ -10,14 +10,14 @@ class Serializable:
|
||||
Serializable behaves like pickle.
|
||||
But it only saves the state whose name **does not** start with `_`
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._dump_all = False
|
||||
self._exclude = []
|
||||
|
||||
def __getstate__(self) -> dict:
|
||||
return {
|
||||
k: v
|
||||
for k, v in self.__dict__.items() if k not in self.exclude and (self.dump_all or not k.startswith("_"))
|
||||
k: v for k, v in self.__dict__.items() if k not in self.exclude and (self.dump_all or not k.startswith("_"))
|
||||
}
|
||||
|
||||
def __setstate__(self, state: dict):
|
||||
|
||||
@@ -251,7 +251,7 @@ class MLflowExpManager(ExpManager):
|
||||
self.active_experiment = None
|
||||
|
||||
def create_exp(self, experiment_name=None):
|
||||
assert(experiment_name is not None)
|
||||
assert experiment_name is not None
|
||||
# init experiment
|
||||
experiment_id = self.client.create_experiment(experiment_name)
|
||||
experiment = MLflowExperiment(experiment_id, experiment_name, self.uri)
|
||||
|
||||
@@ -119,7 +119,7 @@ class SignalRecord(RecordTemp):
|
||||
raw_label = DatasetH.prepare(**params)
|
||||
except TypeError:
|
||||
# The argument number is not right
|
||||
del params['data_key']
|
||||
del params["data_key"]
|
||||
# The backend handler should be DataHandler
|
||||
raw_label = DatasetH.prepare(**params)
|
||||
self.recorder.save_objects(**{"label.pkl": raw_label})
|
||||
@@ -147,7 +147,7 @@ class SigAnaRecord(SignalRecord):
|
||||
"IC": ic.mean(),
|
||||
"ICIR": ic.mean() / ic.std(),
|
||||
"Rank IC": ric.mean(),
|
||||
"Rank ICIR": ric.mean() / ric.std()
|
||||
"Rank ICIR": ric.mean() / ric.std(),
|
||||
}
|
||||
self.recorder.log_metrics(**metrics)
|
||||
self.recorder.save_objects(**{"ic.pkl": ic, "ric.pkl": ric}, artifact_path=self.artifact_path_sig)
|
||||
|
||||
Reference in New Issue
Block a user