diff --git a/qlib/contrib/model/pytorch_utils.py b/qlib/contrib/model/pytorch_utils.py index c0d483db3..e7a8e8d67 100644 --- a/qlib/contrib/model/pytorch_utils.py +++ b/qlib/contrib/model/pytorch_utils.py @@ -3,6 +3,7 @@ import torch.nn as nn + def count_parameters(models_or_parameters, unit="mb"): if isinstance(models_or_parameters, nn.Module): counts = sum(v.numel() for v in models_or_parameters.parameters())