1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-02 18:40:58 +08:00

Simplify count_parameters

This commit is contained in:
D-X-Y
2021-03-08 08:16:17 +00:00
parent 7bed3b4c2e
commit ca48345b29
2 changed files with 9 additions and 8 deletions

View File

@@ -100,9 +100,7 @@ class TabnetModel(Model):
self.device
)
self.logger.info("model:\n{:}\n{:}".format(self.tabnet_model, self.tabnet_decoder))
self.logger.info(
"model size: {:.4f} MB".format(count_parameters(self.tabnet_model) + count_parameters(self.tabnet_decoder))
)
self.logger.info("model size: {:.4f} MB".format(count_parameters([self.tabnet_model, self.tabnet_decoder])))
if optimizer.lower() == "adam":
self.pretrain_optimizer = optim.Adam(

View File

@@ -1,15 +1,18 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import numpy as np
import torch.nn as nn
def count_parameters(model_or_parameters, unit="mb"):
if isinstance(model_or_parameters, nn.Module):
counts = np.sum(np.prod(v.size()) for v in model_or_parameters.parameters())
def count_parameters(models_or_parameters, unit="mb"):
if isinstance(models_or_parameters, nn.Module):
counts = sum(v.numel() for v in models_or_parameters.parameters())
elif isinstance(models_or_parameters, nn.Parameter):
counts = models_or_parameters.numel()
elif isinstance(models_or_parameters, (list, tuple)):
return sum(count_parameters(x, unit) for x in model_or_parameters)
else:
counts = np.sum(np.prod(v.size()) for v in model_or_parameters)
counts = sum(v.numel() for v in models_or_parameters)
if unit.lower() == "mb":
counts /= 1e6
elif unit.lower() == "kb":