mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-04 19:41:00 +08:00
black format
This commit is contained in:
@@ -226,7 +226,7 @@ class Alpha158(DataHandlerLP):
|
||||
data_loader=data_loader,
|
||||
infer_processors=infer_processors,
|
||||
learn_processors=learn_processors,
|
||||
process_type=process_type
|
||||
process_type=process_type,
|
||||
)
|
||||
|
||||
def get_feature_config(self):
|
||||
|
||||
@@ -146,7 +146,6 @@ class ALSTM(Model):
|
||||
|
||||
raise ValueError("unknown metric `%s`" % self.metric)
|
||||
|
||||
|
||||
def train_epoch(self, x_train, y_train):
|
||||
|
||||
x_train_values = x_train.values
|
||||
|
||||
@@ -22,6 +22,7 @@ from ...utils import unpack_archive_with_buffer, save_multiple_parts_file, creat
|
||||
from ...log import get_module_logger, TimeInspector
|
||||
from ...workflow import R
|
||||
|
||||
|
||||
class DNNModelPytorch(Model):
|
||||
"""DNN Model
|
||||
|
||||
@@ -349,7 +350,7 @@ class Net(nn.Module):
|
||||
def _weight_init(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Linear):
|
||||
nn.init.kaiming_normal_(m.weight, a=0.1, mode='fan_in', nonlinearity='leaky_relu')
|
||||
nn.init.kaiming_normal_(m.weight, a=0.1, mode="fan_in", nonlinearity="leaky_relu")
|
||||
|
||||
def forward(self, x):
|
||||
cur_output = x
|
||||
|
||||
@@ -100,7 +100,8 @@ class DropCol(Processor):
|
||||
else:
|
||||
mask = df.columns.isin(self.col_list)
|
||||
return df.loc[:, ~mask]
|
||||
|
||||
|
||||
|
||||
class TanhProcess(Processor):
|
||||
""" Use tanh to process noise data"""
|
||||
|
||||
@@ -133,6 +134,7 @@ class ProcessInf(Processor):
|
||||
|
||||
return replace_inf(df)
|
||||
|
||||
|
||||
class Fillna(Processor):
|
||||
"""Process NaN"""
|
||||
|
||||
@@ -270,6 +272,7 @@ class CSRankNorm(Processor):
|
||||
df[cols] = t
|
||||
return df
|
||||
|
||||
|
||||
class CSZFillna(Processor):
|
||||
"""Cross Sectional Fill Nan"""
|
||||
|
||||
@@ -279,4 +282,4 @@ class CSZFillna(Processor):
|
||||
def __call__(self, df):
|
||||
cols = get_group_columns(df, self.fields_group)
|
||||
df[cols] = df[cols].groupby("datetime").apply(lambda x: x.fillna(x.mean()))
|
||||
return df
|
||||
return df
|
||||
|
||||
Reference in New Issue
Block a user