1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-04 03:21:00 +08:00

Fix Ghost BN bugs in TabNet and simplify its implementation

This commit is contained in:
D-X-Y
2021-03-14 07:25:09 +00:00
parent 1d2b2f4f01
commit d5f9395e51

View File

@@ -93,12 +93,8 @@ class TabnetModel(Model):
np.random.seed(self.seed)
torch.manual_seed(self.seed)
self.tabnet_model = TabNet(
inp_dim=self.d_feat, out_dim=self.out_dim, vbs=vbs, relax=relax, device=self.device
).to(self.device)
self.tabnet_decoder = TabNet_Decoder(self.out_dim, self.d_feat, n_shared, n_ind, vbs, n_steps, self.device).to(
self.device
)
self.tabnet_model = TabNet(inp_dim=self.d_feat, out_dim=self.out_dim, vbs=vbs, relax=relax).to(self.device)
self.tabnet_decoder = TabNet_Decoder(self.out_dim, self.d_feat, n_shared, n_ind, vbs, n_steps).to(self.device)
self.logger.info("model:\n{:}\n{:}".format(self.tabnet_model, self.tabnet_decoder))
self.logger.info("model size: {:.4f} MB".format(count_parameters([self.tabnet_model, self.tabnet_decoder])))
@@ -401,7 +397,7 @@ class FinetuneModel(nn.Module):
"""
def __init__(self, input_dim, output_dim, trained_model):
super().__init__()
super(FinetuneModel, self).__init__()
self.model = trained_model
self.fc = nn.Linear(input_dim, output_dim)
@@ -410,9 +406,9 @@ class FinetuneModel(nn.Module):
class DecoderStep(nn.Module):
def __init__(self, inp_dim, out_dim, shared, n_ind, vbs, device):
super().__init__()
self.fea_tran = FeatureTransformer(inp_dim, out_dim, shared, n_ind, vbs, device)
def __init__(self, inp_dim, out_dim, shared, n_ind, vbs):
super(DecoderStep, self).__init__()
self.fea_tran = FeatureTransformer(inp_dim, out_dim, shared, n_ind, vbs)
self.fc = nn.Linear(out_dim, out_dim)
def forward(self, x):
@@ -421,13 +417,13 @@ class DecoderStep(nn.Module):
class TabNet_Decoder(nn.Module):
def __init__(self, inp_dim, out_dim, n_shared, n_ind, vbs, n_steps, device):
def __init__(self, inp_dim, out_dim, n_shared, n_ind, vbs, n_steps):
"""
TabNet decoder that is used in pre-training
"""
self.out_dim = out_dim
super().__init__()
super(TabNet_Decoder, self).__init__()
if n_shared > 0:
self.shared = nn.ModuleList()
self.shared.append(nn.Linear(inp_dim, 2 * out_dim))
@@ -438,7 +434,7 @@ class TabNet_Decoder(nn.Module):
self.n_steps = n_steps
self.steps = nn.ModuleList()
for x in range(n_steps):
self.steps.append(DecoderStep(inp_dim, out_dim, self.shared, n_ind, vbs, device))
self.steps.append(DecoderStep(inp_dim, out_dim, self.shared, n_ind, vbs))
def forward(self, x):
out = torch.zeros(x.size(0), self.out_dim).to(x.device)
@@ -448,9 +444,7 @@ class TabNet_Decoder(nn.Module):
class TabNet(nn.Module):
def __init__(
self, inp_dim=6, out_dim=6, n_d=64, n_a=64, n_shared=2, n_ind=2, n_steps=5, relax=1.2, vbs=1024, device="cpu"
):
def __init__(self, inp_dim=6, out_dim=6, n_d=64, n_a=64, n_shared=2, n_ind=2, n_steps=5, relax=1.2, vbs=1024):
"""
TabNet AKA the original encoder
@@ -463,7 +457,7 @@ class TabNet(nn.Module):
relax coefficient:
virtual batch size:
"""
super().__init__()
super(TabNet, self).__init__()
# set the number of shared step in feature transformer
if n_shared > 0:
@@ -474,10 +468,10 @@ class TabNet(nn.Module):
else:
self.shared = None
self.first_step = FeatureTransformer(inp_dim, n_d + n_a, self.shared, n_ind, vbs, device)
self.first_step = FeatureTransformer(inp_dim, n_d + n_a, self.shared, n_ind, vbs)
self.steps = nn.ModuleList()
for x in range(n_steps - 1):
self.steps.append(DecisionStep(inp_dim, n_d, n_a, self.shared, n_ind, relax, vbs, device))
self.steps.append(DecisionStep(inp_dim, n_d, n_a, self.shared, n_ind, relax, vbs))
self.fc = nn.Linear(n_d, out_dim)
self.bn = nn.BatchNorm1d(inp_dim, momentum=0.01)
self.n_d = n_d
@@ -486,14 +480,14 @@ class TabNet(nn.Module):
assert not torch.isnan(x).any()
x = self.bn(x)
x_a = self.first_step(x)[:, self.n_d :]
sparse_loss = torch.zeros(1).to(x.device)
sparse_loss = []
out = torch.zeros(x.size(0), self.n_d).to(x.device)
for step in self.steps:
x_te, l = step(x, x_a, priors)
out += F.relu(x_te[:, : self.n_d]) # split the feautre from feat_transformer
x_a = x_te[:, self.n_d :]
sparse_loss += l
return self.fc(out), sparse_loss
sparse_loss.append(l)
return self.fc(out), sum(sparse_loss)
class GBN(nn.Module):
@@ -506,14 +500,17 @@ class GBN(nn.Module):
"""
def __init__(self, inp, vbs=1024, momentum=0.01):
super().__init__()
super(GBN, self).__init__()
self.bn = nn.BatchNorm1d(inp, momentum=momentum)
self.vbs = vbs
def forward(self, x):
chunk = torch.chunk(x, x.size(0) // self.vbs, 0)
res = [self.bn(y) for y in chunk]
return torch.cat(res, 0)
if x.size(0) <= self.vbs: # can not be chunked
return self.bn(x)
else:
chunk = torch.chunk(x, x.size(0) // self.vbs, 0)
res = [self.bn(y) for y in chunk]
return torch.cat(res, 0)
class GLU(nn.Module):
@@ -525,7 +522,7 @@ class GLU(nn.Module):
"""
def __init__(self, inp_dim, out_dim, fc=None, vbs=1024):
super().__init__()
super(GLU, self).__init__()
if fc:
self.fc = fc
else:
@@ -561,8 +558,8 @@ class AttentionTransformer(nn.Module):
class FeatureTransformer(nn.Module):
def __init__(self, inp_dim, out_dim, shared, n_ind, vbs, device):
super().__init__()
def __init__(self, inp_dim, out_dim, shared, n_ind, vbs):
super(FeatureTransformer, self).__init__()
first = True
self.shared = nn.ModuleList()
if shared:
@@ -577,7 +574,7 @@ class FeatureTransformer(nn.Module):
self.independ.append(GLU(inp, out_dim, vbs=vbs))
for x in range(first, n_ind):
self.independ.append(GLU(out_dim, out_dim, vbs=vbs))
self.scale = torch.sqrt(torch.tensor([0.5], device=device))
self.scale = float(np.sqrt(0.5))
def forward(self, x):
if self.shared:
@@ -596,10 +593,10 @@ class DecisionStep(nn.Module):
One step for the TabNet
"""
def __init__(self, inp_dim, n_d, n_a, shared, n_ind, relax, vbs, device):
def __init__(self, inp_dim, n_d, n_a, shared, n_ind, relax, vbs):
super().__init__()
self.atten_tran = AttentionTransformer(n_a, inp_dim, relax, vbs)
self.fea_tran = FeatureTransformer(inp_dim, n_d + n_a, shared, n_ind, vbs, device)
self.fea_tran = FeatureTransformer(inp_dim, n_d + n_a, shared, n_ind, vbs)
def forward(self, x, a, priors):
mask = self.atten_tran(a, priors)