diff --git a/qlib/model/base.py b/qlib/model/base.py index 5a295787f..a7001f0a6 100644 --- a/qlib/model/base.py +++ b/qlib/model/base.py @@ -43,8 +43,8 @@ class Model(BaseModel): # get weights try: - wdf_train, wdf_valid = dataset.prepare(["train", "valid"], col_set=["weight"], data_key=DataHandlerLP.DK_L) - w_train, w_valid = wdf_train["weight"], wdf_valid["weight"] + wdf_train, wdf_valid = dataset.prepare(["train", "valid"], col_set=["weight"], + data_key=DataHandlerLP.DK_L, w_train, w_valid = wdf_train["weight"], wdf_valid["weight"] except KeyError as e: w_train = pd.DataFrame(np.ones_like(y_train.values), index=y_train.index) w_valid = pd.DataFrame(np.ones_like(y_valid.values), index=y_valid.index)