From 7f9f54faf40d3d270876220d2caadbb1dc4bab20 Mon Sep 17 00:00:00 2001 From: Dong Zhou Date: Tue, 24 Nov 2020 10:09:27 +0800 Subject: [PATCH] add CSRankNorm processor --- qlib/data/dataset/processor.py | 34 ++++++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/qlib/data/dataset/processor.py b/qlib/data/dataset/processor.py index 3970c8a0a..1e8442866 100755 --- a/qlib/data/dataset/processor.py +++ b/qlib/data/dataset/processor.py @@ -124,18 +124,19 @@ class ProcessInf(Processor): class Fillna(Processor): - """Process infinity """ + """Process NaN""" + + def __init__(self, fields_group=None, fill_value=0): + self.fields_group = fields_group + self.fill_value = fill_value def __call__(self, df): - def fill_na(df, columns=None, fill=0): - - if columns == None: - columns = df.columns - df[columns] = df[columns].fillna(fill) - - return df - - return fill_na(df) + if self.fields_group is None: + df.fillna(self.fill_value, inplace=True) + else: + cols = get_group_columns(df, self.fields_group) + df.fillna({col: self.fill_value for col in cols}, inplace=True) + return df class MinMaxNorm(Processor): @@ -203,3 +204,16 @@ class CSZScoreNorm(Processor): cols = get_group_columns(df, self.fields_group) df[cols] = df[cols].groupby("datetime").apply(lambda df: (df - df.mean()).div(df.std())) return df + + +class CSRankNorm(Processor): + """Cross Sectional Rank Normalization""" + + def __init__(self, fields_group=None): + self.fields_group = fields_group + + def __call__(self, df): + # try not modify original dataframe + cols = get_group_columns(df, self.fields_group) + df[cols] = df[cols].groupby("datetime").apply(lambda df: (df - df.mean()).div(df.std())) + return df