From dcc49d8a7ef91c5e1baeb4d510ec4f37bc259760 Mon Sep 17 00:00:00 2001 From: Billy Bradley Date: Wed, 11 Oct 2023 12:18:42 +0100 Subject: [PATCH] In assisted decoding, pass model_kwargs to model's forward call (fix prepare_input_for_generation in all models) (#25242) * In assisted decoding, pass model_kwargs to model's forward call Previously, assisted decoding would ignore any additional kwargs that it doesn't explicitly handle. This was inconsistent with other generation methods, which pass the model_kwargs through prepare_inputs_for_generation and forward the returned dict to the model's forward call. The prepare_inputs_for_generation method needs to be amended in all models, as previously it only kept the last input ID when a past_key_values was passed. * Improve variable names in _extend_attention_mask * Refactor extending token_type_ids into a function * Replace deepcopy with copy to optimize performance * Update new persimmon model with llama changes for assisted generation * Update new mistral model for assisted generation with prepare_inputs_for_generation * Update position_ids creation in falcon prepare_inputs_for_generation to support assisted generation --- src/transformers/generation/utils.py | 91 +++++++++++-------- src/transformers/models/bark/modeling_bark.py | 15 ++- src/transformers/models/bart/modeling_bart.py | 22 ++++- src/transformers/models/bert/modeling_bert.py | 11 ++- .../modeling_bert_generation.py | 13 ++- .../models/big_bird/modeling_big_bird.py | 13 ++- .../modeling_bigbird_pegasus.py | 11 ++- .../models/biogpt/modeling_biogpt.py | 15 ++- .../models/blenderbot/modeling_blenderbot.py | 22 ++++- .../modeling_blenderbot_small.py | 22 ++++- .../models/blip/modeling_blip_text.py | 11 ++- .../models/bloom/modeling_bloom.py | 15 ++- .../models/camembert/modeling_camembert.py | 13 ++- .../models/codegen/modeling_codegen.py | 17 +++- src/transformers/models/ctrl/modeling_ctrl.py | 15 ++- .../models/data2vec/modeling_data2vec_text.py | 13 ++- .../open_llama/modeling_open_llama.py | 15 ++- .../models/electra/modeling_electra.py | 13 ++- .../models/ernie/modeling_ernie.py | 11 ++- .../models/falcon/modeling_falcon.py | 13 ++- src/transformers/models/gpt2/modeling_gpt2.py | 35 +++++-- .../gpt_bigcode/modeling_gpt_bigcode.py | 20 +++- .../models/gpt_neo/modeling_gpt_neo.py | 17 +++- .../models/gpt_neox/modeling_gpt_neox.py | 19 +++- src/transformers/models/gptj/modeling_gptj.py | 17 +++- .../models/imagegpt/modeling_imagegpt.py | 17 +++- .../models/llama/modeling_llama.py | 15 ++- .../models/longt5/modeling_longt5.py | 13 ++- .../models/m2m_100/modeling_m2m_100.py | 11 ++- .../models/marian/modeling_marian.py | 22 ++++- .../models/markuplm/modeling_markuplm.py | 11 ++- .../models/mbart/modeling_mbart.py | 22 ++++- .../megatron_bert/modeling_megatron_bert.py | 13 ++- .../models/mistral/modeling_mistral.py | 14 ++- src/transformers/models/mpt/modeling_mpt.py | 15 ++- src/transformers/models/mt5/modeling_mt5.py | 13 ++- .../models/musicgen/modeling_musicgen.py | 12 ++- src/transformers/models/mvp/modeling_mvp.py | 22 ++++- .../models/nllb_moe/modeling_nllb_moe.py | 11 ++- src/transformers/models/opt/modeling_opt.py | 13 ++- .../models/pegasus/modeling_pegasus.py | 22 ++++- .../models/pegasus_x/modeling_pegasus_x.py | 11 ++- .../models/persimmon/modeling_persimmon.py | 15 ++- .../models/pix2struct/modeling_pix2struct.py | 13 ++- .../models/plbart/modeling_plbart.py | 22 ++++- .../models/qdqbert/modeling_qdqbert.py | 13 ++- .../models/rembert/modeling_rembert.py | 13 ++- .../models/roberta/modeling_roberta.py | 13 ++- .../modeling_roberta_prelayernorm.py | 13 ++- .../models/roc_bert/modeling_roc_bert.py | 13 ++- .../models/roformer/modeling_roformer.py | 13 ++- .../modeling_speech_to_text_2.py | 11 ++- .../models/speecht5/modeling_speecht5.py | 11 ++- .../modeling_switch_transformers.py | 13 ++- src/transformers/models/t5/modeling_t5.py | 13 ++- .../models/trocr/modeling_trocr.py | 11 ++- src/transformers/models/umt5/modeling_umt5.py | 13 ++- .../models/whisper/modeling_whisper.py | 12 ++- src/transformers/models/xglm/modeling_xglm.py | 17 +++- .../xlm_roberta/modeling_xlm_roberta.py | 13 ++- .../xlm_roberta_xl/modeling_xlm_roberta_xl.py | 13 ++- src/transformers/models/xmod/modeling_xmod.py | 13 ++- tests/generation/test_utils.py | 86 ++++++++++++++++++ 63 files changed, 911 insertions(+), 179 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 3b1bef6f04..49b213cc5e 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1297,6 +1297,43 @@ class GenerationMixin: UserWarning, ) + def _extend_attention_mask(self, model_kwargs: Dict[str, Any], new_mask_length: int) -> Dict[str, Any]: + if self.config.is_encoder_decoder: + key = "decoder_attention_mask" + else: + key = "attention_mask" + + if key not in model_kwargs: + return model_kwargs + + mask = model_kwargs[key] + mask_extension_length = new_mask_length - mask.shape[1] + + if mask_extension_length < 0: + raise ValueError("Cannot extend attention mask to a length less than it already is") + + model_kwargs[key] = torch.cat( + [mask, mask.new_ones((mask.shape[0], mask_extension_length))], + dim=-1, + ) + + return model_kwargs + + def _extend_token_type_ids(self, model_kwargs: Dict[str, Any], new_length: int) -> Dict[str, Any]: + if "token_type_ids" not in model_kwargs or model_kwargs["token_type_ids"] is None: + return model_kwargs + + token_type_ids = model_kwargs["token_type_ids"] + final_token_type = token_type_ids[:, -1].unsqueeze(-1) + extension_length = new_length - token_type_ids.shape[1] + token_type_copies = final_token_type.repeat(1, extension_length) + model_kwargs["token_type_ids"] = torch.cat( + [model_kwargs["token_type_ids"], token_type_copies], + dim=-1, + ) + + return model_kwargs + @torch.no_grad() def generate( self, @@ -4441,47 +4478,21 @@ class GenerationMixin: # `candidate_length + 1` relevant logits from this process: in the event that all candidates are correct, # we use this forward pass to also pick the subsequent logits in the original model. - # 2.1. Run a forward pass on the candidate sequence - if "past_key_values" in model_kwargs: - model_attn = torch.ones_like(candidate_input_ids) - model_input_ids = candidate_input_ids[:, -candidate_length - 1 :] - if self.config.is_encoder_decoder: - outputs = self( - decoder_input_ids=model_input_ids, - decoder_attention_mask=model_attn, - past_key_values=model_kwargs["past_key_values"], - encoder_outputs=model_kwargs["encoder_outputs"], - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - use_cache=True, - ) - else: - outputs = self( - model_input_ids, - attention_mask=model_attn, - past_key_values=model_kwargs["past_key_values"], - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - use_cache=True, - ) - else: - if self.config.is_encoder_decoder: - outputs = self( - decoder_input_ids=candidate_input_ids, - encoder_outputs=model_kwargs["encoder_outputs"], - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - use_cache=True, - ) - else: - outputs = self( - candidate_input_ids, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - use_cache=True, - ) + # 2.1. Prepare the model inputs + candidate_kwargs = copy.copy(model_kwargs) + candidate_kwargs = self._extend_attention_mask(candidate_kwargs, candidate_input_ids.shape[1]) + candidate_kwargs = self._extend_token_type_ids(candidate_kwargs, candidate_input_ids.shape[1]) - # 2.2. Process the new logits + model_inputs = self.prepare_inputs_for_generation(candidate_input_ids, **candidate_kwargs) + + # 2.2. Run a forward pass on the candidate sequence + outputs = self( + **model_inputs, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + # 2.3. Process the new logits new_logits = outputs.logits[:, -candidate_length - 1 :] # excludes the input prompt if present if len(logits_processor) > 0: for i in range(candidate_length): diff --git a/src/transformers/models/bark/modeling_bark.py b/src/transformers/models/bark/modeling_bark.py index bdafb63477..649719e0ee 100644 --- a/src/transformers/models/bark/modeling_bark.py +++ b/src/transformers/models/bark/modeling_bark.py @@ -483,9 +483,18 @@ class BarkCausalModel(BarkPreTrainedModel): position_ids = kwargs.get("position_ids", None) if past_key_values is not None: - # only last token for inputs_ids if past is defined in kwargs + # Omit tokens covered by past_key_values seq_len = input_ids.shape[1] - input_ids = input_ids[:, [-1]] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # input_embeds have already been used and is not required anymore input_embeds = None @@ -507,7 +516,7 @@ class BarkCausalModel(BarkPreTrainedModel): position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] else: position_ids = None diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 52dfa5e392..9e7763ca23 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -1443,7 +1443,16 @@ class BartForConditionalGeneration(BartPreTrainedModel): ): # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed @@ -1934,7 +1943,16 @@ class BartForCausalLM(BartPreTrainedModel): attention_mask = input_ids.new_ones(input_ids.shape) if past_key_values: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # first step, decoder_cached_states are empty return { "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 29846b8051..1b0fad3f9d 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -1282,7 +1282,16 @@ class BertLMHeadModel(BertPreTrainedModel): # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return { "input_ids": input_ids, diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index f245ac155e..abe2d828b2 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -993,9 +993,18 @@ class BertGenerationDecoder(BertGenerationPreTrainedModel): if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index 867aca67e9..e266b1a67b 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -2628,9 +2628,18 @@ class BigBirdForCausalLM(BigBirdPreTrainedModel): if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index a32f3ecde7..4e279f9dc0 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -2627,7 +2627,16 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel): ): # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index 7534ed17fe..d1c471aa80 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -729,9 +729,18 @@ class BioGptForCausalLM(BioGptPreTrainedModel): def prepare_inputs_for_generation( self, input_ids, attention_mask, inputs_embeds=None, past_key_values=None, **kwargs ): - # only last token for inputs_ids if past is defined in kwargs - if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) + # only last tokens for inputs_ids if past is defined in kwargs + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index bdb8c52a55..1db8190521 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -1392,7 +1392,16 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel): ): # cut decoder_input_ids if past is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed @@ -1622,7 +1631,16 @@ class BlenderbotForCausalLM(BlenderbotPreTrainedModel): attention_mask = input_ids.new_ones(input_ids.shape) if past_key_values: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # first step, decoder_cached_states are empty return { "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index a1e888aec9..129de3dd14 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -1359,7 +1359,16 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel): ): # cut decoder_input_ids if past is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed @@ -1589,7 +1598,16 @@ class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel): attention_mask = input_ids.new_ones(input_ids.shape) if past_key_values: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # first step, decoder_cached_states are empty return { "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index 2ae3ac053b..49b958afc2 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -920,7 +920,16 @@ class BlipTextLMHeadModel(BlipTextPreTrainedModel): # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return { "input_ids": input_ids, diff --git a/src/transformers/models/bloom/modeling_bloom.py b/src/transformers/models/bloom/modeling_bloom.py index d12ec1724f..d90bb6ad8f 100644 --- a/src/transformers/models/bloom/modeling_bloom.py +++ b/src/transformers/models/bloom/modeling_bloom.py @@ -844,9 +844,18 @@ class BloomForCausalLM(BloomPreTrainedModel): inputs_embeds: Optional[torch.Tensor] = None, **kwargs, ) -> dict: - # only last token for input_ids if past is not None - if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) + # only last tokens for input_ids if past is not None + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # the cache may be in the stardard format (e.g. in contrastive search), convert to bloom's format if needed if past_key_values[0][0].shape[0] == input_ids.shape[0]: diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index 4635c06198..8d7d279579 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -1542,9 +1542,18 @@ class CamembertForCausalLM(CamembertPreTrainedModel): if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 93d5aa7ee4..172a45544b 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -617,11 +617,20 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel): def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) - # only last token for inputs_ids if past is defined in kwargs + # Omit tokens covered by past_key_values if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] if token_type_ids is not None: - token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) @@ -631,7 +640,7 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel): position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] return { "input_ids": input_ids, diff --git a/src/transformers/models/ctrl/modeling_ctrl.py b/src/transformers/models/ctrl/modeling_ctrl.py index 70cd4ec059..cec68de07d 100644 --- a/src/transformers/models/ctrl/modeling_ctrl.py +++ b/src/transformers/models/ctrl/modeling_ctrl.py @@ -526,9 +526,18 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): self.lm_head = new_embeddings def prepare_inputs_for_generation(self, input_ids, past_key_values=None, use_cache=None, **kwargs): - # only last token for inputs_ids if past is defined in kwargs - if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) + # only last tokens for inputs_ids if past is defined in kwargs + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "past_key_values": past_key_values, "use_cache": use_cache} diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index 7cbaee6925..a521ccb39a 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -1009,9 +1009,18 @@ class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel): if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index c975aa4087..6853f5333f 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -843,8 +843,17 @@ class OpenLlamaForCausalLM(OpenLlamaPreTrainedModel): def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): - if past_key_values: - input_ids = input_ids[:, -1:] + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: @@ -852,7 +861,7 @@ class OpenLlamaForCausalLM(OpenLlamaPreTrainedModel): position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index c06d306c1a..da3ee8e51d 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -1667,9 +1667,18 @@ class ElectraForCausalLM(ElectraPreTrainedModel): if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index 7ee6f43812..d55155f800 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -1223,7 +1223,16 @@ class ErnieForCausalLM(ErniePreTrainedModel): # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return { "input_ids": input_ids, diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index ab29322613..33b9fdde73 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -1228,7 +1228,16 @@ class FalconForCausalLM(FalconPreTrainedModel): **kwargs, ) -> dict: if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # Note: versions of Falcon with alibi do not use position_ids. It is used with RoPE. if not self.transformer.use_alibi and attention_mask is not None and position_ids is None: @@ -1236,7 +1245,7 @@ class FalconForCausalLM(FalconPreTrainedModel): position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] return { "input_ids": input_ids, diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 714f0351b3..838e7ca299 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -1005,11 +1005,20 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) - # only last token for inputs_ids if past is defined in kwargs + # Omit tokens covered by past_key_values if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] if token_type_ids is not None: - token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) @@ -1019,7 +1028,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] else: position_ids = None @@ -1038,6 +1047,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): "token_type_ids": token_type_ids, } ) + return model_inputs @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING) @@ -1201,11 +1211,20 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) - # only last token for inputs_ids if past is defined in kwargs + # Omit tokens covered by past_key_values if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] if token_type_ids is not None: - token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) @@ -1215,7 +1234,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] else: position_ids = None diff --git a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index d58e00af1d..be90f61e45 100644 --- a/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/src/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -737,11 +737,23 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel): def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) - # only last token for inputs_ids if past is defined in kwargs + # Omit tokens covered by past_key_values if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) + if self.config.multi_query: + past_length = past_key_values[0].shape[1] + else: + past_length = past_key_values[0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] if token_type_ids is not None: - token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) @@ -751,7 +763,7 @@ class GPTBigCodeForCausalLM(GPTBigCodePreTrainedModel): position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] else: position_ids = None diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 6364cfc316..3ad49554c0 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -680,11 +680,20 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel): def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) - # only last token for inputs_ids if past is defined in kwargs + # Omit tokens covered by past_key_values if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] if token_type_ids is not None: - token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) @@ -694,7 +703,7 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel): position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index b4aa415445..9391805a77 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -808,10 +808,21 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel): self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): input_shape = input_ids.shape + print(input_shape) + print(past_key_values[0][0].shape if past_key_values is not None else "no pkv") # cut decoder_input_ids if past is used - if past_key_values and past_key_values[0] is not None: - input_ids = input_ids[:, -1:] + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: @@ -819,7 +830,7 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel): position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly if attention_mask is None: @@ -830,7 +841,7 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel): model_inputs = {"inputs_embeds": inputs_embeds} else: model_inputs = {"input_ids": input_ids} - + print(position_ids.shape) model_inputs.update( { "attention_mask": attention_mask, diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index a93bdeaacd..6b5607f235 100644 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -785,11 +785,20 @@ class GPTJForCausalLM(GPTJPreTrainedModel): def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) - # only last token for inputs_ids if past is defined in kwargs + # Omit tokens covered by past_key_values if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] if token_type_ids is not None: - token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) @@ -799,7 +808,7 @@ class GPTJForCausalLM(GPTJPreTrainedModel): position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index 5f193a137b..54edcd30fc 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -912,11 +912,20 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel): def prepare_inputs_for_generation(self, input_ids: torch.Tensor, past_key_values: Optional[bool] = None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) - # only last token for inputs_ids if past is defined in kwargs + # Omit tokens covered by past_key_values if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] if token_type_ids is not None: - token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + token_type_ids = token_type_ids[:, -input_ids.shape[1] :] attention_mask = kwargs.get("attention_mask", None) position_ids = kwargs.get("position_ids", None) @@ -926,7 +935,7 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel): position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] else: position_ids = None return { diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 55753d5f75..4afa3293ed 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -1080,8 +1080,17 @@ class LlamaForCausalLM(LlamaPreTrainedModel): def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): - if past_key_values: - input_ids = input_ids[:, -1:] + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: @@ -1089,7 +1098,7 @@ class LlamaForCausalLM(LlamaPreTrainedModel): position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index d08ed83af0..4e8aef0678 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -2103,9 +2103,18 @@ class LongT5ForConditionalGeneration(LongT5PreTrainedModel): encoder_outputs=None, **kwargs, ): - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return { "decoder_input_ids": input_ids, diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 88e543b54b..6db8bbb521 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -1367,7 +1367,16 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel): ): # cut decoder_input_ids if past is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index b4e3aac5be..69de5b2e7d 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -1509,7 +1509,16 @@ class MarianMTModel(MarianPreTrainedModel): ) -> Dict: # cut decoder_input_ids if past is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed @@ -1740,7 +1749,16 @@ class MarianForCausalLM(MarianPreTrainedModel): attention_mask = input_ids.new_ones(input_ids.shape) if past_key_values: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # first step, decoder_cached_states are empty return { "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py index ca6bea4033..530c66a0c8 100755 --- a/src/transformers/models/markuplm/modeling_markuplm.py +++ b/src/transformers/models/markuplm/modeling_markuplm.py @@ -948,7 +948,16 @@ class MarkupLMModel(MarkupLMPreTrainedModel): # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return { "input_ids": input_ids, diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 276f94aebd..b53ad8848d 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -1413,7 +1413,16 @@ class MBartForConditionalGeneration(MBartPreTrainedModel): ): # cut decoder_input_ids if past is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed @@ -1897,7 +1906,16 @@ class MBartForCausalLM(MBartPreTrainedModel): attention_mask = input_ids.new_ones(input_ids.shape) if past_key_values: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # first step, decoder_cached_states are empty return { "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index 1c1eeff667..5d0ad6e341 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -1251,9 +1251,18 @@ class MegatronBertForCausalLM(MegatronBertPreTrainedModel): if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index a55f16a23d..62610ceb41 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -1083,8 +1083,18 @@ class MistralForCausalLM(MistralPreTrainedModel): def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): + # Omit tokens covered by past_key_values if past_key_values: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: @@ -1092,7 +1102,7 @@ class MistralForCausalLM(MistralPreTrainedModel): position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/mpt/modeling_mpt.py b/src/transformers/models/mpt/modeling_mpt.py index 0c608dbd2a..d760bec985 100644 --- a/src/transformers/models/mpt/modeling_mpt.py +++ b/src/transformers/models/mpt/modeling_mpt.py @@ -605,9 +605,18 @@ class MptForCausalLM(MptPreTrainedModel): use_cache: Optional[bool] = None, **kwargs, ) -> dict: - # only last token for input_ids if past is not None - if past_key_values: - input_ids = input_ids[:, -1].unsqueeze(-1) + # only last tokens for input_ids if past is not None + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 3d03503ddd..0de50afe9d 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -1836,9 +1836,18 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel): encoder_outputs=None, **kwargs, ): - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return { "decoder_input_ids": input_ids, diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index f178a67620..16766e953c 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -1995,9 +1995,17 @@ class MusicgenForConditionalGeneration(PreTrainedModel): if decoder_attention_mask is not None: decoder_attention_mask = decoder_attention_mask.repeat((2, 1)) - # cut decoder_input_ids if past is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index 21a82f95c3..5c1ed05249 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -1572,7 +1572,16 @@ class MvpForConditionalGeneration(MvpPreTrainedModel): ): # cut decoder_input_ids if past is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed @@ -2054,7 +2063,16 @@ class MvpForCausalLM(MvpPreTrainedModel): attention_mask = input_ids.new_ones(input_ids.shape) if past_key_values: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # first step, decoder_cached_states are empty return { "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index f37f64627d..3701bbecef 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -1808,7 +1808,16 @@ class NllbMoeForConditionalGeneration(NllbMoePreTrainedModel): ): # cut decoder_input_ids if past is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index d24211f039..8f3f246524 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -981,8 +981,17 @@ class OPTForCausalLM(OPTPreTrainedModel): def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): - if past_key_values: - input_ids = input_ids[:, -1:] + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 67934520fb..55856f7b06 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -1466,7 +1466,16 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel): ): # cut decoder_input_ids if past is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed @@ -1719,7 +1728,16 @@ class PegasusForCausalLM(PegasusPreTrainedModel): attention_mask = input_ids.new_ones(input_ids.shape) if past_key_values: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # first step, decoder_cached_states are empty return { "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index def82bdbaa..e87e9c7164 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -1671,7 +1671,16 @@ class PegasusXForConditionalGeneration(PegasusXPreTrainedModel): ): # cut decoder_input_ids if past is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index c09657c065..a0bc572638 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -847,8 +847,17 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel): def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs ): - if past_key_values: - input_ids = input_ids[:, -1:] + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: @@ -856,7 +865,7 @@ class PersimmonForCausalLM(PersimmonPreTrainedModel): position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 288e31a126..e19761803e 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -1798,9 +1798,18 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel): if decoder_attention_mask is None: decoder_attention_mask = torch.ones_like(input_ids).to(input_ids.device) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return { "flattened_patches": flattened_patches, diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 93532f4b0d..3a88083923 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -1379,7 +1379,16 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel): ) -> Dict[str, Any]: # cut decoder_input_ids if past is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "input_ids": None, # encoder_outputs is defined. input_ids not needed @@ -1739,7 +1748,16 @@ class PLBartForCausalLM(PLBartPreTrainedModel): attention_mask = input_ids.new_ones(input_ids.shape) if past_key_values: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # first step, decoder_cached_states are empty return { "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/qdqbert/modeling_qdqbert.py b/src/transformers/models/qdqbert/modeling_qdqbert.py index 47546930eb..fead8fc0cf 100755 --- a/src/transformers/models/qdqbert/modeling_qdqbert.py +++ b/src/transformers/models/qdqbert/modeling_qdqbert.py @@ -1151,9 +1151,18 @@ class QDQBertLMHeadModel(QDQBertPreTrainedModel): if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index 745be26ebf..235bff89f8 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -1147,9 +1147,18 @@ class RemBertForCausalLM(RemBertPreTrainedModel): if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 67e0fee422..6d4cc991d2 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -1007,9 +1007,18 @@ class RobertaForCausalLM(RobertaPreTrainedModel): if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index ddd87fa9ce..da1cd6331b 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -1014,9 +1014,18 @@ class RobertaPreLayerNormForCausalLM(RobertaPreLayerNormPreTrainedModel): if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index 35d4be9f20..a5b1b63050 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -1560,9 +1560,18 @@ class RoCBertForCausalLM(RoCBertPreTrainedModel): if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] if input_shape_ids is not None: input_shape_ids = input_shape_ids[:, -1:] if input_pronunciation_ids is not None: diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index 2c3feeda12..b9c36a305f 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -1178,9 +1178,18 @@ class RoFormerForCausalLM(RoFormerPreTrainedModel): if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} diff --git a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py index bfd801b242..f9b5dec420 100755 --- a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py +++ b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py @@ -963,7 +963,16 @@ class Speech2Text2ForCausalLM(Speech2Text2PreTrainedModel): attention_mask = input_ids.new_ones(input_ids.shape) if past_key_values: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # first step, decoder_cached_states are empty return { "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index 48334deb37..c4de7de090 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -2508,7 +2508,16 @@ class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel): ): # cut decoder_input_ids if past is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "encoder_outputs": encoder_outputs, diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 6c2fe82697..541db4382d 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -1727,9 +1727,18 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod encoder_outputs=None, **kwargs, ): - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return { "decoder_input_ids": input_ids, diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index e6d9deefa1..9716c7ffaf 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1810,9 +1810,18 @@ class T5ForConditionalGeneration(T5PreTrainedModel): encoder_outputs=None, **kwargs, ): - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return { "decoder_input_ids": input_ids, diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index 50829592a0..c0541814be 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -1003,7 +1003,16 @@ class TrOCRForCausalLM(TrOCRPreTrainedModel): attention_mask = input_ids.new_ones(input_ids.shape) if past_key_values: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] # first step, decoder_cached_states are empty return { "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index 8323054144..bd35111be1 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -1307,9 +1307,18 @@ class UMT5ForConditionalGeneration(UMT5PreTrainedModel): encoder_outputs=None, **kwargs, ): - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return { "decoder_input_ids": input_ids, diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index be5f50dbff..de1565fa76 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1810,9 +1810,17 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): attention_mask=None, **kwargs, ): - # cut decoder_input_ids if past is used if past_key_values is not None: - decoder_input_ids = decoder_input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if decoder_input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = decoder_input_ids.shape[1] - 1 + + decoder_input_ids = decoder_input_ids[:, remove_prefix_length:] return { "encoder_outputs": encoder_outputs, diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index 5f8778f98d..0c769dbbb5 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -851,21 +851,30 @@ class XGLMForCausalLM(XGLMPreTrainedModel): def prepare_inputs_for_generation( self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs ): + if past_key_values is not None: + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] + position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: # create position_ids on the fly for batch generation position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: - position_ids = position_ids[:, -1].unsqueeze(-1) + position_ids = position_ids[:, -input_ids.shape[1] :] else: position_ids = None # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly if attention_mask is None: attention_mask = input_ids.new_ones(input_ids.shape) - - if past_key_values: - input_ids = input_ids[:, -1:] # first step, decoder_cached_states are empty return { "input_ids": input_ids, # encoder_outputs is defined. input_ids not needed diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index 761e96a11b..da454b1e33 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -1011,9 +1011,18 @@ class XLMRobertaForCausalLM(XLMRobertaPreTrainedModel): if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index 025bab3887..26e0361abd 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -970,9 +970,18 @@ class XLMRobertaXLForCausalLM(XLMRobertaXLPreTrainedModel): if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index 61002bd277..28fddc2fdb 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -1118,9 +1118,18 @@ class XmodForCausalLM(XmodPreTrainedModel): if attention_mask is None: attention_mask = input_ids.new_ones(input_shape) - # cut decoder_input_ids if past is used + # cut decoder_input_ids if past_key_values is used if past_key_values is not None: - input_ids = input_ids[:, -1:] + past_length = past_key_values[0][0].shape[2] + + # Some generation methods already pass only the last input ID + if input_ids.shape[1] > past_length: + remove_prefix_length = past_length + else: + # Default to old behavior: keep only final ID + remove_prefix_length = input_ids.shape[1] - 1 + + input_ids = input_ids[:, remove_prefix_length:] return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index f73e3f60a5..8e3079f748 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2906,3 +2906,89 @@ class GenerationIntegrationTests(unittest.TestCase, GenerationIntegrationTestsMi model.generation_config.max_length = 10 model.generate(input_ids) self.assertEqual(len(warning_list), 0) + + def test_model_kwarg_assisted_decoding_decoder_only(self): + # PT-only test: TF doesn't support assisted decoding yet. + model = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2") + model.config.pad_token_id = tokenizer.eos_token_id + + text = "Hello world" + tokenized_inputs = tokenizer([text], return_tensors="pt") + input_ids = tokenized_inputs.input_ids.to(torch_device) + + # Traditional way of generating text + outputs_normal = model.generate(input_ids) + self.assertEqual(outputs_normal.shape, (1, 20)) + + # Should be different with token_type_ids + outputs_tti = model.generate( + input_ids, + token_type_ids=torch.zeros(input_ids.shape, dtype=torch.long).to(torch_device), + ) + with self.assertRaises(AssertionError): + self.assertListEqual(outputs_tti.tolist(), outputs_normal.tolist()) + + # Assistant model + assistant = AutoModelForCausalLM.from_pretrained("hf-internal-testing/tiny-random-gpt2").to(torch_device) + assistant.config.pad_token_id = tokenizer.eos_token_id + + # If assisted generation passes model_kwargs correctly, should be same as previous + outputs_assisted = model.generate( + input_ids, + token_type_ids=torch.zeros(input_ids.shape, dtype=torch.long).to(torch_device), + assistant_model=assistant, + ) + self.assertListEqual(outputs_assisted.tolist(), outputs_tti.tolist()) + + def test_model_kwarg_assisted_decoding_encoder_decoder(self): + # PT-only test: TF doesn't support assisted decoding yet. + # Bart subclass with a kwarg that distorts the output + class FakeBart(BartForConditionalGeneration): + def forward(self, input_ids, foo=False, **kwargs): + outs = super().forward(input_ids, **kwargs) + + if foo: + outs["logits"][:, :, :] = 0.0 + + return outs + + def prepare_inputs_for_generation(self, *args, foo=False, **kwargs): + inputs = super().prepare_inputs_for_generation(*args, **kwargs) + + inputs["foo"] = foo + return inputs + + model = FakeBart.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration").to( + torch_device + ) + tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-BartForConditionalGeneration") + + text = "Hello world" + tokenized_inputs = tokenizer([text], return_tensors="pt") + input_ids = tokenized_inputs.input_ids.to(torch_device) + + # Traditional way of generating text + outputs_normal = model.generate(input_ids) + self.assertEqual(outputs_normal.shape, (1, 20)) + + # Should be different with foo + outputs_foo = model.generate( + input_ids, + foo=True, + ) + with self.assertRaises(AssertionError): + self.assertListEqual(outputs_foo.tolist(), outputs_normal.tolist()) + + # Assistant model + assistant = AutoModelForSeq2SeqLM.from_pretrained( + "hf-internal-testing/tiny-random-BartForConditionalGeneration" + ).to(torch_device) + + # If assisted generation passes model_kwargs correctly, should be same as previous + outputs_assisted = model.generate( + input_ids, + foo=True, + assistant_model=assistant, + ) + self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())