Adding batch_size support for (almost) all pipelines (#13724)

* Tentative enabling of `batch_size` for pipelines.

* Add systematic test for pipeline batching.

* Enabling batch_size on almost all pipelines

- Not `zero-shot` (it's already passing stuff as batched so trickier)
- Not `QA` (preprocess uses squad features, we need to switch to real
tensors at this boundary.

* Adding `min_length_for_response` for conversational.

* Making CTC, speech mappings avaiable regardless of framework.

* Attempt at fixing automatic tests (ffmpeg not enabled for fast tests)

* Removing ffmpeg dependency in tests.

* Small fixes.

* Slight cleanup.

* Adding docs

and adressing comments.

* Quality.

* Update docs/source/main_classes/pipelines.rst

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/pipelines/question_answering.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/pipelines/zero_shot_classification.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Improving docs.

* Update docs/source/main_classes/pipelines.rst

Co-authored-by: Philipp Schmid <32632186+philschmid@users.noreply.github.com>

* N -> oberved_batch_size

softmax trick.

* Follow `padding_side`.

* Supporting image pipeline batching (and padding).

* Rename `unbatch` -> `loader_batch`.

* unbatch_size forgot.

* Custom padding for offset mappings.

* Attempt to remove librosa.

* Adding require_audio.

* torchaudio.

* Back to using datasets librosa.

* Adding help to set a pad_token on the tokenizer.

* Update src/transformers/pipelines/base.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/pipelines/base.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/pipelines/base.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Quality.

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Philipp Schmid <32632186+philschmid@users.noreply.github.com>
This commit is contained in:
Nicolas Patry
2021-10-29 11:34:18 +02:00
committed by GitHub
parent 4469010c1b
commit be236361f1
27 changed files with 629 additions and 64 deletions

View File

@@ -24,6 +24,7 @@ from transformers.testing_utils import (
require_datasets,
require_tf,
require_torch,
require_torchaudio,
slow,
)
@@ -35,15 +36,16 @@ from .test_pipelines_common import ANY, PipelineTestCaseMeta
class AudioClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
model_mapping = MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING
@require_datasets
@slow
def run_pipeline_test(self, model, tokenizer, feature_extractor):
import datasets
def get_test_pipeline(self, model, tokenizer, feature_extractor):
audio_classifier = AudioClassificationPipeline(model=model, feature_extractor=feature_extractor)
# test with a raw waveform
audio = np.zeros((34000,))
audio2 = np.zeros((14000,))
return audio_classifier, [audio2, audio]
def run_pipeline_test(self, audio_classifier, examples):
audio2, audio = examples
output = audio_classifier(audio)
# by default a model is initialized with num_labels=2
self.assertEqual(
@@ -61,10 +63,17 @@ class AudioClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
],
)
self.run_torchaudio(audio_classifier)
@require_datasets
@require_torchaudio
def run_torchaudio(self, audio_classifier):
import datasets
# test with a local file
dataset = datasets.load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
filename = dataset[0]["file"]
output = audio_classifier(filename)
audio = dataset[0]["audio"]["array"]
output = audio_classifier(audio)
self.assertEqual(
output,
[

View File

@@ -14,11 +14,28 @@
import unittest
import numpy as np
import pytest
from transformers import AutoFeatureExtractor, AutoTokenizer, Speech2TextForConditionalGeneration, Wav2Vec2ForCTC
from transformers import (
MODEL_FOR_CTC_MAPPING,
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
AutoFeatureExtractor,
AutoTokenizer,
Speech2TextForConditionalGeneration,
Wav2Vec2ForCTC,
)
from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline
from transformers.testing_utils import is_pipeline_test, require_datasets, require_torch, require_torchaudio, slow
from transformers.testing_utils import (
is_pipeline_test,
require_datasets,
require_tf,
require_torch,
require_torchaudio,
slow,
)
from .test_pipelines_common import ANY, PipelineTestCaseMeta
# We can't use this mixin because it assumes TF support.
@@ -26,14 +43,42 @@ from transformers.testing_utils import is_pipeline_test, require_datasets, requi
@is_pipeline_test
class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
model_mapping = {
k: v
for k, v in (list(MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING.items()) if MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING else [])
+ (MODEL_FOR_CTC_MAPPING.items() if MODEL_FOR_CTC_MAPPING else [])
}
def get_test_pipeline(self, model, tokenizer, feature_extractor):
if tokenizer is None:
# Side effect of no Fast Tokenizer class for these model, so skipping
# But the slow tokenizer test should still run as they're quite small
self.skipTest("No tokenizer available")
return
# return None, None
speech_recognizer = AutomaticSpeechRecognitionPipeline(
model=model, tokenizer=tokenizer, feature_extractor=feature_extractor
)
# test with a raw waveform
audio = np.zeros((34000,))
audio2 = np.zeros((14000,))
return speech_recognizer, [audio, audio2]
def run_pipeline_test(self, speech_recognizer, examples):
audio = np.zeros((34000,))
outputs = speech_recognizer(audio)
self.assertEqual(outputs, {"text": ANY(str)})
@require_torch
@slow
def test_pt_defaults(self):
pipeline("automatic-speech-recognition", framework="pt")
@require_torch
def test_torch_small(self):
def test_small_model_pt(self):
import numpy as np
speech_recognizer = pipeline(
@@ -46,6 +91,10 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
output = speech_recognizer(waveform)
self.assertEqual(output, {"text": "(Applaudissements)"})
@require_tf
def test_small_model_tf(self):
self.skipTest("Tensorflow not supported yet.")
@require_torch
def test_torch_small_no_tokenizer_files(self):
# test that model without tokenizer file cannot be loaded

View File

@@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import importlib
import logging
import random
import string
import unittest
from abc import abstractmethod
@@ -21,6 +23,7 @@ from functools import lru_cache
from unittest import skipIf
from transformers import FEATURE_EXTRACTOR_MAPPING, TOKENIZER_MAPPING, AutoFeatureExtractor, AutoTokenizer, pipeline
from transformers.pipelines.base import _pad
from transformers.testing_utils import is_pipeline_test, require_torch
@@ -73,6 +76,12 @@ def get_tiny_config_from_class(configuration_class):
@lru_cache(maxsize=100)
def get_tiny_tokenizer_from_checkpoint(checkpoint):
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
if tokenizer.vocab_size < 300:
# Wav2Vec2ForCTC for instance
# ByT5Tokenizer
# all are already small enough and have no Fast version that can
# be retrained
return tokenizer
logger.info("Training new from iterator ...")
vocabulary = string.ascii_letters + string.digits + " "
tokenizer = tokenizer.train_new_from_iterator(vocabulary, vocab_size=len(vocabulary), show_progress=False)
@@ -87,6 +96,12 @@ def get_tiny_feature_extractor_from_checkpoint(checkpoint, tiny_config):
feature_extractor = None
if hasattr(tiny_config, "image_size") and feature_extractor:
feature_extractor = feature_extractor.__class__(size=tiny_config.image_size, crop_size=tiny_config.image_size)
# Speech2TextModel specific.
if hasattr(tiny_config, "input_feat_per_channel") and feature_extractor:
feature_extractor = feature_extractor.__class__(
feature_size=tiny_config.input_feat_per_channel, num_mel_bins=tiny_config.input_feat_per_channel
)
return feature_extractor
@@ -136,7 +151,26 @@ class PipelineTestCaseMeta(type):
else:
tokenizer = None
feature_extractor = get_tiny_feature_extractor_from_checkpoint(checkpoint, tiny_config)
self.run_pipeline_test(model, tokenizer, feature_extractor)
pipeline, examples = self.get_test_pipeline(model, tokenizer, feature_extractor)
if pipeline is None:
# The test can disable itself, but it should be very marginal
# Concerns: Wav2Vec2ForCTC without tokenizer test (FastTokenizer don't exist)
return
self.run_pipeline_test(pipeline, examples)
def run_batch_test(pipeline, examples):
# Need to copy because `Conversation` are stateful
if pipeline.tokenizer is not None and pipeline.tokenizer.pad_token_id is None:
return # No batching for this and it's OK
# 10 examples with batch size 4 means there needs to be a unfinished batch
# which is important for the unbatcher
dataset = [copy.deepcopy(random.choice(examples)) for i in range(10)]
for item in pipeline(dataset, batch_size=4):
pass
run_batch_test(pipeline, examples)
return test
@@ -211,3 +245,85 @@ class CommonPipelineTest(unittest.TestCase):
dataset = MyDataset()
for output in text_classifier(dataset):
self.assertEqual(output, {"label": ANY(str), "score": ANY(float)})
@is_pipeline_test
class PipelinePadTest(unittest.TestCase):
@require_torch
def test_pipeline_padding(self):
import torch
items = [
{
"label": "label1",
"input_ids": torch.LongTensor([[1, 23, 24, 2]]),
"attention_mask": torch.LongTensor([[0, 1, 1, 0]]),
},
{
"label": "label2",
"input_ids": torch.LongTensor([[1, 23, 24, 43, 44, 2]]),
"attention_mask": torch.LongTensor([[0, 1, 1, 1, 1, 0]]),
},
]
self.assertEqual(_pad(items, "label", 0, "right"), ["label1", "label2"])
self.assertTrue(
torch.allclose(
_pad(items, "input_ids", 10, "right"),
torch.LongTensor([[1, 23, 24, 2, 10, 10], [1, 23, 24, 43, 44, 2]]),
)
)
self.assertTrue(
torch.allclose(
_pad(items, "input_ids", 10, "left"),
torch.LongTensor([[10, 10, 1, 23, 24, 2], [1, 23, 24, 43, 44, 2]]),
)
)
self.assertTrue(
torch.allclose(
_pad(items, "attention_mask", 0, "right"), torch.LongTensor([[0, 1, 1, 0, 0, 0], [0, 1, 1, 1, 1, 0]])
)
)
@require_torch
def test_pipeline_image_padding(self):
import torch
items = [
{
"label": "label1",
"pixel_values": torch.zeros((1, 3, 10, 10)),
},
{
"label": "label2",
"pixel_values": torch.zeros((1, 3, 10, 10)),
},
]
self.assertEqual(_pad(items, "label", 0, "right"), ["label1", "label2"])
self.assertTrue(
torch.allclose(
_pad(items, "pixel_values", 10, "right"),
torch.zeros((2, 3, 10, 10)),
)
)
@require_torch
def test_pipeline_offset_mapping(self):
import torch
items = [
{
"offset_mappings": torch.zeros([1, 11, 2], dtype=torch.long),
},
{
"offset_mappings": torch.zeros([1, 4, 2], dtype=torch.long),
},
]
self.assertTrue(
torch.allclose(
_pad(items, "offset_mappings", 0, "right"),
torch.zeros((2, 11, 2), dtype=torch.long),
),
)

View File

@@ -54,8 +54,11 @@ class ConversationalPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
else []
)
def run_pipeline_test(self, model, tokenizer, feature_extractor):
def get_test_pipeline(self, model, tokenizer, feature_extractor):
conversation_agent = ConversationalPipeline(model=model, tokenizer=tokenizer)
return conversation_agent, [Conversation("Hi there!")]
def run_pipeline_test(self, conversation_agent, _):
# Simple
outputs = conversation_agent(Conversation("Hi there!"))
self.assertEqual(outputs, Conversation(past_user_inputs=["Hi there!"], generated_responses=[ANY(str)]))

View File

@@ -14,7 +14,15 @@
import unittest
from transformers import MODEL_MAPPING, TF_MODEL_MAPPING, CLIPConfig, FeatureExtractionPipeline, LxmertConfig, pipeline
from transformers import (
MODEL_MAPPING,
TF_MODEL_MAPPING,
CLIPConfig,
FeatureExtractionPipeline,
LxmertConfig,
Wav2Vec2Config,
pipeline,
)
from transformers.testing_utils import is_pipeline_test, nested_simplify, require_tf, require_torch
from .test_pipelines_common import PipelineTestCaseMeta
@@ -61,12 +69,12 @@ class FeatureExtractionPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
raise ValueError("We expect lists of floats, nothing else")
return shape
def run_pipeline_test(self, model, tokenizer, feature_extractor):
def get_test_pipeline(self, model, tokenizer, feature_extractor):
if tokenizer is None:
self.skipTest("No tokenizer")
return
elif isinstance(model.config, (LxmertConfig, CLIPConfig)):
elif isinstance(model.config, (LxmertConfig, CLIPConfig, Wav2Vec2Config)):
self.skipTest(
"This is an Lxmert bimodal model, we need to find a more consistent way to switch on those models."
)
@@ -81,11 +89,12 @@ class FeatureExtractionPipelineTests(unittest.TestCase, metaclass=PipelineTestCa
)
return
feature_extractor = FeatureExtractionPipeline(
model=model, tokenizer=tokenizer, feature_extractor=feature_extractor
)
return feature_extractor, ["This is a test", "This is another test"]
def run_pipeline_test(self, feature_extractor, examples):
outputs = feature_extractor("This is a test")
shape = self.get_shape(outputs)

View File

@@ -159,22 +159,32 @@ class FillMaskPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
unmasker = pipeline(task="fill-mask", model="sshleifer/tiny-distilroberta-base", framework="pt")
unmasker.tokenizer.pad_token_id = None
unmasker.tokenizer.pad_token = None
self.run_pipeline_test(unmasker.model, unmasker.tokenizer, None)
self.run_pipeline_test(unmasker, [])
@require_tf
def test_model_no_pad_tf(self):
unmasker = pipeline(task="fill-mask", model="sshleifer/tiny-distilroberta-base", framework="tf")
unmasker.tokenizer.pad_token_id = None
unmasker.tokenizer.pad_token = None
self.run_pipeline_test(unmasker.model, unmasker.tokenizer, None)
self.run_pipeline_test(unmasker, [])
def run_pipeline_test(self, model, tokenizer, feature_extractor):
def get_test_pipeline(self, model, tokenizer, feature_extractor):
if tokenizer is None or tokenizer.mask_token_id is None:
self.skipTest("The provided tokenizer has no mask token, (probably reformer or wav2vec2)")
fill_masker = FillMaskPipeline(model=model, tokenizer=tokenizer)
examples = [
f"This is another {tokenizer.mask_token} test",
]
return fill_masker, examples
outputs = fill_masker(f"This is a {tokenizer.mask_token}")
def run_pipeline_test(self, fill_masker, examples):
tokenizer = fill_masker.tokenizer
model = fill_masker.model
outputs = fill_masker(
f"This is a {tokenizer.mask_token}",
)
self.assertEqual(
outputs,
[

View File

@@ -44,9 +44,17 @@ else:
class ImageClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
model_mapping = MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING
@require_datasets
def run_pipeline_test(self, model, tokenizer, feature_extractor):
def get_test_pipeline(self, model, tokenizer, feature_extractor):
image_classifier = ImageClassificationPipeline(model=model, feature_extractor=feature_extractor)
examples = [
Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png"),
"http://images.cocodataset.org/val2017/000000039769.jpg",
]
return image_classifier, examples
@require_datasets
def run_pipeline_test(self, image_classifier, examples):
outputs = image_classifier("./tests/fixtures/tests_samples/COCO/000000039769.png")
self.assertEqual(

View File

@@ -53,9 +53,12 @@ else:
class ObjectDetectionPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
model_mapping = MODEL_FOR_OBJECT_DETECTION_MAPPING
@require_datasets
def run_pipeline_test(self, model, tokenizer, feature_extractor):
def get_test_pipeline(self, model, tokenizer, feature_extractor):
object_detector = ObjectDetectionPipeline(model=model, feature_extractor=feature_extractor)
return object_detector, ["./tests/fixtures/tests_samples/COCO/000000039769.png"]
@require_datasets
def run_pipeline_test(self, object_detector, examples):
outputs = object_detector("./tests/fixtures/tests_samples/COCO/000000039769.png", threshold=0.0)
self.assertGreater(len(outputs), 0)

View File

@@ -32,13 +32,20 @@ class QAPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta):
model_mapping = MODEL_FOR_QUESTION_ANSWERING_MAPPING
tf_model_mapping = TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING
def run_pipeline_test(self, model, tokenizer, feature_extractor):
def get_test_pipeline(self, model, tokenizer, feature_extractor):
if isinstance(model.config, LxmertConfig):
# This is an bimodal model, we need to find a more consistent way
# to switch on those models.
return
return None, None
question_answerer = QuestionAnsweringPipeline(model, tokenizer)
examples = [
{"question": "Where was HuggingFace founded ?", "context": "HuggingFace was founded in Paris."},
{"question": "In what field is HuggingFace ?", "context": "HuggingFace is an AI startup."},
]
return question_answerer, examples
def run_pipeline_test(self, question_answerer, _):
outputs = question_answerer(
question="Where was HuggingFace founded ?", context="HuggingFace was founded in Paris."
)

View File

@@ -36,8 +36,12 @@ class SummarizationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMe
model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
tf_model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
def run_pipeline_test(self, model, tokenizer, feature_extractor):
def get_test_pipeline(self, model, tokenizer, feature_extractor):
summarizer = SummarizationPipeline(model=model, tokenizer=tokenizer)
return summarizer, ["(CNN)The Palestinian Authority officially became", "Some other text"]
def run_pipeline_test(self, summarizer, _):
model = summarizer.model
outputs = summarizer("(CNN)The Palestinian Authority officially became")
self.assertEqual(outputs, [{"summary_text": ANY(str)}])

View File

@@ -30,9 +30,11 @@ class Text2TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTest
model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
tf_model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
def run_pipeline_test(self, model, tokenizer, feature_extractor):
def get_test_pipeline(self, model, tokenizer, feature_extractor):
generator = Text2TextGenerationPipeline(model=model, tokenizer=tokenizer)
return generator, ["Something to write", "Something else"]
def run_pipeline_test(self, generator, _):
outputs = generator("Something there")
self.assertEqual(outputs, [{"generated_text": ANY(str)}])
# These are encoder decoder, they don't just append to incoming string

View File

@@ -72,9 +72,12 @@ class TextClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTestC
outputs = text_classifier("Birds are a type of animal")
self.assertEqual(nested_simplify(outputs), [{"label": "POSITIVE", "score": 0.988}])
def run_pipeline_test(self, model, tokenizer, feature_extractor):
def get_test_pipeline(self, model, tokenizer, feature_extractor):
text_classifier = TextClassificationPipeline(model=model, tokenizer=tokenizer)
return text_classifier, ["HuggingFace is in", "This is another test"]
def run_pipeline_test(self, text_classifier, _):
model = text_classifier.model
# Small inputs because BartTokenizer tiny has maximum position embeddings = 22
valid_inputs = "HuggingFace is in"
outputs = text_classifier(valid_inputs)

View File

@@ -88,8 +88,14 @@ class TextGenerationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseM
],
)
def run_pipeline_test(self, model, tokenizer, feature_extractor):
def get_test_pipeline(self, model, tokenizer, feature_extractor):
text_generator = TextGenerationPipeline(model=model, tokenizer=tokenizer)
return text_generator, ["This is a test", "Another test"]
def run_pipeline_test(self, text_generator, _):
model = text_generator.model
tokenizer = text_generator.tokenizer
outputs = text_generator("This is a test")
self.assertEqual(outputs, [{"generated_text": ANY(str)}])
self.assertTrue(outputs[0]["generated_text"].startswith("This is a test"))

View File

@@ -45,8 +45,13 @@ class TokenClassificationPipelineTests(unittest.TestCase, metaclass=PipelineTest
model_mapping = MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
tf_model_mapping = TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING
def run_pipeline_test(self, model, tokenizer, feature_extractor):
def get_test_pipeline(self, model, tokenizer, feature_extractor):
token_classifier = TokenClassificationPipeline(model=model, tokenizer=tokenizer)
return token_classifier, ["A simple string", "A simple string that is quite a bit longer"]
def run_pipeline_test(self, token_classifier, _):
model = token_classifier.model
tokenizer = token_classifier.tokenizer
outputs = token_classifier("A simple string")
self.assertIsInstance(outputs, list)

View File

@@ -20,6 +20,7 @@ from transformers import (
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
MBart50TokenizerFast,
MBartConfig,
MBartForConditionalGeneration,
TranslationPipeline,
pipeline,
@@ -34,14 +35,16 @@ class TranslationPipelineTests(unittest.TestCase, metaclass=PipelineTestCaseMeta
model_mapping = MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
tf_model_mapping = TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
def run_pipeline_test(self, model, tokenizer, feature_extractor):
translator = TranslationPipeline(model=model, tokenizer=tokenizer)
try:
outputs = translator("Some string")
except ValueError:
# Triggered by m2m langages
src_lang, tgt_lang = list(translator.tokenizer.lang_code_to_id.keys())[:2]
outputs = translator("Some string", src_lang=src_lang, tgt_lang=tgt_lang)
def get_test_pipeline(self, model, tokenizer, feature_extractor):
if isinstance(model.config, MBartConfig):
src_lang, tgt_lang = list(tokenizer.lang_code_to_id.keys())[:2]
translator = TranslationPipeline(model=model, tokenizer=tokenizer, src_lang=src_lang, tgt_lang=tgt_lang)
else:
translator = TranslationPipeline(model=model, tokenizer=tokenizer)
return translator, ["Some string", "Some other text"]
def run_pipeline_test(self, translator, _):
outputs = translator("Some string")
self.assertEqual(outputs, [{"translation_text": ANY(str)}])
@require_torch

View File

@@ -31,9 +31,13 @@ class ZeroShotClassificationPipelineTests(unittest.TestCase, metaclass=PipelineT
model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
tf_model_mapping = TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING
def run_pipeline_test(self, model, tokenizer, feature_extractor):
classifier = ZeroShotClassificationPipeline(model=model, tokenizer=tokenizer)
def get_test_pipeline(self, model, tokenizer, feature_extractor):
classifier = ZeroShotClassificationPipeline(
model=model, tokenizer=tokenizer, candidate_labels=["polics", "health"]
)
return classifier, ["Who are you voting for in 2020?", "My stomach hurts."]
def run_pipeline_test(self, classifier, _):
outputs = classifier("Who are you voting for in 2020?", candidate_labels="politics")
self.assertEqual(outputs, {"sequence": ANY(str), "labels": [ANY(str)], "scores": [ANY(float)]})