Fix nn.DataParallel compatibility in PyTorch 1.5 (#4300)
* Test case for #3936 * multigpu tests pass on pytorch 1.4.0 * Fixup * multigpu tests pass on pytorch 1.5.0 * Update src/transformers/modeling_utils.py * Update src/transformers/modeling_utils.py * rename multigpu to require_multigpu * mode doc
This commit is contained in:
@@ -23,7 +23,7 @@ from typing import List
|
||||
|
||||
from transformers import is_torch_available
|
||||
|
||||
from .utils import require_torch, slow, torch_device
|
||||
from .utils import require_multigpu, require_torch, slow, torch_device
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -758,6 +758,31 @@ class ModelTesterMixin:
|
||||
return True
|
||||
return False
|
||||
|
||||
@require_multigpu
|
||||
def test_multigpu_data_parallel_forward(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
# some params shouldn't be scattered by nn.DataParallel
|
||||
# so just remove them if they are present.
|
||||
blacklist_non_batched_params = ["head_mask"]
|
||||
for k in blacklist_non_batched_params:
|
||||
inputs_dict.pop(k, None)
|
||||
|
||||
# move input tensors to cuda:O
|
||||
for k, v in inputs_dict.items():
|
||||
if torch.is_tensor(v):
|
||||
inputs_dict[k] = v.to(0)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=config)
|
||||
model.to(0)
|
||||
model.eval()
|
||||
|
||||
# Wrap model in nn.DataParallel
|
||||
model = torch.nn.DataParallel(model)
|
||||
with torch.no_grad():
|
||||
_ = model(**inputs_dict)
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user