mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-04 03:21:00 +08:00
Fix Ghost BN bugs in TabNet and simplify its implementation
This commit is contained in:
@@ -93,12 +93,8 @@ class TabnetModel(Model):
|
||||
np.random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
|
||||
self.tabnet_model = TabNet(
|
||||
inp_dim=self.d_feat, out_dim=self.out_dim, vbs=vbs, relax=relax, device=self.device
|
||||
).to(self.device)
|
||||
self.tabnet_decoder = TabNet_Decoder(self.out_dim, self.d_feat, n_shared, n_ind, vbs, n_steps, self.device).to(
|
||||
self.device
|
||||
)
|
||||
self.tabnet_model = TabNet(inp_dim=self.d_feat, out_dim=self.out_dim, vbs=vbs, relax=relax).to(self.device)
|
||||
self.tabnet_decoder = TabNet_Decoder(self.out_dim, self.d_feat, n_shared, n_ind, vbs, n_steps).to(self.device)
|
||||
self.logger.info("model:\n{:}\n{:}".format(self.tabnet_model, self.tabnet_decoder))
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters([self.tabnet_model, self.tabnet_decoder])))
|
||||
|
||||
@@ -401,7 +397,7 @@ class FinetuneModel(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self, input_dim, output_dim, trained_model):
|
||||
super().__init__()
|
||||
super(FinetuneModel, self).__init__()
|
||||
self.model = trained_model
|
||||
self.fc = nn.Linear(input_dim, output_dim)
|
||||
|
||||
@@ -410,9 +406,9 @@ class FinetuneModel(nn.Module):
|
||||
|
||||
|
||||
class DecoderStep(nn.Module):
|
||||
def __init__(self, inp_dim, out_dim, shared, n_ind, vbs, device):
|
||||
super().__init__()
|
||||
self.fea_tran = FeatureTransformer(inp_dim, out_dim, shared, n_ind, vbs, device)
|
||||
def __init__(self, inp_dim, out_dim, shared, n_ind, vbs):
|
||||
super(DecoderStep, self).__init__()
|
||||
self.fea_tran = FeatureTransformer(inp_dim, out_dim, shared, n_ind, vbs)
|
||||
self.fc = nn.Linear(out_dim, out_dim)
|
||||
|
||||
def forward(self, x):
|
||||
@@ -421,13 +417,13 @@ class DecoderStep(nn.Module):
|
||||
|
||||
|
||||
class TabNet_Decoder(nn.Module):
|
||||
def __init__(self, inp_dim, out_dim, n_shared, n_ind, vbs, n_steps, device):
|
||||
def __init__(self, inp_dim, out_dim, n_shared, n_ind, vbs, n_steps):
|
||||
"""
|
||||
TabNet decoder that is used in pre-training
|
||||
"""
|
||||
self.out_dim = out_dim
|
||||
|
||||
super().__init__()
|
||||
super(TabNet_Decoder, self).__init__()
|
||||
if n_shared > 0:
|
||||
self.shared = nn.ModuleList()
|
||||
self.shared.append(nn.Linear(inp_dim, 2 * out_dim))
|
||||
@@ -438,7 +434,7 @@ class TabNet_Decoder(nn.Module):
|
||||
self.n_steps = n_steps
|
||||
self.steps = nn.ModuleList()
|
||||
for x in range(n_steps):
|
||||
self.steps.append(DecoderStep(inp_dim, out_dim, self.shared, n_ind, vbs, device))
|
||||
self.steps.append(DecoderStep(inp_dim, out_dim, self.shared, n_ind, vbs))
|
||||
|
||||
def forward(self, x):
|
||||
out = torch.zeros(x.size(0), self.out_dim).to(x.device)
|
||||
@@ -448,9 +444,7 @@ class TabNet_Decoder(nn.Module):
|
||||
|
||||
|
||||
class TabNet(nn.Module):
|
||||
def __init__(
|
||||
self, inp_dim=6, out_dim=6, n_d=64, n_a=64, n_shared=2, n_ind=2, n_steps=5, relax=1.2, vbs=1024, device="cpu"
|
||||
):
|
||||
def __init__(self, inp_dim=6, out_dim=6, n_d=64, n_a=64, n_shared=2, n_ind=2, n_steps=5, relax=1.2, vbs=1024):
|
||||
"""
|
||||
TabNet AKA the original encoder
|
||||
|
||||
@@ -463,7 +457,7 @@ class TabNet(nn.Module):
|
||||
relax coefficient:
|
||||
virtual batch size:
|
||||
"""
|
||||
super().__init__()
|
||||
super(TabNet, self).__init__()
|
||||
|
||||
# set the number of shared step in feature transformer
|
||||
if n_shared > 0:
|
||||
@@ -474,10 +468,10 @@ class TabNet(nn.Module):
|
||||
else:
|
||||
self.shared = None
|
||||
|
||||
self.first_step = FeatureTransformer(inp_dim, n_d + n_a, self.shared, n_ind, vbs, device)
|
||||
self.first_step = FeatureTransformer(inp_dim, n_d + n_a, self.shared, n_ind, vbs)
|
||||
self.steps = nn.ModuleList()
|
||||
for x in range(n_steps - 1):
|
||||
self.steps.append(DecisionStep(inp_dim, n_d, n_a, self.shared, n_ind, relax, vbs, device))
|
||||
self.steps.append(DecisionStep(inp_dim, n_d, n_a, self.shared, n_ind, relax, vbs))
|
||||
self.fc = nn.Linear(n_d, out_dim)
|
||||
self.bn = nn.BatchNorm1d(inp_dim, momentum=0.01)
|
||||
self.n_d = n_d
|
||||
@@ -486,14 +480,14 @@ class TabNet(nn.Module):
|
||||
assert not torch.isnan(x).any()
|
||||
x = self.bn(x)
|
||||
x_a = self.first_step(x)[:, self.n_d :]
|
||||
sparse_loss = torch.zeros(1).to(x.device)
|
||||
sparse_loss = []
|
||||
out = torch.zeros(x.size(0), self.n_d).to(x.device)
|
||||
for step in self.steps:
|
||||
x_te, l = step(x, x_a, priors)
|
||||
out += F.relu(x_te[:, : self.n_d]) # split the feautre from feat_transformer
|
||||
x_a = x_te[:, self.n_d :]
|
||||
sparse_loss += l
|
||||
return self.fc(out), sparse_loss
|
||||
sparse_loss.append(l)
|
||||
return self.fc(out), sum(sparse_loss)
|
||||
|
||||
|
||||
class GBN(nn.Module):
|
||||
@@ -506,14 +500,17 @@ class GBN(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self, inp, vbs=1024, momentum=0.01):
|
||||
super().__init__()
|
||||
super(GBN, self).__init__()
|
||||
self.bn = nn.BatchNorm1d(inp, momentum=momentum)
|
||||
self.vbs = vbs
|
||||
|
||||
def forward(self, x):
|
||||
chunk = torch.chunk(x, x.size(0) // self.vbs, 0)
|
||||
res = [self.bn(y) for y in chunk]
|
||||
return torch.cat(res, 0)
|
||||
if x.size(0) <= self.vbs: # can not be chunked
|
||||
return self.bn(x)
|
||||
else:
|
||||
chunk = torch.chunk(x, x.size(0) // self.vbs, 0)
|
||||
res = [self.bn(y) for y in chunk]
|
||||
return torch.cat(res, 0)
|
||||
|
||||
|
||||
class GLU(nn.Module):
|
||||
@@ -525,7 +522,7 @@ class GLU(nn.Module):
|
||||
"""
|
||||
|
||||
def __init__(self, inp_dim, out_dim, fc=None, vbs=1024):
|
||||
super().__init__()
|
||||
super(GLU, self).__init__()
|
||||
if fc:
|
||||
self.fc = fc
|
||||
else:
|
||||
@@ -561,8 +558,8 @@ class AttentionTransformer(nn.Module):
|
||||
|
||||
|
||||
class FeatureTransformer(nn.Module):
|
||||
def __init__(self, inp_dim, out_dim, shared, n_ind, vbs, device):
|
||||
super().__init__()
|
||||
def __init__(self, inp_dim, out_dim, shared, n_ind, vbs):
|
||||
super(FeatureTransformer, self).__init__()
|
||||
first = True
|
||||
self.shared = nn.ModuleList()
|
||||
if shared:
|
||||
@@ -577,7 +574,7 @@ class FeatureTransformer(nn.Module):
|
||||
self.independ.append(GLU(inp, out_dim, vbs=vbs))
|
||||
for x in range(first, n_ind):
|
||||
self.independ.append(GLU(out_dim, out_dim, vbs=vbs))
|
||||
self.scale = torch.sqrt(torch.tensor([0.5], device=device))
|
||||
self.scale = float(np.sqrt(0.5))
|
||||
|
||||
def forward(self, x):
|
||||
if self.shared:
|
||||
@@ -596,10 +593,10 @@ class DecisionStep(nn.Module):
|
||||
One step for the TabNet
|
||||
"""
|
||||
|
||||
def __init__(self, inp_dim, n_d, n_a, shared, n_ind, relax, vbs, device):
|
||||
def __init__(self, inp_dim, n_d, n_a, shared, n_ind, relax, vbs):
|
||||
super().__init__()
|
||||
self.atten_tran = AttentionTransformer(n_a, inp_dim, relax, vbs)
|
||||
self.fea_tran = FeatureTransformer(inp_dim, n_d + n_a, shared, n_ind, vbs, device)
|
||||
self.fea_tran = FeatureTransformer(inp_dim, n_d + n_a, shared, n_ind, vbs)
|
||||
|
||||
def forward(self, x, a, priors):
|
||||
mask = self.atten_tran(a, priors)
|
||||
|
||||
Reference in New Issue
Block a user