From a515d0a77c769954ac2f0151a2a99c04d8d6cf95 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Tue, 4 Apr 2023 16:21:57 +0200 Subject: [PATCH] Soft error whisper. (#22475) * Soft error whisper. * Fix format. --------- Co-authored-by: Ubuntu --- .../models/whisper/tokenization_whisper.py | 4 +-- ..._pipelines_automatic_speech_recognition.py | 33 ++++++++++++++++++- 2 files changed, 33 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/whisper/tokenization_whisper.py b/src/transformers/models/whisper/tokenization_whisper.py index 0160237304..24eb72a0b0 100644 --- a/src/transformers/models/whisper/tokenization_whisper.py +++ b/src/transformers/models/whisper/tokenization_whisper.py @@ -877,9 +877,7 @@ def _decode_asr(tokenizer, model_outputs, *, return_timestamps, return_language, if previous_tokens: if return_timestamps: - # Last token should always be timestamps, so there shouldn't be - # leftover - raise ValueError( + logger.warning( "There was an error while processing timestamps, we haven't found a timestamp as last token. Was" " WhisperTimeStampLogitsProcessor used?" ) diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index 5db3e3e46c..952508dca4 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -17,7 +17,7 @@ import unittest import numpy as np import pytest from datasets import load_dataset -from huggingface_hub import snapshot_download +from huggingface_hub import hf_hub_download, snapshot_download from transformers import ( MODEL_FOR_CTC_MAPPING, @@ -39,6 +39,7 @@ from transformers.testing_utils import ( require_pyctcdecode, require_tf, require_torch, + require_torch_gpu, require_torchaudio, slow, ) @@ -1158,6 +1159,36 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): output = speech_recognizer({"raw": waveform, "stride": (1000, 8000), "sampling_rate": 16_000}) self.assertEqual(output, {"text": "XB"}) + @slow + @require_torch_gpu + def test_slow_unfinished_sequence(self): + from transformers import GenerationConfig + + pipe = pipeline( + "automatic-speech-recognition", + model="vasista22/whisper-hindi-large-v2", + device="cuda:0", + ) + # Original model wasn't trained with timestamps and has incorrect generation config + pipe.model.generation_config = GenerationConfig.from_pretrained("openai/whisper-large-v2") + + audio = hf_hub_download("Narsil/asr_dummy", filename="hindi.ogg", repo_type="dataset") + + out = pipe( + audio, + return_timestamps=True, + ) + self.assertEqual( + out, + { + "chunks": [ + {"text": "", "timestamp": (18.94, 0.0)}, + {"text": "मिर्ची में कितने विभिन्न प्रजातियां हैं", "timestamp": (None, None)}, + ], + "text": "मिर्ची में कितने विभिन्न प्रजातियां हैं", + }, + ) + def require_ffmpeg(test_case): """