Fix nn.DataParallel compatibility in PyTorch 1.5 (#4300)

* Test case for #3936

* multigpu tests pass on pytorch 1.4.0

* Fixup

* multigpu tests pass on pytorch 1.5.0

* Update src/transformers/modeling_utils.py

* Update src/transformers/modeling_utils.py

* rename multigpu to require_multigpu

* mode doc
This commit is contained in:
Julien Chaumond
2020-05-18 20:34:50 -04:00
committed by GitHub
parent 9de4afa897
commit 4c06893610
12 changed files with 95 additions and 21 deletions

View File

@@ -550,7 +550,7 @@ class AlbertModel(AlbertPreTrainedModel):
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)

View File

@@ -703,9 +703,7 @@ class BertModel(BertPreTrainedModel):
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads. # ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
attention_mask, input_shape, self.device
)
# If a 2D ou 3D attention mask is provided for the cross-attention # If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length] # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]

View File

@@ -704,7 +704,7 @@ class T5Stack(T5PreTrainedModel):
past_key_value_states = [None] * len(self.block) past_key_value_states = [None] * len(self.block)
# ourselves in which case we just need to make it broadcastable to all heads. # ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, self.device) extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device)
if self.is_decoder and encoder_attention_mask is not None: if self.is_decoder and encoder_attention_mask is not None:
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)

View File

@@ -17,7 +17,7 @@
import inspect import inspect
import logging import logging
import os import os
from typing import Callable, Tuple from typing import Callable, List, Tuple
import torch import torch
from torch import Tensor, device, dtype, nn from torch import Tensor, device, dtype, nn
@@ -110,11 +110,33 @@ class ModuleUtilsMixin:
@property @property
def device(self) -> device: def device(self) -> device:
try:
return next(self.parameters()).device return next(self.parameters()).device
except StopIteration:
# For nn.DataParallel compatibility in PyTorch 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = self._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].device
@property @property
def dtype(self) -> dtype: def dtype(self) -> dtype:
try:
return next(self.parameters()).dtype return next(self.parameters()).dtype
except StopIteration:
# For nn.DataParallel compatibility in PyTorch 1.5
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
return tuples
gen = self._named_members(get_members_fn=find_tensor_attributes)
first_tuple = next(gen)
return first_tuple[1].dtype
def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor: def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
"""type: torch.Tensor -> torch.Tensor""" """type: torch.Tensor -> torch.Tensor"""

View File

@@ -623,7 +623,7 @@ class XLNetModel(XLNetPreTrainedModel):
mask_lo = torch.tril(attn_mask, diagonal=-1) mask_lo = torch.tril(attn_mask, diagonal=-1)
ret = torch.cat([ret[:, :qlen] + mask_lo, ret[:, qlen:]], dim=1) ret = torch.cat([ret[:, :qlen] + mask_lo, ret[:, qlen:]], dim=1)
ret = ret.to(next(self.parameters())) ret = ret.to(self.device)
return ret return ret
def cache_mem(self, curr_out, prev_mem): def cache_mem(self, curr_out, prev_mem):
@@ -685,7 +685,7 @@ class XLNetModel(XLNetPreTrainedModel):
fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len) fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz) pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)
pos_emb = pos_emb.to(next(self.parameters())) pos_emb = pos_emb.to(self.device)
return pos_emb return pos_emb
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
@@ -761,8 +761,8 @@ class XLNetModel(XLNetPreTrainedModel):
mlen = mems[0].shape[0] if mems is not None and mems[0] is not None else 0 mlen = mems[0].shape[0] if mems is not None and mems[0] is not None else 0
klen = mlen + qlen klen = mlen + qlen
dtype_float = next(self.parameters()).dtype dtype_float = self.dtype
device = next(self.parameters()).device device = self.device
# Attention mask # Attention mask
# causal attention mask # causal attention mask

View File

