From b8d1e08010d03bc35a4979fff0962eac232e093d Mon Sep 17 00:00:00 2001 From: Christian Clauss Date: Tue, 14 Sep 2021 06:13:27 +0200 Subject: [PATCH] Fix undefined names in Python code (#599) * Update pytorch_tabnet.py $ `flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics` ``` ./qlib/qlib/contrib/model/pytorch_tabnet.py:567:38: F821 undefined name 'inp' self.independ.append(GLU(inp, out_dim, vbs=vbs)) ^ ./qlib/examples/model_rolling/task_manager_rolling.py:75:18: F821 undefined name 'task_train' run_task(task_train, self.task_pool, experiment_name=self.experiment_name) ^ 2 F821 undefined name 'task_train' 2 ``` * Fix undefined names in Python code * from qlib.model.trainer import task_train --- examples/model_rolling/task_manager_rolling.py | 2 +- qlib/contrib/model/pytorch_tabnet.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/model_rolling/task_manager_rolling.py b/examples/model_rolling/task_manager_rolling.py index 844f18198..091a87862 100644 --- a/examples/model_rolling/task_manager_rolling.py +++ b/examples/model_rolling/task_manager_rolling.py @@ -17,7 +17,7 @@ from qlib.workflow.task.gen import RollingGen, task_generator from qlib.workflow.task.manage import TaskManager, run_task from qlib.workflow.task.collect import RecorderCollector from qlib.model.ens.group import RollingGroup -from qlib.model.trainer import TrainerRM +from qlib.model.trainer import TrainerRM, task_train from qlib.tests.config import CSI100_RECORD_LGB_TASK_CONFIG, CSI100_RECORD_XGBOOST_TASK_CONFIG diff --git a/qlib/contrib/model/pytorch_tabnet.py b/qlib/contrib/model/pytorch_tabnet.py index b05d9a026..bd8f085ec 100644 --- a/qlib/contrib/model/pytorch_tabnet.py +++ b/qlib/contrib/model/pytorch_tabnet.py @@ -564,7 +564,7 @@ class FeatureTransformer(nn.Module): self.shared = None self.independ = nn.ModuleList() if first: - self.independ.append(GLU(inp, out_dim, vbs=vbs)) + self.independ.append(GLU(inp_dim, out_dim, vbs=vbs)) for x in range(first, n_ind): self.independ.append(GLU(out_dim, out_dim, vbs=vbs)) self.scale = float(np.sqrt(0.5))