From ae757a4b5198a75c8650548acef52aac9a33f73e Mon Sep 17 00:00:00 2001 From: bxdd Date: Fri, 27 Nov 2020 13:09:40 +0800 Subject: [PATCH] black format --- qlib/contrib/data/handler.py | 2 +- qlib/contrib/model/pytorch_alstm.py | 1 - qlib/contrib/model/pytorch_nn.py | 3 ++- qlib/data/dataset/processor.py | 7 +++++-- 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/qlib/contrib/data/handler.py b/qlib/contrib/data/handler.py index f74c2cebc..e97b00c24 100644 --- a/qlib/contrib/data/handler.py +++ b/qlib/contrib/data/handler.py @@ -226,7 +226,7 @@ class Alpha158(DataHandlerLP): data_loader=data_loader, infer_processors=infer_processors, learn_processors=learn_processors, - process_type=process_type + process_type=process_type, ) def get_feature_config(self): diff --git a/qlib/contrib/model/pytorch_alstm.py b/qlib/contrib/model/pytorch_alstm.py index 8f5ddc486..1b23d2401 100644 --- a/qlib/contrib/model/pytorch_alstm.py +++ b/qlib/contrib/model/pytorch_alstm.py @@ -146,7 +146,6 @@ class ALSTM(Model): raise ValueError("unknown metric `%s`" % self.metric) - def train_epoch(self, x_train, y_train): x_train_values = x_train.values diff --git a/qlib/contrib/model/pytorch_nn.py b/qlib/contrib/model/pytorch_nn.py index 47316ebf6..d324e27aa 100644 --- a/qlib/contrib/model/pytorch_nn.py +++ b/qlib/contrib/model/pytorch_nn.py @@ -22,6 +22,7 @@ from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, creat from ...log import get_module_logger, TimeInspector from ...workflow import R + class DNNModelPytorch(Model): """DNN Model @@ -349,7 +350,7 @@ class Net(nn.Module): def _weight_init(self): for m in self.modules(): if isinstance(m, nn.Linear): - nn.init.kaiming_normal_(m.weight, a=0.1, mode='fan_in', nonlinearity='leaky_relu') + nn.init.kaiming_normal_(m.weight, a=0.1, mode="fan_in", nonlinearity="leaky_relu") def forward(self, x): cur_output = x diff --git a/qlib/data/dataset/processor.py b/qlib/data/dataset/processor.py index fc85ccde9..76cf85c4a 100755 --- a/qlib/data/dataset/processor.py +++ b/qlib/data/dataset/processor.py @@ -100,7 +100,8 @@ class DropCol(Processor): else: mask = df.columns.isin(self.col_list) return df.loc[:, ~mask] - + + class TanhProcess(Processor): """ Use tanh to process noise data""" @@ -133,6 +134,7 @@ class ProcessInf(Processor): return replace_inf(df) + class Fillna(Processor): """Process NaN""" @@ -270,6 +272,7 @@ class CSRankNorm(Processor): df[cols] = t return df + class CSZFillna(Processor): """Cross Sectional Fill Nan""" @@ -279,4 +282,4 @@ class CSZFillna(Processor): def __call__(self, df): cols = get_group_columns(df, self.fields_group) df[cols] = df[cols].groupby("datetime").apply(lambda x: x.fillna(x.mean())) - return df \ No newline at end of file + return df