@@ -23,7 +23,7 @@ from typing import List
from transformers import is_torch_available from transformers import is_torch_available
from .utils import require_torch, slow, torch_device from .utils import require_multigpu, require_torch, slow, torch_device
if is_torch_available(): if is_torch_available():
@@ -758,6 +758,31 @@ class ModelTesterMixin:
return True return True
return False return False
@require_multigpu
def test_multigpu_data_parallel_forward(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# some params shouldn't be scattered by nn.DataParallel
# so just remove them if they are present.
blacklist_non_batched_params = ["head_mask"]
for k in blacklist_non_batched_params:
inputs_dict.pop(k, None)
# move input tensors to cuda:O
for k, v in inputs_dict.items():
if torch.is_tensor(v):
inputs_dict[k] = v.to(0)
for model_class in self.all_model_classes:
model = model_class(config=config)
model.to(0)
model.eval()
# Wrap model in nn.DataParallel
model = torch.nn.DataParallel(model)
with torch.no_grad():
_ = model(**inputs_dict)
global_rng = random.Random() global_rng = random.Random()

View File

@@ -41,7 +41,7 @@ class CTRLModelTest(ModelTesterMixin, unittest.TestCase):
def __init__( def __init__(
self, self,
parent, parent,
batch_size=13, batch_size=14,
seq_length=7, seq_length=7,
is_training=True, is_training=True,
use_token_type_ids=True, use_token_type_ids=True,

View File

@@ -46,7 +46,7 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
def __init__( def __init__(
self, self,
parent, parent,
batch_size=13, batch_size=14,
seq_length=7, seq_length=7,
is_training=True, is_training=True,
use_token_type_ids=True, use_token_type_ids=True,

View File

@@ -19,7 +19,7 @@ from transformers import is_torch_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from .utils import require_torch, slow, torch_device from .utils import require_multigpu, require_torch, slow, torch_device
if is_torch_available(): if is_torch_available():
@@ -448,9 +448,14 @@ class ReformerTesterMixin:
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_model_fp16_generate(*config_and_inputs) self.model_tester.create_and_check_reformer_model_fp16_generate(*config_and_inputs)
@require_multigpu
def test_multigpu_data_parallel_forward(self):
# Opt-out of this test.
pass
@require_torch @require_torch
class ReformerLocalAttnModelTest(ModelTesterMixin, ReformerTesterMixin, unittest.TestCase): class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.TestCase):
all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else () all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else ()
all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else () all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else ()
test_pruning = False test_pruning = False
@@ -504,7 +509,7 @@ class ReformerLocalAttnModelTest(ModelTesterMixin, ReformerTesterMixin, unittest
@require_torch @require_torch
class ReformerLSHAttnModelTest(ModelTesterMixin, unittest.TestCase, ReformerTesterMixin): class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.TestCase):
all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else () all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else ()
all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else () all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else ()
test_pruning = False test_pruning = False

View File

@@ -21,7 +21,7 @@ from transformers import is_torch_available
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import require_torch, slow, torch_device from .utils import require_multigpu, require_torch, slow, torch_device
if is_torch_available(): if is_torch_available():
@@ -43,7 +43,7 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
def __init__( def __init__(
self, self,
parent, parent,
batch_size=13, batch_size=14,
seq_length=7, seq_length=7,
mem_len=30, mem_len=30,
clamp_len=15, clamp_len=15,
@@ -207,6 +207,11 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
output_result = self.model_tester.create_transfo_xl_lm_head(*config_and_inputs) output_result = self.model_tester.create_transfo_xl_lm_head(*config_and_inputs)
self.model_tester.check_transfo_xl_lm_head_output(output_result) self.model_tester.check_transfo_xl_lm_head_output(output_result)
@require_multigpu
def test_multigpu_data_parallel_forward(self):
# Opt-out of this test.
pass
@slow @slow
def test_model_from_pretrained(self): def test_model_from_pretrained(self):
for model_name in list(TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: for model_name in list(TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:

View File

@@ -61,7 +61,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
def __init__( def __init__(
self, self,
parent, parent,
batch_size=13, batch_size=14,
seq_length=7, seq_length=7,
mem_len=10, mem_len=10,
clamp_len=-1, clamp_len=-1,

View File

@@ -94,6 +94,25 @@ def require_tf(test_case):
return test_case return test_case
def require_multigpu(test_case):
"""
Decorator marking a test that requires a multi-GPU setup (in PyTorch).
These tests are skipped on a machine without multiple GPUs.
To run *only* the multigpu tests, assuming all test names contain multigpu:
$ pytest -sv ./tests -k "multigpu"
"""
if not _torch_available:
return unittest.skip("test requires PyTorch")(test_case)
import torch
if torch.cuda.device_count() < 2:
return unittest.skip("test requires multiple GPUs")(test_case)
return test_case
if _torch_available: if _torch_available:
# Set the USE_CUDA environment variable to select a GPU. # Set the USE_CUDA environment variable to select a GPU.
torch_device = "cuda" if parse_flag_from_env("USE_CUDA") else "cpu" torch_device = "cuda" if parse_flag_from_env("USE_CUDA") else "cpu"