Fix nn.DataParallel compatibility in PyTorch 1.5 (#4300)

* Test case for #3936

* multigpu tests pass on pytorch 1.4.0

* Fixup

* multigpu tests pass on pytorch 1.5.0

* Update src/transformers/modeling_utils.py

* Update src/transformers/modeling_utils.py

* rename multigpu to require_multigpu

* mode doc
This commit is contained in:
Julien Chaumond
2020-05-18 20:34:50 -04:00
committed by GitHub
parent 9de4afa897
commit 4c06893610
12 changed files with 95 additions and 21 deletions

View File

@@ -19,7 +19,7 @@ from transformers import is_torch_available
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from .utils import require_torch, slow, torch_device
from .utils import require_multigpu, require_torch, slow, torch_device
if is_torch_available():
@@ -448,9 +448,14 @@ class ReformerTesterMixin:
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_reformer_model_fp16_generate(*config_and_inputs)
@require_multigpu
def test_multigpu_data_parallel_forward(self):
# Opt-out of this test.
pass
@require_torch
class ReformerLocalAttnModelTest(ModelTesterMixin, ReformerTesterMixin, unittest.TestCase):
class ReformerLocalAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.TestCase):
all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else ()
all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else ()
test_pruning = False
@@ -504,7 +509,7 @@ class ReformerLocalAttnModelTest(ModelTesterMixin, ReformerTesterMixin, unittest
@require_torch
class ReformerLSHAttnModelTest(ModelTesterMixin, unittest.TestCase, ReformerTesterMixin):
class ReformerLSHAttnModelTest(ReformerTesterMixin, ModelTesterMixin, unittest.TestCase):
all_model_classes = (ReformerModel, ReformerModelWithLMHead) if is_torch_available() else ()
all_generative_model_classes = (ReformerModelWithLMHead,) if is_torch_available() else ()
test_pruning = False