From 17083b9b847c71e8c303e9cb0798a8928e99a6e0 Mon Sep 17 00:00:00 2001 From: Connor Henderson Date: Fri, 5 May 2023 11:52:19 -0400 Subject: [PATCH] fix: Passing language as acronym to Whisper generate (#23141) * add fix * address comments * remove error formatting --- .../models/whisper/modeling_whisper.py | 8 ++++++-- tests/models/whisper/test_modeling_whisper.py | 15 +++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index bde8009116..91de6810b1 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1562,6 +1562,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): generation_config.return_timestamps = False if language is not None: + language = language.lower() generation_config.language = language if task is not None: generation_config.task = task @@ -1573,10 +1574,13 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): language_token = generation_config.language elif generation_config.language in TO_LANGUAGE_CODE.keys(): language_token = f"<|{TO_LANGUAGE_CODE[generation_config.language]}|>" + elif generation_config.language in TO_LANGUAGE_CODE.values(): + language_token = f"<|{generation_config.language}|>" else: + is_language_code = len(generation_config.language) == 2 raise ValueError( - f"Unsupported language: {self.language}. Language should be one of:" - f" {list(TO_LANGUAGE_CODE.keys()) if generation_config.language in TO_LANGUAGE_CODE.keys() else list(TO_LANGUAGE_CODE.values())}." + f"Unsupported language: {generation_config.language}. Language should be one of:" + f" {list(TO_LANGUAGE_CODE.values()) if is_language_code else list(TO_LANGUAGE_CODE.keys())}." ) forced_decoder_ids.append((1, generation_config.lang_to_id[language_token])) else: diff --git a/tests/models/whisper/test_modeling_whisper.py b/tests/models/whisper/test_modeling_whisper.py index dd6ad07eb4..0591c6f464 100644 --- a/tests/models/whisper/test_modeling_whisper.py +++ b/tests/models/whisper/test_modeling_whisper.py @@ -414,6 +414,21 @@ class WhisperModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMi model.generate(input_features) model.generate(input_features, num_beams=4, do_sample=True, early_stopping=False, num_return_sequences=3) + def test_generate_language(self): + config, input_dict = self.model_tester.prepare_config_and_inputs() + input_features = input_dict["input_features"] + model = WhisperForConditionalGeneration(config).to(torch_device) + # Hack to keep the test fast and not require downloading a model with a generation_config + model.generation_config.__setattr__("lang_to_id", {"<|en|>": 1}) + model.generation_config.__setattr__("task_to_id", {"transcribe": 2}) + + # test language code + model.generate(input_features, language="en") + # test tokenizer code + model.generate(input_features, language="<|en|>") + # test language name + model.generate(input_features, language="English") + def test_forward_signature(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common()