[Whisper] Refactor forced_decoder_ids & prompt ids (#28687)

* up

* Fix more

* Correct more

* Fix more tests

* fix fast tests

* Fix more

* fix more

* push all files

* finish all

* make style

* Fix timestamp wrap

* make style

* make style

* up

* up

* up

* Fix lang detection behavior

* Fix lang detection behavior

* Add lang detection test

* Fix lang detection behavior

* make style

* Update src/transformers/models/whisper/generation_whisper.py

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* better error message

* make style tests

* add warning

---------

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
Patrick von Platen
2024-01-31 14:02:07 +02:00
committed by GitHub
parent f9f1f2ac5e
commit 65a926e82b
3 changed files with 605 additions and 228 deletions

View File

@@ -1451,6 +1451,7 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
# Original model wasn't trained with timestamps and has incorrect generation config
pipe.model.generation_config = GenerationConfig.from_pretrained("openai/whisper-large-v2")
# the audio is 4 seconds long
audio = hf_hub_download("Narsil/asr_dummy", filename="hindi.ogg", repo_type="dataset")
out = pipe(
@@ -1460,11 +1461,8 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase):
self.assertEqual(
out,
{
"chunks": [
{"text": "", "timestamp": (18.94, 0.02)},
{"text": "मिर्ची में कितने विभिन्न प्रजातियां हैं", "timestamp": (None, None)},
],
"text": "मिर्ची में कितने विभिन्न प्रजातियां हैं",
"chunks": [{"timestamp": (0.58, None), "text": "मिर्ची में कितने विभिन्न प्रजातियां हैं"}],
},
)