diff --git a/examples/model_rolling/task_manager_rolling.py b/examples/model_rolling/task_manager_rolling.py index 844f18198..091a87862 100644 --- a/examples/model_rolling/task_manager_rolling.py +++ b/examples/model_rolling/task_manager_rolling.py @@ -17,7 +17,7 @@ from qlib.workflow.task.gen import RollingGen, task_generator from qlib.workflow.task.manage import TaskManager, run_task from qlib.workflow.task.collect import RecorderCollector from qlib.model.ens.group import RollingGroup -from qlib.model.trainer import TrainerRM +from qlib.model.trainer import TrainerRM, task_train from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG diff --git a/qlib/contrib/model/pytorch_tabnet.py b/qlib/contrib/model/pytorch_tabnet.py index b05d9a026..bd8f085ec 100644 --- a/qlib/contrib/model/pytorch_tabnet.py +++ b/qlib/contrib/model/pytorch_tabnet.py @@ -564,7 +564,7 @@ class FeatureTransformer(nn.Module): self.shared = None self.independ = nn.ModuleList() if first: - self.independ.append(GLU(inp, out_dim, vbs=vbs)) + self.independ.append(GLU(inp_dim, out_dim, vbs=vbs)) for x in range(first, n_ind): self.independ.append(GLU(out_dim, out_dim, vbs=vbs)) self.scale = float(np.sqrt(0.5))