Moving zero-shot-classification pipeline to new testing. (#13299)
* Moving `zero-shot-classification` pipeline to new testing. * Cleaning up old mixins. * Fixing tests `sshleifer/tiny-distilbert-base-uncased-finetuned-sst-2-english` is corrupted in PT. * Adding warning.
This commit is contained in:
@@ -17,21 +17,9 @@ import logging
|
||||
import string
|
||||
from abc import abstractmethod
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional
|
||||
from unittest import mock, skipIf
|
||||
from unittest import skipIf
|
||||
|
||||
from transformers import (
|
||||
FEATURE_EXTRACTOR_MAPPING,
|
||||
TOKENIZER_MAPPING,
|
||||
AutoFeatureExtractor,
|
||||
AutoTokenizer,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
pipeline,
|
||||
)
|
||||
from transformers.file_utils import to_py_obj
|
||||
from transformers.pipelines import Pipeline
|
||||
from transformers.testing_utils import _run_slow_tests, is_pipeline_test, require_tf, require_torch, slow
|
||||
from transformers import FEATURE_EXTRACTOR_MAPPING, TOKENIZER_MAPPING, AutoFeatureExtractor, AutoTokenizer
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -189,228 +177,3 @@ class PipelineTestCaseMeta(type):
|
||||
dct["test_small_model_tf"] = dct.get("test_small_model_tf", inner)
|
||||
|
||||
return type.__new__(mcs, name, bases, dct)
|
||||
|
||||
|
||||
VALID_INPUTS = ["A simple string", ["list of strings"]]
|
||||
|
||||
|
||||
@is_pipeline_test
|
||||
class CustomInputPipelineCommonMixin:
|
||||
pipeline_task = None
|
||||
pipeline_loading_kwargs = {} # Additional kwargs to load the pipeline with
|
||||
pipeline_running_kwargs = {} # Additional kwargs to run the pipeline with
|
||||
small_models = [] # Models tested without the @slow decorator
|
||||
large_models = [] # Models tested with the @slow decorator
|
||||
valid_inputs = VALID_INPUTS # Some inputs which are valid to compare fast and slow tokenizers
|
||||
|
||||
def setUp(self) -> None:
|
||||
if not is_tf_available() and not is_torch_available():
|
||||
return # Currently no JAX pipelines
|
||||
|
||||
# Download needed checkpoints
|
||||
models = self.small_models
|
||||
if _run_slow_tests:
|
||||
models = models + self.large_models
|
||||
|
||||
for model_name in models:
|
||||
if is_torch_available():
|
||||
pipeline(
|
||||
self.pipeline_task,
|
||||
model=model_name,
|
||||
tokenizer=model_name,
|
||||
framework="pt",
|
||||
**self.pipeline_loading_kwargs,
|
||||
)
|
||||
if is_tf_available():
|
||||
pipeline(
|
||||
self.pipeline_task,
|
||||
model=model_name,
|
||||
tokenizer=model_name,
|
||||
framework="tf",
|
||||
**self.pipeline_loading_kwargs,
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_pt_defaults(self):
|
||||
pipeline(self.pipeline_task, framework="pt", **self.pipeline_loading_kwargs)
|
||||
|
||||
@require_tf
|
||||
@slow
|
||||
def test_tf_defaults(self):
|
||||
pipeline(self.pipeline_task, framework="tf", **self.pipeline_loading_kwargs)
|
||||
|
||||
@require_torch
|
||||
def test_torch_small(self):
|
||||
for model_name in self.small_models:
|
||||
pipe_small = pipeline(
|
||||
task=self.pipeline_task,
|
||||
model=model_name,
|
||||
tokenizer=model_name,
|
||||
framework="pt",
|
||||
**self.pipeline_loading_kwargs,
|
||||
)
|
||||
self._test_pipeline(pipe_small)
|
||||
|
||||
@require_tf
|
||||
def test_tf_small(self):
|
||||
for model_name in self.small_models:
|
||||
pipe_small = pipeline(
|
||||
task=self.pipeline_task,
|
||||
model=model_name,
|
||||
tokenizer=model_name,
|
||||
framework="tf",
|
||||
**self.pipeline_loading_kwargs,
|
||||
)
|
||||
self._test_pipeline(pipe_small)
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
def test_torch_large(self):
|
||||
for model_name in self.large_models:
|
||||
pipe_large = pipeline(
|
||||
task=self.pipeline_task,
|
||||
model=model_name,
|
||||
tokenizer=model_name,
|
||||
framework="pt",
|
||||
**self.pipeline_loading_kwargs,
|
||||
)
|
||||
self._test_pipeline(pipe_large)
|
||||
|
||||
@require_tf
|
||||
@slow
|
||||
def test_tf_large(self):
|
||||
for model_name in self.large_models:
|
||||
pipe_large = pipeline(
|
||||
task=self.pipeline_task,
|
||||
model=model_name,
|
||||
tokenizer=model_name,
|
||||
framework="tf",
|
||||
**self.pipeline_loading_kwargs,
|
||||
)
|
||||
self._test_pipeline(pipe_large)
|
||||
|
||||
def _test_pipeline(self, pipe: Pipeline):
|
||||
raise NotImplementedError
|
||||
|
||||
@require_torch
|
||||
def test_compare_slow_fast_torch(self):
|
||||
for model_name in self.small_models:
|
||||
pipe_slow = pipeline(
|
||||
task=self.pipeline_task,
|
||||
model=model_name,
|
||||
tokenizer=model_name,
|
||||
framework="pt",
|
||||
use_fast=False,
|
||||
**self.pipeline_loading_kwargs,
|
||||
)
|
||||
pipe_fast = pipeline(
|
||||
task=self.pipeline_task,
|
||||
model=model_name,
|
||||
tokenizer=model_name,
|
||||
framework="pt",
|
||||
use_fast=True,
|
||||
**self.pipeline_loading_kwargs,
|
||||
)
|
||||
self._compare_slow_fast_pipelines(pipe_slow, pipe_fast, method="forward")
|
||||
|
||||
@require_tf
|
||||
def test_compare_slow_fast_tf(self):
|
||||
for model_name in self.small_models:
|
||||
pipe_slow = pipeline(
|
||||
task=self.pipeline_task,
|
||||
model=model_name,
|
||||
tokenizer=model_name,
|
||||
framework="tf",
|
||||
use_fast=False,
|
||||
**self.pipeline_loading_kwargs,
|
||||
)
|
||||
pipe_fast = pipeline(
|
||||
task=self.pipeline_task,
|
||||
model=model_name,
|
||||
tokenizer=model_name,
|
||||
framework="tf",
|
||||
use_fast=True,
|
||||
**self.pipeline_loading_kwargs,
|
||||
)
|
||||
self._compare_slow_fast_pipelines(pipe_slow, pipe_fast, method="call")
|
||||
|
||||
def _compare_slow_fast_pipelines(self, pipe_slow: Pipeline, pipe_fast: Pipeline, method: str):
|
||||
"""We check that the inputs to the models forward passes are identical for
|
||||
slow and fast tokenizers.
|
||||
"""
|
||||
with mock.patch.object(
|
||||
pipe_slow.model, method, wraps=getattr(pipe_slow.model, method)
|
||||
) as mock_slow, mock.patch.object(
|
||||
pipe_fast.model, method, wraps=getattr(pipe_fast.model, method)
|
||||
) as mock_fast:
|
||||
for inputs in self.valid_inputs:
|
||||
if isinstance(inputs, dict):
|
||||
inputs.update(self.pipeline_running_kwargs)
|
||||
_ = pipe_slow(**inputs)
|
||||
_ = pipe_fast(**inputs)
|
||||
else:
|
||||
_ = pipe_slow(inputs, **self.pipeline_running_kwargs)
|
||||
_ = pipe_fast(inputs, **self.pipeline_running_kwargs)
|
||||
|
||||
mock_slow.assert_called()
|
||||
mock_fast.assert_called()
|
||||
|
||||
self.assertEqual(len(mock_slow.call_args_list), len(mock_fast.call_args_list))
|
||||
for mock_slow_call_args, mock_fast_call_args in zip(
|
||||
mock_slow.call_args_list, mock_slow.call_args_list
|
||||
):
|
||||
slow_call_args, slow_call_kwargs = mock_slow_call_args
|
||||
fast_call_args, fast_call_kwargs = mock_fast_call_args
|
||||
|
||||
slow_call_args, slow_call_kwargs = to_py_obj(slow_call_args), to_py_obj(slow_call_kwargs)
|
||||
fast_call_args, fast_call_kwargs = to_py_obj(fast_call_args), to_py_obj(fast_call_kwargs)
|
||||
|
||||
self.assertEqual(slow_call_args, fast_call_args)
|
||||
self.assertDictEqual(slow_call_kwargs, fast_call_kwargs)
|
||||
|
||||
|
||||
@is_pipeline_test
|
||||
class MonoInputPipelineCommonMixin(CustomInputPipelineCommonMixin):
|
||||
"""A version of the CustomInputPipelineCommonMixin
|
||||
with a predefined `_test_pipeline` method.
|
||||
"""
|
||||
|
||||
mandatory_keys = {} # Keys which should be in the output
|
||||
invalid_inputs = [None] # inputs which are not allowed
|
||||
expected_multi_result: Optional[List] = None
|
||||
expected_check_keys: Optional[List[str]] = None
|
||||
|
||||
def _test_pipeline(self, pipe: Pipeline):
|
||||
self.assertIsNotNone(pipe)
|
||||
|
||||
mono_result = pipe(self.valid_inputs[0], **self.pipeline_running_kwargs)
|
||||
self.assertIsInstance(mono_result, list)
|
||||
self.assertIsInstance(mono_result[0], (dict, list))
|
||||
|
||||
if isinstance(mono_result[0], list):
|
||||
mono_result = mono_result[0]
|
||||
|
||||
for key in self.mandatory_keys:
|
||||
self.assertIn(key, mono_result[0])
|
||||
|
||||
multi_result = [pipe(input, **self.pipeline_running_kwargs) for input in self.valid_inputs]
|
||||
self.assertIsInstance(multi_result, list)
|
||||
self.assertIsInstance(multi_result[0], (dict, list))
|
||||
|
||||
if self.expected_multi_result is not None:
|
||||
for result, expect in zip(multi_result, self.expected_multi_result):
|
||||
for key in self.expected_check_keys or []:
|
||||
self.assertEqual(
|
||||
set([o[key] for o in result]),
|
||||
set([o[key] for o in expect]),
|
||||
)
|
||||
|
||||
if isinstance(multi_result[0], list):
|
||||
multi_result = multi_result[0]
|
||||
|
||||
for result in multi_result:
|
||||
for key in self.mandatory_keys:
|
||||
self.assertIn(key, result)
|
||||
|
||||
self.assertRaises(Exception, pipe, self.invalid_inputs)
|
||||
|
||||
Reference in New Issue
Block a user