From 4b919657313103f1ee903e32a9213b48e6433afe Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Wed, 17 Feb 2021 15:53:43 +0100 Subject: [PATCH] Factor out methods (#10215) --- src/transformers/modeling_utils.py | 60 +++++++++++++++++------------- 1 file changed, 34 insertions(+), 26 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index d0fc1ad0f4..16a5f0452d 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -86,6 +86,36 @@ def find_pruneable_heads_and_indices( return heads, index +def get_parameter_device(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]): + try: + return next(parameter.parameters()).device + except StopIteration: + # For nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].device + + +def get_parameter_dtype(parameter: Union[nn.Module, GenerationMixin, "ModuleUtilsMixin"]): + try: + return next(parameter.parameters()).dtype + except StopIteration: + # For nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].dtype + + class ModuleUtilsMixin: """ A few utilities for :obj:`torch.nn.Modules`, to be used as a mixin. @@ -145,36 +175,14 @@ class ModuleUtilsMixin: :obj:`torch.device`: The device on which the module is (assuming that all the module parameters are on the same device). """ - try: - return next(self.parameters()).device - except StopIteration: - # For nn.DataParallel compatibility in PyTorch 1.5 - - def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: - tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] - return tuples - - gen = self._named_members(get_members_fn=find_tensor_attributes) - first_tuple = next(gen) - return first_tuple[1].device + return get_parameter_device(self) @property def dtype(self) -> dtype: """ :obj:`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). """ - try: - return next(self.parameters()).dtype - except StopIteration: - # For nn.DataParallel compatibility in PyTorch 1.5 - - def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: - tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] - return tuples - - gen = self._named_members(get_members_fn=find_tensor_attributes) - first_tuple = next(gen) - return first_tuple[1].dtype + return get_parameter_dtype(self) def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor: """ @@ -1238,7 +1246,7 @@ class PoolerStartLogits(nn.Module): x = self.dense(hidden_states).squeeze(-1) if p_mask is not None: - if next(self.parameters()).dtype == torch.float16: + if get_parameter_dtype(self) == torch.float16: x = x * (1 - p_mask) - 65500 * p_mask else: x = x * (1 - p_mask) - 1e30 * p_mask @@ -1305,7 +1313,7 @@ class PoolerEndLogits(nn.Module): x = self.dense_1(x).squeeze(-1) if p_mask is not None: - if next(self.parameters()).dtype == torch.float16: + if get_parameter_dtype(self) == torch.float16: x = x * (1 - p_mask) - 65500 * p_mask else: x = x * (1 - p_mask) - 1e30 * p_mask