Test XLA examples (#5583)
* Test XLA examples * Style * Using `require_torch_tpu` * Style * No need for pytest
This commit is contained in:
@@ -2,7 +2,7 @@ import os
|
||||
import unittest
|
||||
from distutils.util import strtobool
|
||||
|
||||
from transformers.file_utils import _tf_available, _torch_available
|
||||
from transformers.file_utils import _tf_available, _torch_available, _torch_tpu_available
|
||||
|
||||
|
||||
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
|
||||
@@ -113,6 +113,16 @@ def require_multigpu(test_case):
|
||||
return test_case
|
||||
|
||||
|
||||
def require_torch_tpu(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires a TPU (in PyTorch).
|
||||
"""
|
||||
if not _torch_tpu_available:
|
||||
return unittest.skip("test requires PyTorch TPU")
|
||||
|
||||
return test_case
|
||||
|
||||
|
||||
if _torch_available:
|
||||
# Set the USE_CUDA environment variable to select a GPU.
|
||||
torch_device = "cuda" if parse_flag_from_env("USE_CUDA") else "cpu"
|
||||
|
||||
Reference in New Issue
Block a user