num_parameters helper
This commit is contained in:
@@ -53,7 +53,20 @@ except ImportError:
|
||||
return input
|
||||
|
||||
|
||||
class PreTrainedModel(nn.Module):
|
||||
class ModuleUtils:
|
||||
"""
|
||||
A few utilities for torch.nn.Modules, to be used as a mixin.
|
||||
"""
|
||||
|
||||
def num_parameters(self, only_trainable: bool = False) -> int:
|
||||
"""
|
||||
Get number of (optionally, trainable) parameters in the module.
|
||||
"""
|
||||
params = filter(lambda x: x.requires_grad, self.parameters()) if only_trainable else self.parameters()
|
||||
return sum(p.numel() for p in params)
|
||||
|
||||
|
||||
class PreTrainedModel(nn.Module, ModuleUtils):
|
||||
r""" Base class for all models.
|
||||
|
||||
:class:`~transformers.PreTrainedModel` takes care of storing the configuration of the models and handles methods for loading/downloading/saving models
|
||||
|
||||
Reference in New Issue
Block a user