mirror of
https://github.com/microsoft/qlib.git
synced 2026-06-06 05:51:17 +08:00
fixed a problem with multi index caused by the default value of groupkey (#1917)
* fixed a problem with multi index caused by the default value of groupkey * modify group_key default value * limit pandas verion * format with black * fix docs error * fix docs error * fixed bugs caused by pandas upgrade * remove needless code * reformat with black * limit version & add docs
This commit is contained in:
@@ -599,7 +599,7 @@ class TemporalFusionTransformer:
|
||||
print("Getting valid sampling locations.")
|
||||
valid_sampling_locations = []
|
||||
split_data_map = {}
|
||||
for identifier, df in data.groupby(id_col):
|
||||
for identifier, df in data.groupby(id_col, group_key=False):
|
||||
print("Getting locations for {}".format(identifier))
|
||||
num_entries = len(df)
|
||||
if num_entries >= self.time_steps:
|
||||
@@ -678,7 +678,7 @@ class TemporalFusionTransformer:
|
||||
input_cols = [tup[0] for tup in self.column_definition if tup[2] not in {InputTypes.ID, InputTypes.TIME}]
|
||||
|
||||
data_map = {}
|
||||
for _, sliced in data.groupby(id_col):
|
||||
for _, sliced in data.groupby(id_col, group_keys=False):
|
||||
col_mappings = {"identifier": [id_col], "time": [time_col], "outputs": [target_col], "inputs": input_cols}
|
||||
|
||||
for k in col_mappings:
|
||||
|
||||
@@ -78,13 +78,15 @@ DATASET_SETTING = {
|
||||
|
||||
|
||||
def get_shifted_label(data_df, shifts=5, col_shift="LABEL0"):
|
||||
return data_df[[col_shift]].groupby("instrument").apply(lambda df: df.shift(shifts))
|
||||
return data_df[[col_shift]].groupby("instrument", group_keys=False).apply(lambda df: df.shift(shifts))
|
||||
|
||||
|
||||
def fill_test_na(test_df):
|
||||
test_df_res = test_df.copy()
|
||||
feature_cols = ~test_df_res.columns.str.contains("label", case=False)
|
||||
test_feature_fna = test_df_res.loc[:, feature_cols].groupby("datetime").apply(lambda df: df.fillna(df.mean()))
|
||||
test_feature_fna = (
|
||||
test_df_res.loc[:, feature_cols].groupby("datetime", group_keys=False).apply(lambda df: df.fillna(df.mean()))
|
||||
)
|
||||
test_df_res.loc[:, feature_cols] = test_feature_fna
|
||||
return test_df_res
|
||||
|
||||
|
||||
Reference in New Issue
Block a user