Fix whisper for pipeline (#19482)

* update feature extractor params

* update attention mask handling

* fix doc and pipeline test

* add warning when skipping test

* add whisper translation and transcription test

* fix build doc test
This commit is contained in:
Arthur
2022-10-11 13:17:53 +02:00
committed by Sylvain Gugger
parent 9ae22fe3c1
commit c8bc0a0b02
4 changed files with 90 additions and 24 deletions

View File

@@ -26,6 +26,8 @@ from transformers import (
AutoTokenizer,
Speech2TextForConditionalGeneration,
Wav2Vec2ForCTC,
WhisperForConditionalGeneration,
WhisperProcessor,
)
from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline
from transformers.pipelines.audio_utils import chunk_bytes_iter
@@ -308,6 +310,52 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
output = asr(data)
self.assertEqual(output, {"text": "Un uomo disse all'universo: \"Signore, io esisto."})
@slow
@require_torch
@require_torchaudio
def test_simple_whisper_asr(self):
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="openai/whisper-tiny.en",
framework="pt",
)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
filename = ds[0]["file"]
output = speech_recognizer(filename)
self.assertEqual(output, {"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to"})
@slow
@require_torch
@require_torchaudio
def test_simple_whisper_translation(self):
speech_recognizer = pipeline(
task="automatic-speech-recognition",
model="openai/whisper-large",
framework="pt",
)
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
filename = ds[40]["file"]
output = speech_recognizer(filename)
self.assertEqual(output, {"text": " A man said to the universe, Sir, I exist."})
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
tokenizer = AutoTokenizer.from_pretrained("openai/whisper-large")
feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-large")
speech_recognizer_2 = AutomaticSpeechRecognitionPipeline(
model=model, tokenizer=tokenizer, feature_extractor=feature_extractor
)
output_2 = speech_recognizer_2(filename)
self.assertEqual(output, output_2)
processor = WhisperProcessor(feature_extractor, tokenizer)
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(task="transcribe", language="it")
speech_translator = AutomaticSpeechRecognitionPipeline(
model=model, tokenizer=tokenizer, feature_extractor=feature_extractor
)
output_3 = speech_translator(filename)
self.assertEqual(output_3, {"text": " Un uomo ha detto allo universo, Sir, esiste."})
@slow
@require_torch
@require_torchaudio

View File

@@ -178,8 +178,16 @@ class ANY:
class PipelineTestCaseMeta(type):
def __new__(mcs, name, bases, dct):
def gen_test(ModelClass, checkpoint, tiny_config, tokenizer_class, feature_extractor_class):
@skipIf(tiny_config is None, "TinyConfig does not exist")
@skipIf(checkpoint is None, "checkpoint does not exist")
@skipIf(
tiny_config is None,
"TinyConfig does not exist, make sure that you defined a `_CONFIG_FOR_DOC` variable in the modeling"
" file",
)
@skipIf(
checkpoint is None,
"checkpoint does not exist, make sure that you defined a `_CHECKPOINT_FOR_DOC` variable in the"
" modeling file",
)
def test(self):
if ModelClass.__name__.endswith("ForCausalLM"):
tiny_config.is_encoder_decoder = False