1
0
mirror of https://github.com/microsoft/qlib.git synced 2026-07-04 19:41:00 +08:00
Files
qlib/qlib/contrib/model/pytorch_utils.py
2022-02-06 22:33:16 +08:00

38 lines
1.2 KiB
Python

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import torch.nn as nn
def count_parameters(models_or_parameters, unit="m"):
"""
This function is to obtain the storage size unit of a (or multiple) models.
Parameters
----------
models_or_parameters : PyTorch model(s) or a list of parameters.
unit : the storage size unit.
Returns
-------
The number of parameters of the given model(s) or parameters.
"""
if isinstance(models_or_parameters, nn.Module):
counts = sum(v.numel() for v in models_or_parameters.parameters())
elif isinstance(models_or_parameters, nn.Parameter):
counts = models_or_parameters.numel()
elif isinstance(models_or_parameters, (list, tuple)):
return sum(count_parameters(x, unit) for x in models_or_parameters)
else:
counts = sum(v.numel() for v in models_or_parameters)
unit = unit.lower()
if unit in ("kb", "k"):
counts /= 2**10
elif unit in ("mb", "m"):
counts /= 2**20
elif unit in ("gb", "g"):
counts /= 2**30
elif unit is not None:
raise ValueError("Unknown unit: {:}".format(unit))
return counts