From 3e07196f896946f084ce65db852869981c533a98 Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Tue, 18 Oct 2022 09:14:12 +0200 Subject: [PATCH] check decoder_inputs_embeds is None before shifting labels (#19671) --- src/transformers/models/bart/modeling_tf_bart.py | 2 +- src/transformers/models/blenderbot/modeling_blenderbot.py | 2 +- src/transformers/models/blenderbot/modeling_tf_blenderbot.py | 2 +- .../models/blenderbot_small/modeling_blenderbot_small.py | 2 +- .../models/blenderbot_small/modeling_tf_blenderbot_small.py | 2 +- src/transformers/models/led/modeling_led.py | 2 +- src/transformers/models/led/modeling_tf_led.py | 2 +- src/transformers/models/marian/modeling_marian.py | 2 +- src/transformers/models/marian/modeling_tf_marian.py | 2 +- src/transformers/models/mbart/modeling_mbart.py | 2 +- src/transformers/models/mbart/modeling_tf_mbart.py | 2 +- src/transformers/models/pegasus/modeling_pegasus.py | 2 +- src/transformers/models/pegasus/modeling_tf_pegasus.py | 2 +- src/transformers/models/pegasus_x/modeling_pegasus_x.py | 2 +- src/transformers/models/plbart/modeling_plbart.py | 2 +- .../models/speech_to_text/modeling_speech_to_text.py | 2 +- .../models/speech_to_text/modeling_tf_speech_to_text.py | 2 +- src/transformers/models/whisper/modeling_tf_whisper.py | 2 +- src/transformers/models/whisper/modeling_whisper.py | 2 +- .../modeling_tf_{{cookiecutter.lowercase_modelname}}.py | 2 +- .../modeling_{{cookiecutter.lowercase_modelname}}.py | 2 +- 21 files changed, 21 insertions(+), 21 deletions(-) diff --git a/src/transformers/models/bart/modeling_tf_bart.py b/src/transformers/models/bart/modeling_tf_bart.py index ed33b9eb5e..5de35b5a7c 100644 --- a/src/transformers/models/bart/modeling_tf_bart.py +++ b/src/transformers/models/bart/modeling_tf_bart.py @@ -1352,7 +1352,7 @@ class TFBartForConditionalGeneration(TFBartPretrainedModel, TFCausalLanguageMode labels, ) use_cache = False - if decoder_input_ids is None: + if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id ) diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 303a5c4f25..aaa62d9feb 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -1319,7 +1319,7 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel): if use_cache: logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") use_cache = False - if decoder_input_ids is None: + if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id ) diff --git a/src/transformers/models/blenderbot/modeling_tf_blenderbot.py b/src/transformers/models/blenderbot/modeling_tf_blenderbot.py index f6ecc34c21..50ceab3e02 100644 --- a/src/transformers/models/blenderbot/modeling_tf_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_tf_blenderbot.py @@ -1371,7 +1371,7 @@ class TFBlenderbotForConditionalGeneration(TFBlenderbotPreTrainedModel, TFCausal labels, ) use_cache = False - if decoder_input_ids is None: + if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id ) diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 8dac9b6a75..1d8de259b7 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -1286,7 +1286,7 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel): if use_cache: logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") use_cache = False - if decoder_input_ids is None: + if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id ) diff --git a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py index ba112256fc..e6c066af12 100644 --- a/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_tf_blenderbot_small.py @@ -1351,7 +1351,7 @@ class TFBlenderbotSmallForConditionalGeneration(TFBlenderbotSmallPreTrainedModel labels, ) use_cache = False - if decoder_input_ids is None: + if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id ) diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index ff79c0cad4..0c251df3cc 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -2428,7 +2428,7 @@ class LEDForConditionalGeneration(LEDPreTrainedModel): if use_cache: logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") use_cache = False - if decoder_input_ids is None: + if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id ) diff --git a/src/transformers/models/led/modeling_tf_led.py b/src/transformers/models/led/modeling_tf_led.py index a17b169a8b..76fd65eada 100644 --- a/src/transformers/models/led/modeling_tf_led.py +++ b/src/transformers/models/led/modeling_tf_led.py @@ -2445,7 +2445,7 @@ class TFLEDForConditionalGeneration(TFLEDPreTrainedModel): if labels is not None: use_cache = False - if decoder_input_ids is None: + if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id ) diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 26dc6b12dc..fe5ad63e53 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -1432,7 +1432,7 @@ class MarianMTModel(MarianPreTrainedModel): if use_cache: logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") use_cache = False - if decoder_input_ids is None: + if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id ) diff --git a/src/transformers/models/marian/modeling_tf_marian.py b/src/transformers/models/marian/modeling_tf_marian.py index 8c54854310..b040647fb6 100644 --- a/src/transformers/models/marian/modeling_tf_marian.py +++ b/src/transformers/models/marian/modeling_tf_marian.py @@ -1388,7 +1388,7 @@ class TFMarianMTModel(TFMarianPreTrainedModel, TFCausalLanguageModelingLoss): labels, ) use_cache = False - if decoder_input_ids is None: + if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id ) diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 66011fe6a7..1f532b9d69 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -1347,7 +1347,7 @@ class MBartForConditionalGeneration(MBartPreTrainedModel): if use_cache: logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") use_cache = False - if decoder_input_ids is None: + if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id) outputs = self.model( diff --git a/src/transformers/models/mbart/modeling_tf_mbart.py b/src/transformers/models/mbart/modeling_tf_mbart.py index 36512aa823..16106ad30b 100644 --- a/src/transformers/models/mbart/modeling_tf_mbart.py +++ b/src/transformers/models/mbart/modeling_tf_mbart.py @@ -1387,7 +1387,7 @@ class TFMBartForConditionalGeneration(TFMBartPreTrainedModel, TFCausalLanguageMo labels, ) use_cache = False - if decoder_input_ids is None: + if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id) outputs = self.model( diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 5a144aa3e9..6bbdee60a7 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -1393,7 +1393,7 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel): if use_cache: logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") use_cache = False - if decoder_input_ids is None: + if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id ) diff --git a/src/transformers/models/pegasus/modeling_tf_pegasus.py b/src/transformers/models/pegasus/modeling_tf_pegasus.py index dbfbcf6d02..913b8f7703 100644 --- a/src/transformers/models/pegasus/modeling_tf_pegasus.py +++ b/src/transformers/models/pegasus/modeling_tf_pegasus.py @@ -1397,7 +1397,7 @@ class TFPegasusForConditionalGeneration(TFPegasusPreTrainedModel, TFCausalLangua labels, ) use_cache = False - if decoder_input_ids is None: + if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id ) diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 8feac43571..0b813ad1a8 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -1605,7 +1605,7 @@ class PegasusXForConditionalGeneration(PegasusXPreTrainedModel): if use_cache: logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") use_cache = False - if decoder_input_ids is None: + if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id ) diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index d86decb568..fa58563ec4 100755 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -1314,7 +1314,7 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel): return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: - if decoder_input_ids is None: + if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id) outputs = self.model( diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index dd118ee654..730e130b2f 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -1341,7 +1341,7 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel): return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: - if decoder_input_ids is None: + if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id ) diff --git a/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py index ef37691448..364b515024 100755 --- a/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_tf_speech_to_text.py @@ -1405,7 +1405,7 @@ class TFSpeech2TextForConditionalGeneration(TFSpeech2TextPreTrainedModel, TFCaus return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: - if decoder_input_ids is None: + if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id ) diff --git a/src/transformers/models/whisper/modeling_tf_whisper.py b/src/transformers/models/whisper/modeling_tf_whisper.py index a24ccdd4a7..3bf6712d2b 100644 --- a/src/transformers/models/whisper/modeling_tf_whisper.py +++ b/src/transformers/models/whisper/modeling_tf_whisper.py @@ -1293,7 +1293,7 @@ class TFWhisperForConditionalGeneration(TFWhisperPreTrainedModel, TFCausalLangua return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: - if decoder_input_ids is None: + if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id ) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 08204c51d8..92079013fd 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1183,7 +1183,7 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): return_dict = return_dict if return_dict is not None else self.config.use_return_dict if labels is not None: - if decoder_input_ids is None: + if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id ) diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py index 71239f580c..3e9802e205 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_tf_{{cookiecutter.lowercase_modelname}}.py @@ -2938,7 +2938,7 @@ class TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration(TF{{cookiec if labels is not None: use_cache = False - if decoder_input_ids is None: + if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right( labels, self.config.pad_token_id, self.config.decoder_start_token_id ) diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py index 3bd3b1894a..9e2154901a 100755 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py @@ -2833,7 +2833,7 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte if use_cache: logger.warning("The `use_cache` argument is changed to `False` since `labels` is provided.") use_cache = False - if decoder_input_ids is None: + if decoder_input_ids is None and decoder_inputs_embeds is None: decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) outputs = self.model(