Test XLA examples (#5583)

* Test XLA examples

* Style

* Using `require_torch_tpu`

* Style

* No need for pytest
This commit is contained in:
Lysandre Debut
2020-07-09 09:19:19 -04:00
committed by GitHub
parent 3bd55199cd
commit 0533cf4706
2 changed files with 102 additions and 1 deletions

View File

@@ -2,7 +2,7 @@ import os
import unittest
from distutils.util import strtobool
from transformers.file_utils import _tf_available, _torch_available
from transformers.file_utils import _tf_available, _torch_available, _torch_tpu_available
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
@@ -113,6 +113,16 @@ def require_multigpu(test_case):
return test_case
def require_torch_tpu(test_case):
"""
Decorator marking a test that requires a TPU (in PyTorch).
"""
if not _torch_tpu_available:
return unittest.skip("test requires PyTorch TPU")
return test_case
if _torch_available:
# Set the USE_CUDA environment variable to select a GPU.
torch_device = "cuda" if parse_flag_from_env("USE_CUDA") else "cpu"