Fix nn.DataParallel compatibility in PyTorch 1.5 (#4300)
* Test case for #3936 * multigpu tests pass on pytorch 1.4.0 * Fixup * multigpu tests pass on pytorch 1.5.0 * Update src/transformers/modeling_utils.py * Update src/transformers/modeling_utils.py * rename multigpu to require_multigpu * mode doc
This commit is contained in:
@@ -94,6 +94,25 @@ def require_tf(test_case):
|
||||
return test_case
|
||||
|
||||
|
||||
def require_multigpu(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires a multi-GPU setup (in PyTorch).
|
||||
|
||||
These tests are skipped on a machine without multiple GPUs.
|
||||
|
||||
To run *only* the multigpu tests, assuming all test names contain multigpu:
|
||||
$ pytest -sv ./tests -k "multigpu"
|
||||
"""
|
||||
if not _torch_available:
|
||||
return unittest.skip("test requires PyTorch")(test_case)
|
||||
|
||||
import torch
|
||||
|
||||
if torch.cuda.device_count() < 2:
|
||||
return unittest.skip("test requires multiple GPUs")(test_case)
|
||||
return test_case
|
||||
|
||||
|
||||
if _torch_available:
|
||||
# Set the USE_CUDA environment variable to select a GPU.
|
||||
torch_device = "cuda" if parse_flag_from_env("USE_CUDA") else "cpu"
|
||||
|
||||
Reference in New Issue
Block a user