Fix whisper for pipeline (#19482)
* update feature extractor params * update attention mask handling * fix doc and pipeline test * add warning when skipping test * add whisper translation and transcription test * fix build doc test
This commit is contained in:
@@ -26,6 +26,8 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
Speech2TextForConditionalGeneration,
|
||||
Wav2Vec2ForCTC,
|
||||
WhisperForConditionalGeneration,
|
||||
WhisperProcessor,
|
||||
)
|
||||
from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline
|
||||
from transformers.pipelines.audio_utils import chunk_bytes_iter
|
||||
@@ -308,6 +310,52 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
output = asr(data)
|
||||
self.assertEqual(output, {"text": "Un uomo disse all'universo: \"Signore, io esisto."})
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
def test_simple_whisper_asr(self):
|
||||
speech_recognizer = pipeline(
|
||||
task="automatic-speech-recognition",
|
||||
model="openai/whisper-tiny.en",
|
||||
framework="pt",
|
||||
)
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
|
||||
filename = ds[0]["file"]
|
||||
output = speech_recognizer(filename)
|
||||
self.assertEqual(output, {"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to"})
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
def test_simple_whisper_translation(self):
|
||||
speech_recognizer = pipeline(
|
||||
task="automatic-speech-recognition",
|
||||
model="openai/whisper-large",
|
||||
framework="pt",
|
||||
)
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
||||
filename = ds[40]["file"]
|
||||
output = speech_recognizer(filename)
|
||||
self.assertEqual(output, {"text": " A man said to the universe, Sir, I exist."})
|
||||
|
||||
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-large")
|
||||
tokenizer = AutoTokenizer.from_pretrained("openai/whisper-large")
|
||||
feature_extractor = AutoFeatureExtractor.from_pretrained("openai/whisper-large")
|
||||
|
||||
speech_recognizer_2 = AutomaticSpeechRecognitionPipeline(
|
||||
model=model, tokenizer=tokenizer, feature_extractor=feature_extractor
|
||||
)
|
||||
output_2 = speech_recognizer_2(filename)
|
||||
self.assertEqual(output, output_2)
|
||||
|
||||
processor = WhisperProcessor(feature_extractor, tokenizer)
|
||||
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(task="transcribe", language="it")
|
||||
speech_translator = AutomaticSpeechRecognitionPipeline(
|
||||
model=model, tokenizer=tokenizer, feature_extractor=feature_extractor
|
||||
)
|
||||
output_3 = speech_translator(filename)
|
||||
self.assertEqual(output_3, {"text": " Un uomo ha detto allo universo, Sir, esiste."})
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
|
||||
@@ -178,8 +178,16 @@ class ANY:
|
||||
class PipelineTestCaseMeta(type):
|
||||
def __new__(mcs, name, bases, dct):
|
||||
def gen_test(ModelClass, checkpoint, tiny_config, tokenizer_class, feature_extractor_class):
|
||||
@skipIf(tiny_config is None, "TinyConfig does not exist")
|
||||
@skipIf(checkpoint is None, "checkpoint does not exist")
|
||||
@skipIf(
|
||||
tiny_config is None,
|
||||
"TinyConfig does not exist, make sure that you defined a `_CONFIG_FOR_DOC` variable in the modeling"
|
||||
" file",
|
||||
)
|
||||
@skipIf(
|
||||
checkpoint is None,
|
||||
"checkpoint does not exist, make sure that you defined a `_CHECKPOINT_FOR_DOC` variable in the"
|
||||
" modeling file",
|
||||
)
|
||||
def test(self):
|
||||
if ModelClass.__name__.endswith("ForCausalLM"):
|
||||
tiny_config.is_encoder_decoder = False
|
||||
|
||||
Reference in New Issue
Block a user