mirror of
https://github.com/microsoft/qlib.git
synced 2026-07-02 18:40:58 +08:00
Simplify count_parameters
This commit is contained in:
@@ -100,9 +100,7 @@ class TabnetModel(Model):
|
||||
self.device
|
||||
)
|
||||
self.logger.info("model:\n{:}\n{:}".format(self.tabnet_model, self.tabnet_decoder))
|
||||
self.logger.info(
|
||||
"model size: {:.4f} MB".format(count_parameters(self.tabnet_model) + count_parameters(self.tabnet_decoder))
|
||||
)
|
||||
self.logger.info("model size: {:.4f} MB".format(count_parameters([self.tabnet_model, self.tabnet_decoder])))
|
||||
|
||||
if optimizer.lower() == "adam":
|
||||
self.pretrain_optimizer = optim.Adam(
|
||||
|
||||
@@ -1,15 +1,18 @@
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def count_parameters(model_or_parameters, unit="mb"):
|
||||
if isinstance(model_or_parameters, nn.Module):
|
||||
counts = np.sum(np.prod(v.size()) for v in model_or_parameters.parameters())
|
||||
def count_parameters(models_or_parameters, unit="mb"):
|
||||
if isinstance(models_or_parameters, nn.Module):
|
||||
counts = sum(v.numel() for v in models_or_parameters.parameters())
|
||||
elif isinstance(models_or_parameters, nn.Parameter):
|
||||
counts = models_or_parameters.numel()
|
||||
elif isinstance(models_or_parameters, (list, tuple)):
|
||||
return sum(count_parameters(x, unit) for x in model_or_parameters)
|
||||
else:
|
||||
counts = np.sum(np.prod(v.size()) for v in model_or_parameters)
|
||||
counts = sum(v.numel() for v in models_or_parameters)
|
||||
if unit.lower() == "mb":
|
||||
counts /= 1e6
|
||||
elif unit.lower() == "kb":
|
||||
|
||||
Reference in New Issue
Block a user