Fix nn.DataParallel compatibility in PyTorch 1.5 (#4300)
* Test case for #3936 * multigpu tests pass on pytorch 1.4.0 * Fixup * multigpu tests pass on pytorch 1.5.0 * Update src/transformers/modeling_utils.py * Update src/transformers/modeling_utils.py * rename multigpu to require_multigpu * mode doc
This commit is contained in:
@@ -550,7 +550,7 @@ class AlbertModel(AlbertPreTrainedModel):
|
|||||||
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
|
||||||
|
|
||||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
||||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||||
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
||||||
|
|
||||||
|
|||||||
@@ -703,9 +703,7 @@ class BertModel(BertPreTrainedModel):
|
|||||||
|
|
||||||
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
|
||||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||||
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(
|
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
|
||||||
attention_mask, input_shape, self.device
|
|
||||||
)
|
|
||||||
|
|
||||||
# If a 2D ou 3D attention mask is provided for the cross-attention
|
# If a 2D ou 3D attention mask is provided for the cross-attention
|
||||||
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
|
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
|
||||||
|
|||||||
@@ -704,7 +704,7 @@ class T5Stack(T5PreTrainedModel):
|
|||||||
past_key_value_states = [None] * len(self.block)
|
past_key_value_states = [None] * len(self.block)
|
||||||
|
|
||||||
# ourselves in which case we just need to make it broadcastable to all heads.
|
# ourselves in which case we just need to make it broadcastable to all heads.
|
||||||
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, self.device)
|
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, inputs_embeds.device)
|
||||||
|
|
||||||
if self.is_decoder and encoder_attention_mask is not None:
|
if self.is_decoder and encoder_attention_mask is not None:
|
||||||
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Callable, Tuple
|
from typing import Callable, List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor, device, dtype, nn
|
from torch import Tensor, device, dtype, nn
|
||||||
@@ -110,11 +110,33 @@ class ModuleUtilsMixin:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self) -> device:
|
def device(self) -> device:
|
||||||
|
try:
|
||||||
return next(self.parameters()).device
|
return next(self.parameters()).device
|
||||||
|
except StopIteration:
|
||||||
|
# For nn.DataParallel compatibility in PyTorch 1.5
|
||||||
|
|
||||||
|
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
|
||||||
|
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
||||||
|
return tuples
|
||||||
|
|
||||||
|
gen = self._named_members(get_members_fn=find_tensor_attributes)
|
||||||
|
first_tuple = next(gen)
|
||||||
|
return first_tuple[1].device
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self) -> dtype:
|
def dtype(self) -> dtype:
|
||||||
|
try:
|
||||||
return next(self.parameters()).dtype
|
return next(self.parameters()).dtype
|
||||||
|
except StopIteration:
|
||||||
|
# For nn.DataParallel compatibility in PyTorch 1.5
|
||||||
|
|
||||||
|
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
|
||||||
|
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
||||||
|
return tuples
|
||||||
|
|
||||||
|
gen = self._named_members(get_members_fn=find_tensor_attributes)
|
||||||
|
first_tuple = next(gen)
|
||||||
|
return first_tuple[1].dtype
|
||||||
|
|
||||||
def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
|
def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
|
||||||
"""type: torch.Tensor -> torch.Tensor"""
|
"""type: torch.Tensor -> torch.Tensor"""
|
||||||
|
|||||||
@@ -623,7 +623,7 @@ class XLNetModel(XLNetPreTrainedModel):
|
|||||||
mask_lo = torch.tril(attn_mask, diagonal=-1)
|
mask_lo = torch.tril(attn_mask, diagonal=-1)
|
||||||
ret = torch.cat([ret[:, :qlen] + mask_lo, ret[:, qlen:]], dim=1)
|
ret = torch.cat([ret[:, :qlen] + mask_lo, ret[:, qlen:]], dim=1)
|
||||||
|
|
||||||
ret = ret.to(next(self.parameters()))
|
ret = ret.to(self.device)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def cache_mem(self, curr_out, prev_mem):
|
def cache_mem(self, curr_out, prev_mem):
|
||||||
@@ -685,7 +685,7 @@ class XLNetModel(XLNetPreTrainedModel):
|
|||||||
fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
|
fwd_pos_seq = fwd_pos_seq.clamp(-self.clamp_len, self.clamp_len)
|
||||||
pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)
|
pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)
|
||||||
|
|
||||||
pos_emb = pos_emb.to(next(self.parameters()))
|
pos_emb = pos_emb.to(self.device)
|
||||||
return pos_emb
|
return pos_emb
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_callable(XLNET_INPUTS_DOCSTRING)
|
||||||
@@ -761,8 +761,8 @@ class XLNetModel(XLNetPreTrainedModel):
|
|||||||
mlen = mems[0].shape[0] if mems is not None and mems[0] is not None else 0
|
mlen = mems[0].shape[0] if mems is not None and mems[0] is not None else 0
|
||||||
klen = mlen + qlen
|
klen = mlen + qlen
|
||||||
|
|
||||||
dtype_float = next(self.parameters()).dtype
|
dtype_float = self.dtype
|
||||||
device = next(self.parameters()).device
|
device = self.device
|
||||||
|
|
||||||
# Attention mask
|
# Attention mask
|
||||||
# causal attention mask
|
# causal attention mask
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ from typing import List
|
|||||||
|
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
|
|
||||||
from .utils import require_torch, slow, torch_device
|
from .utils import require_multigpu, require_torch, slow, torch_device
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -758,6 +758,31 @@ class ModelTesterMixin:
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@require_multigpu
|
||||||
|
def test_multigpu_data_parallel_forward(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
# some params shouldn't be scattered by nn.DataParallel
|
||||||
|
# so just remove them if they are present.
|
||||||
|
blacklist_non_batched_params = ["head_mask"]
|
||||||
|
for k in blacklist_non_batched_params:
|
||||||
|
inputs_dict.pop(k, None)
|
||||||
|
|
||||||
|
# move input tensors to cuda:O
|
||||||
|
for k, v in inputs_dict.items():
|
||||||
|
if torch.is_tensor(v):
|
||||||
|
inputs_dict[k] = v.to(0)
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
model = model_class(config=config)
|
||||||
|
model.to(0)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# Wrap model in nn.DataParallel
|
||||||
|
model = torch.nn.DataParallel(model)
|
||||||
|
with torch.no_grad():
|
||||||
|
_ = model(**inputs_dict)
|
||||||
|
|
||||||
|
|
||||||
global_rng = random.Random()
|
global_rng = random.Random()
|
||||||
|
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ class CTRLModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
parent,
|
parent,
|
||||||
batch_size=13,
|
batch_size=14,
|
||||||
seq_length=7,
|
seq_length=7,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
use_token_type_ids=True,
|
use_token_type_ids=True,
|
||||||
|
|||||||
@@ -46,7 +46,7 @@ class GPT2ModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
parent,
|
parent,
|
||||||
batch_size=13,
|
batch_size=14,
|
||||||
seq_length=7,
|
seq_length=7,
|
||||||
is_training=True,
|
is_training=True,
|
||||||
use_token_type_ids=True,
|
use_token_type_ids=True,
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from transformers import is_torch_available
|
|||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
|
||||||
from .utils import require_torch, slow, torch_device
|
from .utils import require_multigpu, require_torch, slow, torch_device
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -448,9 +448,14 @@ class ReformerTesterMixin:
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_reformer_model_fp16_generate(*config_and_inputs)
|
self.model_tester.create_and_check_reformer_model_fp16_generate(*config_and_inputs)
|
||||||
|
|
||||||
|
@require_multigpu
|
||||||
|
def test_multigpu_data_parallel_forward(self):
|
||||||
|
# Opt-out of this test.
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class ReformerLocalAttnModelTest(ModelTesterMixin, ReformerTesterMixin, unittest.TestCase):
|
class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else ()
|
all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else ()
|
||||||
all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else ()
|
all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else ()
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
@@ -504,7 +509,7 @@ class ReformerLocalAttnModelTest(ModelTesterMixin, ReformerTesterMixin, unittest
|
|||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
class ReformerLSHAttnModelTest(ModelTesterMixin, unittest.TestCase, ReformerTesterMixin):
|
class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else ()
|
all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else ()
|
||||||
all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else ()
|
all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else ()
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ from transformers import is_torch_available
|
|||||||
|
|
||||||
from .test_configuration_common import ConfigTester
|
from .test_configuration_common import ConfigTester
|
||||||
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
from .test_modeling_common import ModelTesterMixin, ids_tensor
|
||||||
from .utils import require_torch, slow, torch_device
|
from .utils import require_multigpu, require_torch, slow, torch_device
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -43,7 +43,7 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
parent,
|
parent,
|
||||||
batch_size=13,
|
batch_size=14,
|
||||||
seq_length=7,
|
seq_length=7,
|
||||||
mem_len=30,
|
mem_len=30,
|
||||||
clamp_len=15,
|
clamp_len=15,
|
||||||
@@ -207,6 +207,11 @@ class TransfoXLModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
output_result = self.model_tester.create_transfo_xl_lm_head(*config_and_inputs)
|
output_result = self.model_tester.create_transfo_xl_lm_head(*config_and_inputs)
|
||||||
self.model_tester.check_transfo_xl_lm_head_output(output_result)
|
self.model_tester.check_transfo_xl_lm_head_output(output_result)
|
||||||
|
|
||||||
|
@require_multigpu
|
||||||
|
def test_multigpu_data_parallel_forward(self):
|
||||||
|
# Opt-out of this test.
|
||||||
|
pass
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_model_from_pretrained(self):
|
def test_model_from_pretrained(self):
|
||||||
for model_name in list(TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
for model_name in list(TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ class XLNetModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
parent,
|
parent,
|
||||||
batch_size=13,
|
batch_size=14,
|
||||||
seq_length=7,
|
seq_length=7,
|
||||||
mem_len=10,
|
mem_len=10,
|
||||||
clamp_len=-1,
|
clamp_len=-1,
|
||||||
|
|||||||
@@ -94,6 +94,25 @@ def require_tf(test_case):
|
|||||||
return test_case
|
return test_case
|
||||||
|
|
||||||
|
|
||||||
|
def require_multigpu(test_case):
|
||||||
|
"""
|
||||||
|
Decorator marking a test that requires a multi-GPU setup (in PyTorch).
|
||||||
|
|
||||||
|
These tests are skipped on a machine without multiple GPUs.
|
||||||
|
|
||||||
|
To run *only* the multigpu tests, assuming all test names contain multigpu:
|
||||||
|
$ pytest -sv ./tests -k "multigpu"
|
||||||
|
"""
|
||||||
|
if not _torch_available:
|
||||||
|
return unittest.skip("test requires PyTorch")(test_case)
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if torch.cuda.device_count() < 2:
|
||||||
|
return unittest.skip("test requires multiple GPUs")(test_case)
|
||||||
|
return test_case
|
||||||
|
|
||||||
|
|
||||||
if _torch_available:
|
if _torch_available:
|
||||||
# Set the USE_CUDA environment variable to select a GPU.
|
# Set the USE_CUDA environment variable to select a GPU.
|
||||||
torch_device = "cuda" if parse_flag_from_env("USE_CUDA") else "cpu"
|
torch_device = "cuda" if parse_flag_from_env("USE_CUDA") else "cpu"
|
||||||
|
|||||||
Reference in New Issue
Block a user