[Whisper] Refactor whisper (#21252)
* update whisper logit processor * add generate for whisper * remove part of the whisper specific code from pipeline * update logit processes * major update * enforce first timestamp * update generate * add more tests * update new decoding strategy * Apply suggestions from code review * update docstring * fixup * default config will not have multilingual ar * update expected tokenizer size, see pull on the hub for whisper-tiny
This commit is contained in:
@@ -28,7 +28,6 @@ from transformers import (
|
||||
Speech2TextForConditionalGeneration,
|
||||
Wav2Vec2ForCTC,
|
||||
WhisperForConditionalGeneration,
|
||||
WhisperProcessor,
|
||||
)
|
||||
from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline
|
||||
from transformers.pipelines.audio_utils import chunk_bytes_iter
|
||||
@@ -523,10 +522,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
"chunks": [{"text": " A man said to the universe, Sir, I exist.", "timestamp": (0.0, 4.26)}],
|
||||
},
|
||||
)
|
||||
pipe = pipeline(
|
||||
model="openai/whisper-small",
|
||||
return_timestamps=True,
|
||||
)
|
||||
|
||||
output = pipe(array, chunk_length_s=10)
|
||||
self.assertDictEqual(
|
||||
@@ -687,6 +682,21 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
output,
|
||||
{"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."},
|
||||
)
|
||||
output = speech_recognizer(filename, return_timestamps=True)
|
||||
self.assertEqual(
|
||||
output,
|
||||
{
|
||||
"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.",
|
||||
"chunks": [
|
||||
{
|
||||
"text": (
|
||||
" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."
|
||||
),
|
||||
"timestamp": (0.0, 5.44),
|
||||
}
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@@ -712,10 +722,14 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
output_2 = speech_recognizer_2(filename)
|
||||
self.assertEqual(output, output_2)
|
||||
|
||||
processor = WhisperProcessor(feature_extractor, tokenizer)
|
||||
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(task="transcribe", language="it")
|
||||
# either use generate_kwargs or set the model's generation_config
|
||||
# model.generation_config.task = "transcribe"
|
||||
# model.generation_config.lang = "<|it|>"
|
||||
speech_translator = AutomaticSpeechRecognitionPipeline(
|
||||
model=model, tokenizer=tokenizer, feature_extractor=feature_extractor
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
feature_extractor=feature_extractor,
|
||||
generate_kwargs={"task": "transcribe", "language": "<|it|>"},
|
||||
)
|
||||
output_3 = speech_translator(filename)
|
||||
self.assertEqual(output_3, {"text": " Un uomo ha detto all'universo, Sir, esiste."})
|
||||
|
||||
Reference in New Issue
Block a user