[Whisper] Refactor whisper (#21252)

* update whisper logit processor

* add generate for whisper

* remove part of the whisper specific code from pipeline

* update logit processes

* major update

* enforce first timestamp

* update generate

* add more tests

* update new decoding strategy

* Apply suggestions from code review

* update docstring

* fixup

* default config will not have multilingual ar

* update expected tokenizer size, see pull on the hub for whisper-tiny
This commit is contained in:
Arthur
2023-01-25 13:09:43 +01:00
committed by GitHub
parent f83135eb76
commit 255257f3ea
6 changed files with 231 additions and 55 deletions

View File

@@ -28,7 +28,6 @@ from transformers import (
Speech2TextForConditionalGeneration,
Wav2Vec2ForCTC,
WhisperForConditionalGeneration,
WhisperProcessor,
)
from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline
from transformers.pipelines.audio_utils import chunk_bytes_iter
@@ -523,10 +522,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
"chunks": [{"text": " A man said to the universe, Sir, I exist.", "timestamp": (0.0, 4.26)}],
},
)
pipe = pipeline(
model="openai/whisper-small",
return_timestamps=True,
)
output = pipe(array, chunk_length_s=10)
self.assertDictEqual(
@@ -687,6 +682,21 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
output,
{"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."},
)
output = speech_recognizer(filename, return_timestamps=True)
self.assertEqual(
output,
{
"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.",
"chunks": [
{
"text": (
" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."
),
"timestamp": (0.0, 5.44),
}
],
},
)
@slow
@require_torch
@@ -712,10 +722,14 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
output_2 = speech_recognizer_2(filename)
self.assertEqual(output, output_2)
processor = WhisperProcessor(feature_extractor, tokenizer)
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(task="transcribe", language="it")
# either use generate_kwargs or set the model's generation_config
# model.generation_config.task = "transcribe"
# model.generation_config.lang = "<|it|>"
speech_translator = AutomaticSpeechRecognitionPipeline(
model=model, tokenizer=tokenizer, feature_extractor=feature_extractor
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
generate_kwargs={"task": "transcribe", "language": "<|it|>"},
)
output_3 = speech_translator(filename)
self.assertEqual(output_3, {"text": " Un uomo ha detto all'universo, Sir, esiste."})