Fix nn.DataParallel compatibility in PyTorch 1.5 (#4300)
* Test case for #3936 * multigpu tests pass on pytorch 1.4.0 * Fixup * multigpu tests pass on pytorch 1.5.0 * Update src/transformers/modeling_utils.py * Update src/transformers/modeling_utils.py * rename multigpu to require_multigpu * mode doc
This commit is contained in:
@@ -550,7 +550,7 @@ class AlbertModel(AlbertPreTrainedModel):
|
||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||
|
||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||
|
||||
|
||||
@@ -703,9 +703,7 @@ class BertModel(BertPreTrainedModel):
|
||||
|
||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
|
||||
attention_mask, input_shape, self.device
|
||||
)
|
||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
||||
|
||||
# If a 2D ou 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
|
||||
|
||||
@@ -704,7 +704,7 @@ class T5Stack(T5PreTrainedModel):
|
||||
past_key_value_states = [None] * len(self.block)
|
||||
|
||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, self.device)
|
||||
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device)
|
||||
|
||||
if self.is_decoder and encoder_attention_mask is not None:
|
||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
from typing import Callable, Tuple
|
||||
from typing import Callable, List, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, device, dtype, nn
|
||||
@@ -110,11 +110,33 @@ class ModuleUtilsMixin:
|
||||
|
||||
@property
|
||||
def device(self) -> device:
|
||||
try:
|
||||
return next(self.parameters()).device
|
||||
except StopIteration:
|
||||
# For nn.DataParallel compatibility in PyTorch 1.5
|
||||
|
||||
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
|
||||
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
||||
return tuples
|
||||
|
||||
gen = self._named_members(get_members_fn=find_tensor_attributes)
|
||||
first_tuple = next(gen)
|
||||
return first_tuple[1].device
|
||||
|
||||
@property
|
||||
def dtype(self) -> dtype:
|
||||
try:
|
||||
return next(self.parameters()).dtype
|
||||
except StopIteration:
|
||||
# For nn.DataParallel compatibility in PyTorch 1.5
|
||||
|
||||
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
|
||||
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
||||
return tuples
|
||||
|
||||
gen = self._named_members(get_members_fn=find_tensor_attributes)
|
||||
first_tuple = next(gen)
|
||||
return first_tuple[1].dtype
|
||||
|
||||
def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
|
||||
"""type: torch.Tensor -> torch.Tensor"""
|
||||
|
||||
@@ -623,7 +623,7 @@ class XLNetModel(XLNetPreTrainedModel):
|
||||
mask_lo = torch.tril(attn_mask, diagonal=-1)
|
||||
ret = torch.cat([ret[:, :qlen] + mask_lo, ret[:, qlen:]], dim=1)
|
||||
|
||||
ret = ret.to(next(self.parameters()))
|
||||
ret = ret.to(self.device)
|
||||
return ret
|
||||
|
||||
def cache_mem(self, curr_out, prev_mem):
|
||||
@@ -685,7 +685,7 @@ class XLNetModel(XLNetPreTrainedModel):
|
||||
fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
|
||||
pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)
|
||||
|
||||
pos_emb = pos_emb.to(next(self.parameters()))
|
||||
pos_emb = pos_emb.to(self.device)
|
||||
return pos_emb
|
||||
|
||||
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
|
||||
@@ -761,8 +761,8 @@ class XLNetModel(XLNetPreTrainedModel):
|
||||
mlen = mems[0].shape[0] if mems is not None and mems[0] is not None else 0
|
||||
klen = mlen + qlen
|
||||
|
||||
dtype_float = next(self.parameters()).dtype
|
||||
device = next(self.parameters()).device
|
||||
dtype_float = self.dtype
|
||||
device = self.device
|
||||
|
||||
# Attention mask
|
||||
# causal attention mask
|
||||
|
||||
@@ -23,7 +23,7 @@ from typing import List
|
||||
|
||||
from transformers import is_torch_available
|
||||
|
||||
from .utils import require_torch, slow, torch_device
|
||||
from .utils import require_multigpu, require_torch, slow, torch_device
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -758,6 +758,31 @@ class ModelTesterMixin:
|
||||
return True
|
||||
return False
|
||||
|
||||
@require_multigpu
|
||||
def test_multigpu_data_parallel_forward(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# some params shouldn't be scattered by nn.DataParallel
|
||||
# so just remove them if they are present.
|
||||
blacklist_non_batched_params = ["head_mask"]
|
||||
for k in blacklist_non_batched_params:
|
||||
inputs_dict.pop(k, None)
|
||||
|
||||
# move input tensors to cuda:O
|
||||
for k, v in inputs_dict.items():
|
||||
if torch.is_tensor(v):
|
||||
inputs_dict[k] = v.to(0)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=config)
|
||||
model.to(0)
|
||||
model.eval()
|
||||
|
||||
# Wrap model in nn.DataParallel
|
||||
model = torch.nn.DataParallel(model)
|
||||
with torch.no_grad():
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ class CTRLModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
batch_size=14,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_token_type_ids=True,
|
||||
|
||||
@@ -46,7 +46,7 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
batch_size=14,
|
||||
seq_length=7,
|
||||
is_training=True,
|
||||
use_token_type_ids=True,
|
||||
|
||||
@@ -19,7 +19,7 @@ from transformers import is_torch_available
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||
from .utils import require_torch, slow, torch_device
|
||||
from .utils import require_multigpu, require_torch, slow, torch_device
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -448,9 +448,14 @@ class ReformerTesterMixin:
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_reformer_model_fp16_generate(*config_and_inputs)
|
||||
|
||||
@require_multigpu
|
||||
def test_multigpu_data_parallel_forward(self):
|
||||
# Opt-out of this test.
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class ReformerLocalAttnModelTest(ModelTesterMixin, ReformerTesterMixin, unittest.TestCase):
|
||||
class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else ()
|
||||
all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
@@ -504,7 +509,7 @@ class ReformerLocalAttnModelTest(ModelTesterMixin, ReformerTesterMixin, unittest
|
||||
|
||||
|
||||
@require_torch
|
||||
class ReformerLSHAttnModelTest(ModelTesterMixin, unittest.TestCase, ReformerTesterMixin):
|
||||
class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else ()
|
||||
all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else ()
|
||||
test_pruning = False
|
||||
|
||||
@@ -21,7 +21,7 @@ from transformers import is_torch_available
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||
from .utils import require_torch, slow, torch_device
|
||||
from .utils import require_multigpu, require_torch, slow, torch_device
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -43,7 +43,7 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
batch_size=14,
|
||||
seq_length=7,
|
||||
mem_len=30,
|
||||
clamp_len=15,
|
||||
@@ -207,6 +207,11 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
output_result = self.model_tester.create_transfo_xl_lm_head(*config_and_inputs)
|
||||
self.model_tester.check_transfo_xl_lm_head_output(output_result)
|
||||
|
||||
@require_multigpu
|
||||
def test_multigpu_data_parallel_forward(self):
|
||||
# Opt-out of this test.
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
for model_name in list(TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
|
||||
@@ -61,7 +61,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
def __init__(
|
||||
self,
|
||||
parent,
|
||||
batch_size=13,
|
||||
batch_size=14,
|
||||
seq_length=7,
|
||||
mem_len=10,
|
||||
clamp_len=-1,
|
||||
|
||||
@@ -94,6 +94,25 @@ def require_tf(test_case):
|
||||
return test_case
|
||||
|
||||
|
||||
def require_multigpu(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires a multi-GPU setup (in PyTorch).
|
||||
|
||||
These tests are skipped on a machine without multiple GPUs.
|
||||
|
||||
To run *only* the multigpu tests, assuming all test names contain multigpu:
|
||||
$ pytest -sv ./tests -k "multigpu"
|
||||
"""
|
||||
if not _torch_available:
|
||||
return unittest.skip("test requires PyTorch")(test_case)
|
||||
|
||||
import torch
|
||||
|
||||
if torch.cuda.device_count() < 2:
|
||||
return unittest.skip("test requires multiple GPUs")(test_case)
|
||||
return test_case
|
||||
|
||||
|
||||
if _torch_available:
|
||||
# Set the USE_CUDA environment variable to select a GPU.
|
||||
torch_device = "cuda" if parse_flag_from_env("USE_CUDA") else "cpu"
|
||||
|
||||
Reference in New Issue
Block a user