num_parameters helper

This commit is contained in:
Julien Chaumond
2020-01-10 17:40:02 +00:00
parent 331065e62d
commit 84c0aa1868
4 changed files with 35 additions and 2 deletions

View File

@@ -53,7 +53,20 @@ except ImportError:
return input
class PreTrainedModel(nn.Module):
class ModuleUtils:
"""
A few utilities for torch.nn.Modules, to be used as a mixin.
"""
def num_parameters(self, only_trainable: bool = False) -> int:
"""
Get number of (optionally, trainable) parameters in the module.
"""
params = filter(lambda x: x.requires_grad, self.parameters()) if only_trainable else self.parameters()
return sum(p.numel() for p in params)
class PreTrainedModel(nn.Module, ModuleUtils):
r""" Base class for all models.
:class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models