Move tests/utils.py -> transformers/testing_utils.py (#5350)

This commit is contained in:
Sam Shleifer
2020-07-01 10:31:17 -04:00
committed by GitHub
parent 9c219305f5
commit 13deb95a40
66 changed files with 66 additions and 86 deletions

View File

@@ -1,8 +1,7 @@
import unittest
from transformers import is_torch_available
from .utils import require_torch
from transformers.testing_utils import require_torch
if is_torch_available():

View File

@@ -4,8 +4,7 @@ import unittest
from pathlib import Path
from transformers import AutoConfig, is_torch_available
from .utils import require_torch, torch_device
from transformers.testing_utils import require_torch, torch_device
if is_torch_available():

View File

@@ -4,8 +4,7 @@ import unittest
from pathlib import Path
from transformers import AutoConfig, is_tf_available
from .utils import require_tf
from transformers.testing_utils import require_tf
if is_tf_available():

View File

@@ -19,8 +19,7 @@ import unittest
from transformers.configuration_auto import CONFIG_MAPPING, AutoConfig
from transformers.configuration_bert import BertConfig
from transformers.configuration_roberta import RobertaConfig
from .utils import DUMMY_UNKWOWN_IDENTIFIER
from transformers.testing_utils import DUMMY_UNKWOWN_IDENTIFIER
SAMPLE_ROBERTA_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/dummy-config.json")

View File

@@ -21,8 +21,7 @@ from pathlib import Path
from typing import List, Union
import transformers
from .utils import require_tf, require_torch, slow
from transformers.testing_utils import require_tf, require_torch, slow
logger = logging.getLogger()

View File

@@ -17,10 +17,10 @@
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import require_torch, slow, torch_device
if is_torch_available():

View File

@@ -17,8 +17,7 @@
import unittest
from transformers import is_torch_available
from .utils import DUMMY_UNKWOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER, require_torch, slow
from transformers.testing_utils import DUMMY_UNKWOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER, require_torch, slow
if is_torch_available():

View File

@@ -20,10 +20,10 @@ import timeout_decorator # noqa
from transformers import is_torch_available
from transformers.file_utils import cached_property
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import require_torch, slow, torch_device
if is_torch_available():

View File

@@ -17,10 +17,10 @@
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from .utils import require_torch, slow, torch_device
if is_torch_available():

View File

@@ -16,8 +16,7 @@
import unittest
from transformers import is_torch_available
from .utils import require_torch, slow, torch_device
from transformers.testing_utils import require_torch, slow, torch_device
if is_torch_available():

View File

@@ -21,8 +21,7 @@ import unittest
from typing import List
from transformers import is_torch_available
from .utils import require_multigpu, require_torch, slow, torch_device
from transformers.testing_utils import require_multigpu, require_torch, slow, torch_device
if is_torch_available():

View File

@@ -16,10 +16,10 @@
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import require_torch, slow, torch_device
if is_torch_available():

View File

@@ -17,10 +17,10 @@
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, torch_device
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import require_torch, torch_device
if is_torch_available():

View File

@@ -17,10 +17,10 @@
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import require_torch, slow, torch_device
if is_torch_available():

View File

@@ -18,12 +18,12 @@ import tempfile
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
# TODO(PVP): this line reruns all the tests in BertModelTest; not sure whether this can be prevented
# for now only run module with pytest tests/test_modeling_encoder_decoder.py::EncoderDecoderModelTest
from .test_modeling_bert import BertModelTester
from .test_modeling_common import ids_tensor
from .utils import require_torch, slow, torch_device
if is_torch_available():

View File

@@ -17,10 +17,10 @@
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import require_torch, slow, torch_device
if is_torch_available():

View File

@@ -17,10 +17,10 @@
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import require_torch, slow, torch_device
if is_torch_available():

View File

@@ -17,10 +17,10 @@
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import require_torch, slow, torch_device
if is_torch_available():

View File

@@ -19,8 +19,7 @@ import unittest
from transformers import is_torch_available
from transformers.file_utils import cached_property
from transformers.hf_api import HfApi
from .utils import require_torch, slow, torch_device
from transformers.testing_utils import require_torch, slow, torch_device
if is_torch_available():

View File

@@ -17,10 +17,10 @@
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from .utils import require_torch, slow, torch_device
if is_torch_available():

View File

@@ -17,10 +17,10 @@
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import require_torch, slow, torch_device
if is_torch_available():

View File

@@ -16,10 +16,10 @@
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_multigpu, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
from .utils import require_multigpu, require_torch, slow, torch_device
if is_torch_available():

View File

@@ -17,10 +17,10 @@
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import require_torch, slow, torch_device
if is_torch_available():

View File

@@ -17,10 +17,10 @@
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import require_torch, slow, torch_device
if is_torch_available():

View File

@@ -17,10 +17,10 @@
import unittest
from transformers import AlbertConfig, is_tf_available
from transformers.testing_utils import require_tf, slow
from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from .utils import require_tf, slow
if is_tf_available():

View File

@@ -17,8 +17,7 @@
import unittest
from transformers import is_tf_available
from .utils import DUMMY_UNKWOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER, require_tf, slow
from transformers.testing_utils import DUMMY_UNKWOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER, require_tf, slow
if is_tf_available():

View File

@@ -17,10 +17,10 @@
import unittest
from transformers import BertConfig, is_tf_available
from transformers.testing_utils import require_tf, slow
from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from .utils import require_tf, slow
if is_tf_available():

View File

@@ -16,8 +16,7 @@
import unittest
from transformers import is_tf_available
from .utils import require_tf, slow
from transformers.testing_utils import require_tf, slow
if is_tf_available():

View File

@@ -23,8 +23,7 @@ import unittest
from importlib import import_module
from transformers import is_tf_available, is_torch_available
from .utils import _tf_gpu_memory_limit, require_tf
from transformers.testing_utils import _tf_gpu_memory_limit, require_tf
if is_tf_available():

View File

@@ -17,10 +17,10 @@
import unittest
from transformers import CTRLConfig, is_tf_available
from transformers.testing_utils import require_tf, slow
from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from .utils import require_tf, slow
if is_tf_available():

View File

@@ -17,10 +17,10 @@
import unittest
from transformers import DistilBertConfig, is_tf_available
from transformers.testing_utils import require_tf
from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from .utils import require_tf
if is_tf_available():

View File

@@ -17,10 +17,10 @@
import unittest
from transformers import ElectraConfig, is_tf_available
from transformers.testing_utils import require_tf, slow
from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from .utils import require_tf, slow
if is_tf_available():

View File

@@ -16,8 +16,7 @@
import unittest
from transformers import is_tf_available
from .utils import require_tf, slow
from transformers.testing_utils import require_tf, slow
if is_tf_available():

View File

@@ -17,10 +17,10 @@
import unittest
from transformers import GPT2Config, is_tf_available
from transformers.testing_utils import require_tf, slow
from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from .utils import require_tf, slow
if is_tf_available():

View File

@@ -17,10 +17,10 @@
import unittest
from transformers import MobileBertConfig, is_tf_available
from transformers.testing_utils import require_tf, slow
from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from .utils import require_tf, slow
if is_tf_available():

View File

@@ -17,10 +17,10 @@
import unittest
from transformers import OpenAIGPTConfig, is_tf_available
from transformers.testing_utils import require_tf, slow
from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from .utils import require_tf, slow
if is_tf_available():

View File

@@ -17,10 +17,10 @@
import unittest
from transformers import RobertaConfig, is_tf_available
from transformers.testing_utils import require_tf, slow
from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from .utils import require_tf, slow
if is_tf_available():

View File

@@ -17,10 +17,10 @@
import unittest
from transformers import T5Config, is_tf_available
from transformers.testing_utils import require_tf, slow
from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from .utils import require_tf, slow
if is_tf_available():

View File

@@ -18,10 +18,10 @@ import random
import unittest
from transformers import TransfoXLConfig, is_tf_available
from transformers.testing_utils import require_tf, slow
from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from .utils import require_tf, slow
if is_tf_available():

View File

@@ -17,10 +17,10 @@
import unittest
from transformers import is_tf_available
from transformers.testing_utils import require_tf, slow
from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from .utils import require_tf, slow
if is_tf_available():

View File

@@ -16,8 +16,7 @@
import unittest
from transformers import is_tf_available
from .utils import require_tf, slow
from transformers.testing_utils import require_tf, slow
if is_tf_available():

View File

@@ -18,10 +18,10 @@ import random
import unittest
from transformers import XLNetConfig, is_tf_available
from transformers.testing_utils import require_tf, slow
from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
from .utils import require_tf, slow
if is_tf_available():

View File

@@ -17,10 +17,10 @@ import random
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_multigpu, require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import require_multigpu, require_torch, slow, torch_device
if is_torch_available():

View File

@@ -17,10 +17,10 @@
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import require_torch, slow, torch_device
if is_torch_available():

View File

@@ -17,8 +17,7 @@
import unittest
from transformers import is_torch_available
from .utils import slow
from transformers.testing_utils import slow
if is_torch_available():

View File

@@ -18,10 +18,10 @@ import random
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor
from .utils import require_torch, slow, torch_device
if is_torch_available():

View File

@@ -3,9 +3,9 @@ from os.path import dirname, exists
from shutil import rmtree
from tempfile import NamedTemporaryFile, TemporaryDirectory
from tests.utils import require_tf, require_torch, slow
from transformers import BertConfig, BertTokenizerFast, FeatureExtractionPipeline
from transformers.convert_graph_to_onnx import convert, ensure_valid_input, infer_shapes
from transformers.testing_utils import require_tf, require_torch, slow
class FuncContiguousArgs:

View File

@@ -19,8 +19,7 @@ import tempfile
import unittest
from transformers import is_torch_available
from .utils import require_torch
from transformers.testing_utils import require_torch
if is_torch_available():

View File

@@ -1,8 +1,7 @@
import unittest
from transformers import is_tf_available
from .utils import require_tf
from transformers.testing_utils import require_tf
if is_tf_available():

View File

@@ -3,8 +3,7 @@ from typing import Iterable, List, Optional
from transformers import pipeline
from transformers.pipelines import SUPPORTED_TASKS, DefaultArgumentHandler, Pipeline
from .utils import require_tf, require_torch, slow, torch_device
from transformers.testing_utils import require_tf, require_torch, slow, torch_device
DEFAULT_DEVICE_NUM = -1 if torch_device == "cpu" else 0

View File

@@ -27,10 +27,9 @@ from transformers import (
RobertaTokenizer,
RobertaTokenizerFast,
)
from transformers.testing_utils import DUMMY_UNKWOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER # noqa: F401
from transformers.tokenization_auto import TOKENIZER_MAPPING
from .utils import DUMMY_UNKWOWN_IDENTIFIER, SMALL_MODEL_IDENTIFIER, slow # noqa: F401
class AutoTokenizerTest(unittest.TestCase):
# @slow

View File

@@ -17,6 +17,7 @@
import os
import unittest
from transformers.testing_utils import slow
from transformers.tokenization_bert import (
VOCAB_FILES_NAMES,
BasicTokenizer,
@@ -29,7 +30,6 @@ from transformers.tokenization_bert import (
)
from .test_tokenization_common import TokenizerTesterMixin
from .utils import slow
class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):

View File

@@ -17,6 +17,7 @@
import os
import unittest
from transformers.testing_utils import custom_tokenizers
from transformers.tokenization_bert import WordpieceTokenizer
from transformers.tokenization_bert_japanese import (
VOCAB_FILES_NAMES,
@@ -26,7 +27,6 @@ from transformers.tokenization_bert_japanese import (
)
from .test_tokenization_common import TokenizerTesterMixin
from .utils import custom_tokenizers
@custom_tokenizers

View File

@@ -22,8 +22,8 @@ import tempfile
from collections import OrderedDict
from typing import TYPE_CHECKING, Dict, List, Tuple, Union
from tests.utils import require_tf, require_torch, slow
from transformers import PreTrainedTokenizer, PreTrainedTokenizerBase, PreTrainedTokenizerFast
from transformers.testing_utils import require_tf, require_torch, slow
if TYPE_CHECKING:

View File

@@ -14,10 +14,10 @@
# limitations under the License.
from transformers.testing_utils import slow
from transformers.tokenization_distilbert import DistilBertTokenizer, DistilBertTokenizerFast
from .test_tokenization_bert import BertTokenizationTest
from .utils import slow
class DistilBertTokenizationTest(BertTokenizationTest):

View File

@@ -3,7 +3,6 @@ import unittest
from collections import namedtuple
from itertools import takewhile
from tests.utils import require_torch
from transformers import (
BertTokenizer,
BertTokenizerFast,
@@ -16,6 +15,7 @@ from transformers import (
TransfoXLTokenizer,
is_torch_available,
)
from transformers.testing_utils import require_torch
from transformers.tokenization_distilbert import DistilBertTokenizerFast
from transformers.tokenization_openai import OpenAIGPTTokenizerFast
from transformers.tokenization_roberta import RobertaTokenizerFast

View File

@@ -18,10 +18,10 @@ import json
import os
import unittest
from transformers.testing_utils import slow
from transformers.tokenization_roberta import VOCAB_FILES_NAMES, AddedToken, RobertaTokenizer, RobertaTokenizerFast
from .test_tokenization_common import TokenizerTesterMixin
from .utils import slow
class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):

View File

@@ -18,9 +18,9 @@ import os
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_torch
from .test_tokenization_common import TokenizerTesterMixin
from .utils import require_torch
if is_torch_available():

View File

@@ -17,10 +17,9 @@ import unittest
from typing import Callable, Optional
from transformers import BatchEncoding, BertTokenizer, BertTokenizerFast, PreTrainedTokenizer, TensorType
from transformers.testing_utils import require_tf, require_torch, slow
from transformers.tokenization_gpt2 import GPT2Tokenizer
from .utils import require_tf, require_torch, slow
class TokenizerUtilsTest(unittest.TestCase):
def check_tokenizer_from_pretrained(self, tokenizer_class):

View File

@@ -18,10 +18,10 @@ import json
import os
import unittest
from transformers.testing_utils import slow
from transformers.tokenization_xlm import VOCAB_FILES_NAMES, XLMTokenizer
from .test_tokenization_common import TokenizerTesterMixin
from .utils import slow
class XLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):

View File

@@ -18,10 +18,10 @@ import os
import unittest
from transformers.file_utils import cached_property
from transformers.testing_utils import slow
from transformers.tokenization_xlm_roberta import SPIECE_UNDERLINE, XLMRobertaTokenizer
from .test_tokenization_common import TokenizerTesterMixin
from .utils import slow
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")

View File

@@ -17,10 +17,10 @@
import os
import unittest
from transformers.testing_utils import slow
from transformers.tokenization_xlnet import SPIECE_UNDERLINE, XLNetTokenizer
from .test_tokenization_common import TokenizerTesterMixin
from .utils import slow
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")

View File

@@ -1,8 +1,7 @@
import unittest
from transformers import AutoTokenizer, TrainingArguments, is_torch_available
from .utils import require_torch
from transformers.testing_utils import require_torch
if is_torch_available():

View File

@@ -1,120 +0,0 @@
import os
import unittest
from distutils.util import strtobool
from transformers.file_utils import _tf_available, _torch_available
SMALL_MODEL_IDENTIFIER = "julien-c/bert-xsmall-dummy"
DUMMY_UNKWOWN_IDENTIFIER = "julien-c/dummy-unknown"
# Used to test Auto{Config, Model, Tokenizer} model_type detection.
def parse_flag_from_env(key, default=False):
try:
value = os.environ[key]
except KeyError:
# KEY isn't set, default to `default`.
_value = default
else:
# KEY is set, convert it to True or False.
try:
_value = strtobool(value)
except ValueError:
# More values are supported, but let's keep the message simple.
raise ValueError("If set, {} must be yes or no.".format(key))
return _value
def parse_int_from_env(key, default=None):
try:
value = os.environ[key]
except KeyError:
_value = default
else:
try:
_value = int(value)
except ValueError:
raise ValueError("If set, {} must be a int.".format(key))
return _value
_run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False)
_run_custom_tokenizers = parse_flag_from_env("RUN_CUSTOM_TOKENIZERS", default=False)
_tf_gpu_memory_limit = parse_int_from_env("TF_GPU_MEMORY_LIMIT", default=None)
def slow(test_case):
"""
Decorator marking a test as slow.
Slow tests are skipped by default. Set the RUN_SLOW environment variable
to a truthy value to run them.
"""
if not _run_slow_tests:
test_case = unittest.skip("test is slow")(test_case)
return test_case
def custom_tokenizers(test_case):
"""
Decorator marking a test for a custom tokenizer.
Custom tokenizers require additional dependencies, and are skipped
by default. Set the RUN_CUSTOM_TOKENIZERS environment variable
to a truthy value to run them.
"""
if not _run_custom_tokenizers:
test_case = unittest.skip("test of custom tokenizers")(test_case)
return test_case
def require_torch(test_case):
"""
Decorator marking a test that requires PyTorch.
These tests are skipped when PyTorch isn't installed.
"""
if not _torch_available:
test_case = unittest.skip("test requires PyTorch")(test_case)
return test_case
def require_tf(test_case):
"""
Decorator marking a test that requires TensorFlow.
These tests are skipped when TensorFlow isn't installed.
"""
if not _tf_available:
test_case = unittest.skip("test requires TensorFlow")(test_case)
return test_case
def require_multigpu(test_case):
"""
Decorator marking a test that requires a multi-GPU setup (in PyTorch).
These tests are skipped on a machine without multiple GPUs.
To run *only* the multigpu tests, assuming all test names contain multigpu:
$ pytest -sv ./tests -k "multigpu"
"""
if not _torch_available:
return unittest.skip("test requires PyTorch")(test_case)
import torch
if torch.cuda.device_count() < 2:
return unittest.skip("test requires multiple GPUs")(test_case)
return test_case
if _torch_available:
# Set the USE_CUDA environment variable to select a GPU.
torch_device = "cuda" if parse_flag_from_env("USE_CUDA") else "cpu"
else:
torch_device = None