Fix nn.DataParallel compatibility in PyTorch 1.5 (#4300)
* Test case for #3936 * multigpu tests pass on pytorch 1.4.0 * Fixup * multigpu tests pass on pytorch 1.5.0 * Update src/transformers/modeling_utils.py * Update src/transformers/modeling_utils.py * rename multigpu to require_multigpu * mode doc
This commit is contained in:
@@ -550,7 +550,7 @@ class AlbertModel(AlbertPreTrainedModel):
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
|
||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
|
||||
@@ -703,9 +703,7 @@ class BertModel(BertPreTrainedModel):
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
|
||||
attention_mask, input_shape, self.device
|
||||
)
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
||||
|
||||
# If a 2D ou 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
|
||||
|
||||
@@ -704,7 +704,7 @@ class T5Stack(T5PreTrainedModel):
|
||||
past_key_value_states = [None] * len(self.block)
|
||||
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, self.device)
|
||||
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device)
|
||||
|
||||
if self.is_decoder and encoder_attention_mask is not None:
|
||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
from typing import Callable, Tuple
|
||||
from typing import Callable, List, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, device, dtype, nn
|
||||
@@ -110,11 +110,33 @@ class ModuleUtilsMixin:
|
||||
|
||||
@property
|
||||
def device(self) -> device:
|
||||
return next(self.parameters()).device
|
||||
try:
|
||||
return next(self.parameters()).device
|
||||
except StopIteration:
|
||||
# For nn.DataParallel compatibility in PyTorch 1.5
|
||||
|
||||
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
|
||||
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
||||
return tuples
|
||||
|
||||
gen = self._named_members(get_members_fn=find_tensor_attributes)
|
||||
first_tuple = next(gen)
|
||||
return first_tuple[1].device
|
||||
|
||||
@property
|
||||
def dtype(self) -> dtype:
|
||||
return next(self.parameters()).dtype
|
||||
try:
|
||||
return next(self.parameters()).dtype
|
||||
except StopIteration:
|
||||
# For nn.DataParallel compatibility in PyTorch 1.5
|
||||
|
||||
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
|
||||
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
||||
return tuples
|
||||
|
||||
gen = self._named_members(get_members_fn=find_tensor_attributes)
|
||||
first_tuple = next(gen)
|
||||
return first_tuple[1].dtype
|
||||
|
||||
def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
|
||||
"""type: torch.Tensor -> torch.Tensor"""
|
||||
|
||||
@@ -623,7 +623,7 @@ class XLNetModel(XLNetPreTrainedModel):
|
||||
mask_lo = torch.tril(attn_mask, diagonal=-1)
|
||||
ret = torch.cat([ret[:, :qlen] + mask_lo, ret[:, qlen:]], dim=1)
|
||||
|
||||
ret = ret.to(next(self.parameters()))
|
||||
ret = ret.to(self.device)
|
||||
return ret
|
||||
|
||||
def cache_mem(self, curr_out, prev_mem):
|
||||
@@ -685,7 +685,7 @@ class XLNetModel(XLNetPreTrainedModel):
|
||||
fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
|
||||
pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)
|
||||
|
||||
pos_emb = pos_emb.to(next(self.parameters()))
|
||||
pos_emb = pos_emb.to(self.device)
|
||||
return pos_emb
|
||||
|
||||
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
|
||||
@@ -761,8 +761,8 @@ class XLNetModel(XLNetPreTrainedModel):
|
||||
mlen = mems[0].shape[0] if mems is not None and mems[0] is not None else 0
|
||||
klen = mlen + qlen
|
||||
|
||||
dtype_float = next(self.parameters()).dtype
|
||||
device = next(self.parameters()).device
|
||||
dtype_float = self.dtype
|
||||
device = self.device
|
||||
|
||||
# Attention mask
|
||||
# causal attention mask
|
||||
|
||||
Reference in New Issue
Block a user