Improving pipeline tests (#12784)
* Proposal * Testing pipelines slightly better. - Overall same design - Metaclass to get proper different tests instead of subTest (not well supported by Pytest) - Added ANY meta object to make output checking more readable. - Skipping architectures either without tiny_config or without architecture. * Small fix. * Fixing the tests in case of None value. * Oups. * Rebased with more architectures. * Fixing reformer tests (no override anymore). * Adding more options for model tester config. Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
@@ -90,7 +90,7 @@ class MBartModelTester:
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1,
|
||||
attention_probs_dropout_prob=0.1,
|
||||
max_position_embeddings=20,
|
||||
max_position_embeddings=100,
|
||||
eos_token_id=2,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
|
||||
@@ -186,6 +186,11 @@ class ReformerModelTester:
|
||||
hash_seed=self.hash_seed,
|
||||
)
|
||||
|
||||
def get_pipeline_config(self):
|
||||
config = self.get_config()
|
||||
config.vocab_size = 100
|
||||
return config
|
||||
|
||||
def create_and_check_reformer_model(self, config, input_ids, input_mask, choice_labels):
|
||||
model = ReformerModel(config=config)
|
||||
model.to(torch_device)
|
||||
|
||||
@@ -12,15 +12,126 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
import string
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional
|
||||
from unittest import mock
|
||||
from unittest import mock, skipIf
|
||||
|
||||
from transformers import is_tf_available, is_torch_available, pipeline
|
||||
from transformers import TOKENIZER_MAPPING, AutoTokenizer, is_tf_available, is_torch_available, pipeline
|
||||
from transformers.file_utils import to_py_obj
|
||||
from transformers.pipelines import Pipeline
|
||||
from transformers.testing_utils import _run_slow_tests, is_pipeline_test, require_tf, require_torch, slow
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_checkpoint_from_architecture(architecture):
|
||||
module = importlib.import_module(architecture.__module__)
|
||||
|
||||
if hasattr(module, "_CHECKPOINT_FOR_DOC"):
|
||||
return module._CHECKPOINT_FOR_DOC
|
||||
else:
|
||||
logger.warning(f"Can't retrieve checkpoint from {architecture.__name__}")
|
||||
|
||||
|
||||
def get_tiny_config_from_class(configuration_class):
|
||||
if "OpenAIGPT" in configuration_class.__name__:
|
||||
# This is the only file that is inconsistent with the naming scheme.
|
||||
# Will rename this file if we decide this is the way to go
|
||||
return
|
||||
|
||||
model_type = configuration_class.model_type
|
||||
camel_case_model_name = configuration_class.__name__.split("Config")[0]
|
||||
|
||||
module = importlib.import_module(f".test_modeling_{model_type.replace('-', '_')}", package="tests")
|
||||
model_tester_class = getattr(module, f"{camel_case_model_name}ModelTester", None)
|
||||
|
||||
if model_tester_class is None:
|
||||
logger.warning(f"No model tester class for {configuration_class.__name__}")
|
||||
return
|
||||
|
||||
model_tester = model_tester_class(parent=None)
|
||||
|
||||
if hasattr(model_tester, "get_pipeline_config"):
|
||||
return model_tester.get_pipeline_config()
|
||||
elif hasattr(model_tester, "get_config"):
|
||||
return model_tester.get_config()
|
||||
else:
|
||||
logger.warning(f"Model tester {model_tester_class.__name__} has no `get_config()`.")
|
||||
|
||||
|
||||
@lru_cache(maxsize=100)
|
||||
def get_tiny_tokenizer_from_checkpoint(checkpoint):
|
||||
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
||||
logger.warning("Training new from iterator ...")
|
||||
vocabulary = string.ascii_letters + string.digits + " "
|
||||
tokenizer = tokenizer.train_new_from_iterator(vocabulary, vocab_size=len(vocabulary), show_progress=False)
|
||||
logger.warning("Trained.")
|
||||
return tokenizer
|
||||
|
||||
|
||||
class ANY:
|
||||
def __init__(self, _type):
|
||||
self._type = _type
|
||||
|
||||
def __eq__(self, other):
|
||||
return isinstance(other, self._type)
|
||||
|
||||
def __repr__(self):
|
||||
return f"ANY({self._type.__name__})"
|
||||
|
||||
|
||||
class PipelineTestCaseMeta(type):
|
||||
def __new__(mcs, name, bases, dct):
|
||||
def gen_test(ModelClass, checkpoint, tiny_config, tokenizer_class):
|
||||
@skipIf(tiny_config is None, "TinyConfig does not exist")
|
||||
@skipIf(checkpoint is None, "checkpoint does not exist")
|
||||
def test(self):
|
||||
model = ModelClass(tiny_config)
|
||||
if hasattr(model, "eval"):
|
||||
model = model.eval()
|
||||
try:
|
||||
tokenizer = get_tiny_tokenizer_from_checkpoint(checkpoint)
|
||||
tokenizer.model_max_length = model.config.max_position_embeddings
|
||||
# Rust Panic exception are NOT Exception subclass
|
||||
# Some test tokenizer contain broken vocabs or custom PreTokenizer, so we
|
||||
# provide some default tokenizer and hope for the best.
|
||||
except: # noqa: E722
|
||||
logger.warning(f"Tokenizer cannot be created from checkpoint {checkpoint}")
|
||||
tokenizer = get_tiny_tokenizer_from_checkpoint("gpt2")
|
||||
tokenizer.model_max_length = model.config.max_position_embeddings
|
||||
self.run_pipeline_test(model, tokenizer)
|
||||
|
||||
return test
|
||||
|
||||
mapping = dct.get("model_mapping", {})
|
||||
if mapping:
|
||||
for configuration, model_architecture in mapping.items():
|
||||
checkpoint = get_checkpoint_from_architecture(model_architecture)
|
||||
tiny_config = get_tiny_config_from_class(configuration)
|
||||
tokenizer_classes = TOKENIZER_MAPPING.get(configuration, [])
|
||||
for tokenizer_class in tokenizer_classes:
|
||||
if tokenizer_class is not None and tokenizer_class.__name__.endswith("Fast"):
|
||||
test_name = f"test_pt_{configuration.__name__}_{model_architecture.__name__}_{tokenizer_class.__name__}"
|
||||
dct[test_name] = gen_test(model_architecture, checkpoint, tiny_config, tokenizer_class)
|
||||
|
||||
tf_mapping = dct.get("tf_model_mapping", {})
|
||||
if tf_mapping:
|
||||
for configuration, model_architecture in tf_mapping.items():
|
||||
checkpoint = get_checkpoint_from_architecture(model_architecture)
|
||||
tiny_config = get_tiny_config_from_class(configuration)
|
||||
tokenizer_classes = TOKENIZER_MAPPING.get(configuration, [])
|
||||
for tokenizer_class in tokenizer_classes:
|
||||
if tokenizer_class is not None and tokenizer_class.__name__.endswith("Fast"):
|
||||
test_name = f"test_tf_{configuration.__name__}_{model_architecture.__name__}_{tokenizer_class.__name__}"
|
||||
dct[test_name] = gen_test(model_architecture, checkpoint, tiny_config, tokenizer_class)
|
||||
|
||||
return type.__new__(mcs, name, bases, dct)
|
||||
|
||||
|
||||
VALID_INPUTS = ["A simple string", ["list of strings"]]
|
||||
|
||||
|
||||
|
||||
@@ -14,13 +14,61 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from .test_pipelines_common import MonoInputPipelineCommonMixin
|
||||
from transformers import (
|
||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
TextClassificationPipeline,
|
||||
pipeline,
|
||||
)
|
||||
from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch, slow
|
||||
|
||||
from .test_pipelines_common import ANY, PipelineTestCaseMeta
|
||||
|
||||
|
||||
class TextClassificationPipelineTests(MonoInputPipelineCommonMixin, unittest.TestCase):
|
||||
pipeline_task = "sentiment-analysis"
|
||||
small_models = [
|
||||
"sshleifer/tiny-distilbert-base-uncased-finetuned-sst-2-english"
|
||||
] # Default model - Models tested without the @slow decorator
|
||||
large_models = [None] # Models tested with the @slow decorator
|
||||
mandatory_keys = {"label", "score"} # Keys which should be in the output
|
||||
@is_pipeline_test
|
||||
class TextClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
|
||||
model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
|
||||
tf_model_mapping = TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
def test_pt_bert(self):
|
||||
text_classifier = pipeline("text-classification")
|
||||
|
||||
outputs = text_classifier("This is great !")
|
||||
self.assertEqual(nested_simplify(outputs), [{"label": "POSITIVE", "score": 1.0}])
|
||||
outputs = text_classifier("This is bad !")
|
||||
self.assertEqual(nested_simplify(outputs), [{"label": "NEGATIVE", "score": 1.0}])
|
||||
outputs = text_classifier("Birds are a type of animal")
|
||||
self.assertEqual(nested_simplify(outputs), [{"label": "POSITIVE", "score": 0.988}])
|
||||
|
||||
@slow
|
||||
@require_tf
|
||||
def test_tf_bert(self):
|
||||
text_classifier = pipeline("text-classification", framework="tf")
|
||||
|
||||
outputs = text_classifier("This is great !")
|
||||
self.assertEqual(nested_simplify(outputs), [{"label": "POSITIVE", "score": 1.0}])
|
||||
outputs = text_classifier("This is bad !")
|
||||
self.assertEqual(nested_simplify(outputs), [{"label": "NEGATIVE", "score": 1.0}])
|
||||
outputs = text_classifier("Birds are a type of animal")
|
||||
self.assertEqual(nested_simplify(outputs), [{"label": "POSITIVE", "score": 0.988}])
|
||||
|
||||
def run_pipeline_test(self, model, tokenizer):
|
||||
text_classifier = TextClassificationPipeline(model=model, tokenizer=tokenizer)
|
||||
|
||||
# Small inputs because BartTokenizer tiny has maximum position embeddings = 22
|
||||
valid_inputs = "HuggingFace is in"
|
||||
outputs = text_classifier(valid_inputs)
|
||||
|
||||
self.assertEqual(nested_simplify(outputs), [{"label": ANY(str), "score": ANY(float)}])
|
||||
self.assertTrue(outputs[0]["label"] in model.config.id2label.values())
|
||||
|
||||
valid_inputs = ["HuggingFace is in ", "Paris is in France"]
|
||||
outputs = text_classifier(valid_inputs)
|
||||
self.assertEqual(
|
||||
nested_simplify(outputs),
|
||||
[{"label": ANY(str), "score": ANY(float)}, {"label": ANY(str), "score": ANY(float)}],
|
||||
)
|
||||
self.assertTrue(outputs[0]["label"] in model.config.id2label.values())
|
||||
self.assertTrue(outputs[1]["label"] in model.config.id2label.values())
|
||||
|
||||
Reference in New Issue
Block a user