diff --git a/qlib/contrib/model/pytorch_tabnet.py b/qlib/contrib/model/pytorch_tabnet.py index c1dce9308..6fcabfd21 100644 --- a/qlib/contrib/model/pytorch_tabnet.py +++ b/qlib/contrib/model/pytorch_tabnet.py @@ -93,12 +93,8 @@ class TabnetModel(Model): np.random.seed(self.seed) torch.manual_seed(self.seed) - self.tabnet_model = TabNet( - inp_dim=self.d_feat, out_dim=self.out_dim, vbs=vbs, relax=relax, device=self.device - ).to(self.device) - self.tabnet_decoder = TabNet_Decoder(self.out_dim, self.d_feat, n_shared, n_ind, vbs, n_steps, self.device).to( - self.device - ) + self.tabnet_model = TabNet(inp_dim=self.d_feat, out_dim=self.out_dim, vbs=vbs, relax=relax).to(self.device) + self.tabnet_decoder = TabNet_Decoder(self.out_dim, self.d_feat, n_shared, n_ind, vbs, n_steps).to(self.device) self.logger.info("model:\n{:}\n{:}".format(self.tabnet_model, self.tabnet_decoder)) self.logger.info("model size: {:.4f} MB".format(count_parameters([self.tabnet_model, self.tabnet_decoder]))) @@ -401,7 +397,7 @@ class FinetuneModel(nn.Module): """ def __init__(self, input_dim, output_dim, trained_model): - super().__init__() + super(FinetuneModel, self).__init__() self.model = trained_model self.fc = nn.Linear(input_dim, output_dim) @@ -410,9 +406,9 @@ class FinetuneModel(nn.Module): class DecoderStep(nn.Module): - def __init__(self, inp_dim, out_dim, shared, n_ind, vbs, device): - super().__init__() - self.fea_tran = FeatureTransformer(inp_dim, out_dim, shared, n_ind, vbs, device) + def __init__(self, inp_dim, out_dim, shared, n_ind, vbs): + super(DecoderStep, self).__init__() + self.fea_tran = FeatureTransformer(inp_dim, out_dim, shared, n_ind, vbs) self.fc = nn.Linear(out_dim, out_dim) def forward(self, x): @@ -421,13 +417,13 @@ class DecoderStep(nn.Module): class TabNet_Decoder(nn.Module): - def __init__(self, inp_dim, out_dim, n_shared, n_ind, vbs, n_steps, device): + def __init__(self, inp_dim, out_dim, n_shared, n_ind, vbs, n_steps): """ TabNet decoder that is used in pre-training """ self.out_dim = out_dim - super().__init__() + super(TabNet_Decoder, self).__init__() if n_shared > 0: self.shared = nn.ModuleList() self.shared.append(nn.Linear(inp_dim, 2 * out_dim)) @@ -438,7 +434,7 @@ class TabNet_Decoder(nn.Module): self.n_steps = n_steps self.steps = nn.ModuleList() for x in range(n_steps): - self.steps.append(DecoderStep(inp_dim, out_dim, self.shared, n_ind, vbs, device)) + self.steps.append(DecoderStep(inp_dim, out_dim, self.shared, n_ind, vbs)) def forward(self, x): out = torch.zeros(x.size(0), self.out_dim).to(x.device) @@ -448,9 +444,7 @@ class TabNet_Decoder(nn.Module): class TabNet(nn.Module): - def __init__( - self, inp_dim=6, out_dim=6, n_d=64, n_a=64, n_shared=2, n_ind=2, n_steps=5, relax=1.2, vbs=1024, device="cpu" - ): + def __init__(self, inp_dim=6, out_dim=6, n_d=64, n_a=64, n_shared=2, n_ind=2, n_steps=5, relax=1.2, vbs=1024): """ TabNet AKA the original encoder @@ -463,7 +457,7 @@ class TabNet(nn.Module): relax coefficient: virtual batch size: """ - super().__init__() + super(TabNet, self).__init__() # set the number of shared step in feature transformer if n_shared > 0: @@ -474,10 +468,10 @@ class TabNet(nn.Module): else: self.shared = None - self.first_step = FeatureTransformer(inp_dim, n_d + n_a, self.shared, n_ind, vbs, device) + self.first_step = FeatureTransformer(inp_dim, n_d + n_a, self.shared, n_ind, vbs) self.steps = nn.ModuleList() for x in range(n_steps - 1): - self.steps.append(DecisionStep(inp_dim, n_d, n_a, self.shared, n_ind, relax, vbs, device)) + self.steps.append(DecisionStep(inp_dim, n_d, n_a, self.shared, n_ind, relax, vbs)) self.fc = nn.Linear(n_d, out_dim) self.bn = nn.BatchNorm1d(inp_dim, momentum=0.01) self.n_d = n_d @@ -486,14 +480,14 @@ class TabNet(nn.Module): assert not torch.isnan(x).any() x = self.bn(x) x_a = self.first_step(x)[:, self.n_d :] - sparse_loss = torch.zeros(1).to(x.device) + sparse_loss = [] out = torch.zeros(x.size(0), self.n_d).to(x.device) for step in self.steps: x_te, l = step(x, x_a, priors) out += F.relu(x_te[:, : self.n_d]) # split the feautre from feat_transformer x_a = x_te[:, self.n_d :] - sparse_loss += l - return self.fc(out), sparse_loss + sparse_loss.append(l) + return self.fc(out), sum(sparse_loss) class GBN(nn.Module): @@ -506,14 +500,17 @@ class GBN(nn.Module): """ def __init__(self, inp, vbs=1024, momentum=0.01): - super().__init__() + super(GBN, self).__init__() self.bn = nn.BatchNorm1d(inp, momentum=momentum) self.vbs = vbs def forward(self, x): - chunk = torch.chunk(x, x.size(0) // self.vbs, 0) - res = [self.bn(y) for y in chunk] - return torch.cat(res, 0) + if x.size(0) <= self.vbs: # can not be chunked + return self.bn(x) + else: + chunk = torch.chunk(x, x.size(0) // self.vbs, 0) + res = [self.bn(y) for y in chunk] + return torch.cat(res, 0) class GLU(nn.Module): @@ -525,7 +522,7 @@ class GLU(nn.Module): """ def __init__(self, inp_dim, out_dim, fc=None, vbs=1024): - super().__init__() + super(GLU, self).__init__() if fc: self.fc = fc else: @@ -561,8 +558,8 @@ class AttentionTransformer(nn.Module): class FeatureTransformer(nn.Module): - def __init__(self, inp_dim, out_dim, shared, n_ind, vbs, device): - super().__init__() + def __init__(self, inp_dim, out_dim, shared, n_ind, vbs): + super(FeatureTransformer, self).__init__() first = True self.shared = nn.ModuleList() if shared: @@ -577,7 +574,7 @@ class FeatureTransformer(nn.Module): self.independ.append(GLU(inp, out_dim, vbs=vbs)) for x in range(first, n_ind): self.independ.append(GLU(out_dim, out_dim, vbs=vbs)) - self.scale = torch.sqrt(torch.tensor([0.5], device=device)) + self.scale = float(np.sqrt(0.5)) def forward(self, x): if self.shared: @@ -596,10 +593,10 @@ class DecisionStep(nn.Module): One step for the TabNet """ - def __init__(self, inp_dim, n_d, n_a, shared, n_ind, relax, vbs, device): + def __init__(self, inp_dim, n_d, n_a, shared, n_ind, relax, vbs): super().__init__() self.atten_tran = AttentionTransformer(n_a, inp_dim, relax, vbs) - self.fea_tran = FeatureTransformer(inp_dim, n_d + n_a, shared, n_ind, vbs, device) + self.fea_tran = FeatureTransformer(inp_dim, n_d + n_a, shared, n_ind, vbs) def forward(self, x, a, priors): mask = self.atten_tran(a, priors)