From 4c06893610ee148f6645b7fee21f382de5d53023 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Mon, 18 May 2020 20:34:50 -0400 Subject: [PATCH] Fix nn.DataParallel compatibility in PyTorch 1.5 (#4300) * Test case for #3936 * multigpu tests pass on pytorch 1.4.0 * Fixup * multigpu tests pass on pytorch 1.5.0 * Update src/transformers/modeling_utils.py * Update src/transformers/modeling_utils.py * rename multigpu to require_multigpu * mode doc --- src/transformers/modeling_albert.py | 2 +- src/transformers/modeling_bert.py | 4 +--- src/transformers/modeling_t5.py | 2 +- src/transformers/modeling_utils.py | 28 +++++++++++++++++++++++++--- src/transformers/modeling_xlnet.py | 8 ++++---- tests/test_modeling_common.py | 27 ++++++++++++++++++++++++++- tests/test_modeling_ctrl.py | 2 +- tests/test_modeling_gpt2.py | 2 +- tests/test_modeling_reformer.py | 11 ++++++++--- tests/test_modeling_transfo_xl.py | 9 +++++++-- tests/test_modeling_xlnet.py | 2 +- tests/utils.py | 19 +++++++++++++++++++ 12 files changed, 95 insertions(+), 21 deletions(-) diff --git a/src/transformers/modeling_albert.py b/src/transformers/modeling_albert.py index 1dd1bcf553..eb50c29ff6 100644 --- a/src/transformers/modeling_albert.py +++ b/src/transformers/modeling_albert.py @@ -550,7 +550,7 @@ class AlbertModel(AlbertPreTrainedModel): token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) - extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) diff --git a/src/transformers/modeling_bert.py b/src/transformers/modeling_bert.py index 3e409cfb74..1e31b5c402 100644 --- a/src/transformers/modeling_bert.py +++ b/src/transformers/modeling_bert.py @@ -703,9 +703,7 @@ class BertModel(BertPreTrainedModel): # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask: torch.Tensor = self.get_extended_attention_mask( - attention_mask, input_shape, self.device - ) + extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device) # If a 2D ou 3D attention mask is provided for the cross-attention # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length] diff --git a/src/transformers/modeling_t5.py b/src/transformers/modeling_t5.py index 6fcc9453d2..b363316337 100644 --- a/src/transformers/modeling_t5.py +++ b/src/transformers/modeling_t5.py @@ -704,7 +704,7 @@ class T5Stack(T5PreTrainedModel): past_key_value_states = [None] * len(self.block) # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, self.device) + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device) if self.is_decoder and encoder_attention_mask is not None: encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 80b51b2912..8fa1c0d3d8 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -17,7 +17,7 @@ import inspect import logging import os -from typing import Callable, Tuple +from typing import Callable, List, Tuple import torch from torch import Tensor, device, dtype, nn @@ -110,11 +110,33 @@ class ModuleUtilsMixin: @property def device(self) -> device: - return next(self.parameters()).device + try: + return next(self.parameters()).device + except StopIteration: + # For nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = self._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].device @property def dtype(self) -> dtype: - return next(self.parameters()).dtype + try: + return next(self.parameters()).dtype + except StopIteration: + # For nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = self._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].dtype def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor: """type: torch.Tensor -> torch.Tensor""" diff --git a/src/transformers/modeling_xlnet.py b/src/transformers/modeling_xlnet.py index 86905ac1bc..1d0f873de4 100644 --- a/src/transformers/modeling_xlnet.py +++ b/src/transformers/modeling_xlnet.py @@ -623,7 +623,7 @@ class XLNetModel(XLNetPreTrainedModel): mask_lo = torch.tril(attn_mask, diagonal=-1) ret = torch.cat([ret[:, :qlen] + mask_lo, ret[:, qlen:]], dim=1) - ret = ret.to(next(self.parameters())) + ret = ret.to(self.device) return ret def cache_mem(self, curr_out, prev_mem): @@ -685,7 +685,7 @@ class XLNetModel(XLNetPreTrainedModel): fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len) pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz) - pos_emb = pos_emb.to(next(self.parameters())) + pos_emb = pos_emb.to(self.device) return pos_emb @add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING) @@ -761,8 +761,8 @@ class XLNetModel(XLNetPreTrainedModel): mlen = mems[0].shape[0] if mems is not None and mems[0] is not None else 0 klen = mlen + qlen - dtype_float = next(self.parameters()).dtype - device = next(self.parameters()).device + dtype_float = self.dtype + device = self.device # Attention mask # causal attention mask diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 8c9c6a9f5a..009350564a 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -23,7 +23,7 @@ from typing import List from transformers import is_torch_available -from .utils import require_torch, slow, torch_device +from .utils import require_multigpu, require_torch, slow, torch_device if is_torch_available(): @@ -758,6 +758,31 @@ class ModelTesterMixin: return True return False + @require_multigpu + def test_multigpu_data_parallel_forward(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + # some params shouldn't be scattered by nn.DataParallel + # so just remove them if they are present. + blacklist_non_batched_params = ["head_mask"] + for k in blacklist_non_batched_params: + inputs_dict.pop(k, None) + + # move input tensors to cuda:O + for k, v in inputs_dict.items(): + if torch.is_tensor(v): + inputs_dict[k] = v.to(0) + + for model_class in self.all_model_classes: + model = model_class(config=config) + model.to(0) + model.eval() + + # Wrap model in nn.DataParallel + model = torch.nn.DataParallel(model) + with torch.no_grad(): + _ = model(**inputs_dict) + global_rng = random.Random() diff --git a/tests/test_modeling_ctrl.py b/tests/test_modeling_ctrl.py index e6f39c1d7c..098abc90c5 100644 --- a/tests/test_modeling_ctrl.py +++ b/tests/test_modeling_ctrl.py @@ -41,7 +41,7 @@ class CTRLModelTest(ModelTesterMixin, unittest.TestCase): def __init__( self, parent, - batch_size=13, + batch_size=14, seq_length=7, is_training=True, use_token_type_ids=True, diff --git a/tests/test_modeling_gpt2.py b/tests/test_modeling_gpt2.py index ec9940cb8f..3d8b890f1c 100644 --- a/tests/test_modeling_gpt2.py +++ b/tests/test_modeling_gpt2.py @@ -46,7 +46,7 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase): def __init__( self, parent, - batch_size=13, + batch_size=14, seq_length=7, is_training=True, use_token_type_ids=True, diff --git a/tests/test_modeling_reformer.py b/tests/test_modeling_reformer.py index c79b212a8c..0ffc21abcf 100644 --- a/tests/test_modeling_reformer.py +++ b/tests/test_modeling_reformer.py @@ -19,7 +19,7 @@ from transformers import is_torch_available from .test_configuration_common import ConfigTester from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor -from .utils import require_torch, slow, torch_device +from .utils import require_multigpu, require_torch, slow, torch_device if is_torch_available(): @@ -448,9 +448,14 @@ class ReformerTesterMixin: config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_reformer_model_fp16_generate(*config_and_inputs) + @require_multigpu + def test_multigpu_data_parallel_forward(self): + # Opt-out of this test. + pass + @require_torch -class ReformerLocalAttnModelTest(ModelTesterMixin, ReformerTesterMixin, unittest.TestCase): +class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.TestCase): all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else () all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else () test_pruning = False @@ -504,7 +509,7 @@ class ReformerLocalAttnModelTest(ModelTesterMixin, ReformerTesterMixin, unittest @require_torch -class ReformerLSHAttnModelTest(ModelTesterMixin, unittest.TestCase, ReformerTesterMixin): +class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.TestCase): all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else () all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else () test_pruning = False diff --git a/tests/test_modeling_transfo_xl.py b/tests/test_modeling_transfo_xl.py index 494c84d513..c9317b7b78 100644 --- a/tests/test_modeling_transfo_xl.py +++ b/tests/test_modeling_transfo_xl.py @@ -21,7 +21,7 @@ from transformers import is_torch_available from .test_configuration_common import ConfigTester from .test_modeling_common import ModelTesterMixin, ids_tensor -from .utils import require_torch, slow, torch_device +from .utils import require_multigpu, require_torch, slow, torch_device if is_torch_available(): @@ -43,7 +43,7 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase): def __init__( self, parent, - batch_size=13, + batch_size=14, seq_length=7, mem_len=30, clamp_len=15, @@ -207,6 +207,11 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase): output_result = self.model_tester.create_transfo_xl_lm_head(*config_and_inputs) self.model_tester.check_transfo_xl_lm_head_output(output_result) + @require_multigpu + def test_multigpu_data_parallel_forward(self): + # Opt-out of this test. + pass + @slow def test_model_from_pretrained(self): for model_name in list(TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: diff --git a/tests/test_modeling_xlnet.py b/tests/test_modeling_xlnet.py index e2a3037053..e076c2c266 100644 --- a/tests/test_modeling_xlnet.py +++ b/tests/test_modeling_xlnet.py @@ -61,7 +61,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase): def __init__( self, parent, - batch_size=13, + batch_size=14, seq_length=7, mem_len=10, clamp_len=-1, diff --git a/tests/utils.py b/tests/utils.py index b932e2154a..556e3fbff9 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -94,6 +94,25 @@ def require_tf(test_case): return test_case +def require_multigpu(test_case): + """ + Decorator marking a test that requires a multi-GPU setup (in PyTorch). + + These tests are skipped on a machine without multiple GPUs. + + To run *only* the multigpu tests, assuming all test names contain multigpu: + $ pytest -sv ./tests -k "multigpu" + """ + if not _torch_available: + return unittest.skip("test requires PyTorch")(test_case) + + import torch + + if torch.cuda.device_count() < 2: + return unittest.skip("test requires multiple GPUs")(test_case) + return test_case + + if _torch_available: # Set the USE_CUDA environment variable to select a GPU. torch_device = "cuda" if parse_flag_from_env("USE_CUDA") else "cpu"