[Whisper] Refactor whisper (#21252)

* update whisper logit processor

* add generate for whisper

* remove part of the whisper specific code from pipeline

* update logit processes

* major update

* enforce first timestamp

* update generate

* add more tests

* update new decoding strategy

* Apply suggestions from code review

* update docstring

* fixup

* default config will not have multilingual ar

* update expected tokenizer size, see pull on the hub for whisper-tiny
This commit is contained in:
Arthur
2023-01-25 13:09:43 +01:00
committed by GitHub
parent f83135eb76
commit 255257f3ea
6 changed files with 231 additions and 55 deletions

View File

@@ -59,7 +59,7 @@ class WhisperTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
self.assertEqual(len(vocab_keys), 50364)
def test_vocab_size(self):
self.assertEqual(self.get_tokenizer().vocab_size, 50257)
self.assertEqual(self.get_tokenizer().vocab_size, 50258)
def test_full_tokenizer(self):
tokenizer = WhisperTokenizer.from_pretrained(self.tmpdirname)
@@ -265,7 +265,15 @@ class SpeechToTextTokenizerMultilinguialTest(unittest.TestCase):
},
],
)
# test `decode_with_offsets`
output = multilingual_tokenizer.decode(INPUT_TOKENS, decode_with_timestamps=True)
self.assertEqual(
output,
"<|startoftranscript|><|en|><|transcribe|><|0.00|> Lennils, pictures are a sort of upguards and atom"
" paintings, and Mason's exquisite idles<|7.20|><|7.20|> are as national as a jingo poem. Mr. Birkut"
" Foster's landscapes smile at one much in the<|15.16|><|15.16|> same way that Mr. Carker used to flash"
" his teeth. And Mr. John Colier gives his<|21.70|><|21.70|><|endoftext|>",
)
# test a single sequence with timestamps
# fmt: off
INPUT_TOKENS = [

View File

@@ -28,7 +28,6 @@ from transformers import (
Speech2TextForConditionalGeneration,
Wav2Vec2ForCTC,
WhisperForConditionalGeneration,
WhisperProcessor,
)
from transformers.pipelines import AutomaticSpeechRecognitionPipeline, pipeline
from transformers.pipelines.audio_utils import chunk_bytes_iter
@@ -523,10 +522,6 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
"chunks": [{"text": " A man said to the universe, Sir, I exist.", "timestamp": (0.0, 4.26)}],
},
)
pipe = pipeline(
model="openai/whisper-small",
return_timestamps=True,
)
output = pipe(array, chunk_length_s=10)
self.assertDictEqual(
@@ -687,6 +682,21 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
output,
{"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."},
)
output = speech_recognizer(filename, return_timestamps=True)
self.assertEqual(
output,
{
"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.",
"chunks": [
{
"text": (
" Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."
),
"timestamp": (0.0, 5.44),
}
],
},
)
@slow
@require_torch
@@ -712,10 +722,14 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
output_2 = speech_recognizer_2(filename)
self.assertEqual(output, output_2)
processor = WhisperProcessor(feature_extractor, tokenizer)
model.config.forced_decoder_ids = processor.get_decoder_prompt_ids(task="transcribe", language="it")
# either use generate_kwargs or set the model's generation_config
# model.generation_config.task = "transcribe"
# model.generation_config.lang = "<|it|>"
speech_translator = AutomaticSpeechRecognitionPipeline(
model=model, tokenizer=tokenizer, feature_extractor=feature_extractor
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
generate_kwargs={"task": "transcribe", "language": "<|it|>"},
)
output_3 = speech_translator(filename)
self.assertEqual(output_3, {"text": " Un uomo ha detto all'universo, Sir, esiste."})