From da7ea9a4e337eb2eed204090fe38198418c01134 Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Tue, 7 Nov 2023 10:04:23 +0000 Subject: [PATCH] [Whisper] Block language/task args for English-only (#27322) * [Whisper] Block language/task args for English-only * Update src/transformers/models/whisper/modeling_whisper.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- .../models/whisper/modeling_whisper.py | 16 ++++++++ ..._pipelines_automatic_speech_recognition.py | 38 +++++++++++++++++++ 2 files changed, 54 insertions(+) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index a107adf74e..ad54d51b73 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1841,6 +1841,22 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): else: generation_config.return_timestamps = False + if is_multilingual is not None: + if not hasattr(generation_config, "is_multilingual"): + raise ValueError( + "The generation config is outdated and is thus not compatible with the `is_multilingual` argument " + "to `generate`. Please update the generation config as per the instructions " + "https://github.com/huggingface/transformers/issues/25084#issuecomment-1664398224" + ) + generation_config.is_multilingual = is_multilingual + + if hasattr(generation_config, "is_multilingual") and not generation_config.is_multilingual: + if task is not None or language is not None: + raise ValueError( + "Cannot specify `task` or `language` for an English-only model. If the model is intended to be " + "multilingual, pass `is_multilingual=True` to generate, or update the generation config." + ) + if language is not None: if not hasattr(generation_config, "lang_to_id"): raise ValueError( diff --git a/tests/pipelines/test_pipelines_automatic_speech_recognition.py b/tests/pipelines/test_pipelines_automatic_speech_recognition.py index 0343c32939..ea62198e2e 100644 --- a/tests/pipelines/test_pipelines_automatic_speech_recognition.py +++ b/tests/pipelines/test_pipelines_automatic_speech_recognition.py @@ -852,6 +852,44 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase): output_3 = speech_translator(filename) self.assertEqual(output_3, {"text": " Un uomo ha detto all'universo, Sir, esiste."}) + @slow + @require_torch + def test_whisper_language(self): + speech_recognizer = pipeline( + task="automatic-speech-recognition", + model="openai/whisper-tiny.en", + framework="pt", + ) + ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + filename = ds[0]["file"] + + # 1. English-only model compatible with no language argument + output = speech_recognizer(filename) + self.assertEqual( + output, + {"text": " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."}, + ) + + # 2. English-only Whisper does not accept the language argument + with self.assertRaisesRegex( + ValueError, + "Cannot specify `task` or `langauge` for an English-only model. If the model is intended to be multilingual, " + "pass `is_multilingual=True` to generate, or update the generation config.", + ): + _ = speech_recognizer(filename, generate_kwargs={"language": "en"}) + + # 3. Multilingual model accepts language argument + speech_recognizer = pipeline( + task="automatic-speech-recognition", + model="openai/whisper-tiny", + framework="pt", + ) + output = speech_recognizer(filename, generate_kwargs={"language": "en"}) + self.assertEqual( + output, + {"text": " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."}, + ) + @slow @require_torch @require_torchaudio