From ca48345b29dd40d6b50c87bd027eca51cb522a05 Mon Sep 17 00:00:00 2001 From: D-X-Y <280835372@qq.com> Date: Mon, 8 Mar 2021 08:16:17 +0000 Subject: [PATCH] Simplify count_parameters --- qlib/contrib/model/pytorch_tabnet.py | 4 +--- qlib/contrib/model/pytorch_utils.py | 13 ++++++++----- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/qlib/contrib/model/pytorch_tabnet.py b/qlib/contrib/model/pytorch_tabnet.py index 682e1b19f..020bbaff2 100644 --- a/qlib/contrib/model/pytorch_tabnet.py +++ b/qlib/contrib/model/pytorch_tabnet.py @@ -100,9 +100,7 @@ class TabnetModel(Model): self.device ) self.logger.info("model:\n{:}\n{:}".format(self.tabnet_model, self.tabnet_decoder)) - self.logger.info( - "model size: {:.4f} MB".format(count_parameters(self.tabnet_model) + count_parameters(self.tabnet_decoder)) - ) + self.logger.info("model size: {:.4f} MB".format(count_parameters([self.tabnet_model, self.tabnet_decoder]))) if optimizer.lower() == "adam": self.pretrain_optimizer = optim.Adam( diff --git a/qlib/contrib/model/pytorch_utils.py b/qlib/contrib/model/pytorch_utils.py index 532969eb5..f2457c9a6 100644 --- a/qlib/contrib/model/pytorch_utils.py +++ b/qlib/contrib/model/pytorch_utils.py @@ -1,15 +1,18 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import numpy as np import torch.nn as nn -def count_parameters(model_or_parameters, unit="mb"): - if isinstance(model_or_parameters, nn.Module): - counts = np.sum(np.prod(v.size()) for v in model_or_parameters.parameters()) +def count_parameters(models_or_parameters, unit="mb"): + if isinstance(models_or_parameters, nn.Module): + counts = sum(v.numel() for v in models_or_parameters.parameters()) + elif isinstance(models_or_parameters, nn.Parameter): + counts = models_or_parameters.numel() + elif isinstance(models_or_parameters, (list, tuple)): + return sum(count_parameters(x, unit) for x in model_or_parameters) else: - counts = np.sum(np.prod(v.size()) for v in model_or_parameters) + counts = sum(v.numel() for v in models_or_parameters) if unit.lower() == "mb": counts /= 1e6 elif unit.lower() == "kb":