From 57edd84bdb1f8a8c2bd6229d5be6a0a21991135b Mon Sep 17 00:00:00 2001 From: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com> Date: Fri, 17 May 2024 15:12:44 +0100 Subject: [PATCH] [whisper] fix multilingual fine-tuning (#30865) * [whisper] fix multilingual fine-tuning * config ids as well --- .../run_speech_recognition_seq2seq.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py b/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py index 59d097fc72..700e589057 100755 --- a/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py +++ b/examples/pytorch/speech-recognition/run_speech_recognition_seq2seq.py @@ -425,12 +425,8 @@ def main(): if hasattr(model.generation_config, "is_multilingual") and model.generation_config.is_multilingual: # We only need to set the language and task ids in a multilingual setting tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task) - model.generation_config.update( - **{ - "language": data_args.language, - "task": data_args.task, - } - ) + model.generation_config.language = data_args.language + model.generation_config.task = data_args.task elif data_args.language is not None: raise ValueError( "Setting language token for an English-only checkpoint is not permitted. The language argument should " @@ -444,6 +440,9 @@ def main(): "Please use the `language` and `task` arguments instead" ) model.generation_config.forced_decoder_ids = model_args.forced_decoder_ids + else: + model.generation_config.forced_decoder_ids = None + model.config.forced_decoder_ids = None if model_args.suppress_tokens is not None: logger.warning(