diff --git a/qlib/data/dataset/processor.py b/qlib/data/dataset/processor.py index 3970c8a0a..1e8442866 100755 --- a/qlib/data/dataset/processor.py +++ b/qlib/data/dataset/processor.py @@ -124,18 +124,19 @@ class ProcessInf(Processor): class Fillna(Processor): - """Process infinity """ + """Process NaN""" + + def __init__(self, fields_group=None, fill_value=0): + self.fields_group = fields_group + self.fill_value = fill_value def __call__(self, df): - def fill_na(df, columns=None, fill=0): - - if columns == None: - columns = df.columns - df[columns] = df[columns].fillna(fill) - - return df - - return fill_na(df) + if self.fields_group is None: + df.fillna(self.fill_value, inplace=True) + else: + cols = get_group_columns(df, self.fields_group) + df.fillna({col: self.fill_value for col in cols}, inplace=True) + return df class MinMaxNorm(Processor): @@ -203,3 +204,16 @@ class CSZScoreNorm(Processor): cols = get_group_columns(df, self.fields_group) df[cols] = df[cols].groupby("datetime").apply(lambda df: (df - df.mean()).div(df.std())) return df + + +class CSRankNorm(Processor): + """Cross Sectional Rank Normalization""" + + def __init__(self, fields_group=None): + self.fields_group = fields_group + + def __call__(self, df): + # try not modify original dataframe + cols = get_group_columns(df, self.fields_group) + df[cols] = df[cols].groupby("datetime").apply(lambda df: (df - df.mean()).div(df.std())) + return df