From 12eb528b5a88d66b81957139ae452acec99a083a Mon Sep 17 00:00:00 2001 From: Arthur <48595927+ArthurZucker@users.noreply.github.com> Date: Tue, 7 Feb 2023 09:51:35 +0100 Subject: [PATCH] [CI ] Remove `past` in favor of `pat_key_values` (#21443) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix past renamed to past_key_value * update more `past`that were ski^êd * fixup * remove changes made to rag * refactor `_reorder_cache` to use `past_key_values` * fix git `prepare_inputs_for_generation` to pass tests when false is needed in use_cache --- src/transformers/generation/utils.py | 2 +- src/transformers/models/bart/modeling_bart.py | 8 ++++---- src/transformers/models/bert/modeling_bert.py | 4 ++-- .../modeling_bigbird_pegasus.py | 8 ++++---- .../models/biogpt/modeling_biogpt.py | 4 ++-- .../models/blenderbot/modeling_blenderbot.py | 8 ++++---- .../modeling_blenderbot_small.py | 8 ++++---- .../models/blip/modeling_blip_text.py | 4 ++-- .../models/camembert/modeling_camembert.py | 4 ++-- .../models/codegen/modeling_codegen.py | 6 ++++-- src/transformers/models/ctrl/modeling_ctrl.py | 6 ++++-- .../models/data2vec/modeling_data2vec_text.py | 4 ++-- .../models/electra/modeling_electra.py | 4 ++-- .../modeling_tf_encoder_decoder.py | 4 ++-- .../models/ernie/modeling_ernie.py | 4 ++-- src/transformers/models/fsmt/modeling_fsmt.py | 4 ++-- src/transformers/models/git/modeling_git.py | 14 ++++++++------ src/transformers/models/gpt2/modeling_gpt2.py | 12 ++++++++---- .../models/gpt_neo/modeling_gpt_neo.py | 6 ++++-- .../models/gpt_neox/modeling_gpt_neox.py | 4 ++-- .../modeling_gpt_neox_japanese.py | 4 ++-- src/transformers/models/gptj/modeling_gptj.py | 6 ++++-- .../models/imagegpt/modeling_imagegpt.py | 6 ++++-- src/transformers/models/led/modeling_led.py | 4 ++-- .../models/longt5/modeling_longt5.py | 8 ++++---- .../models/m2m_100/modeling_m2m_100.py | 4 ++-- .../models/marian/modeling_marian.py | 8 ++++---- .../models/markuplm/modeling_markuplm.py | 4 ++-- .../models/mbart/modeling_mbart.py | 8 ++++---- .../megatron_bert/modeling_megatron_bert.py | 4 ++-- src/transformers/models/mt5/modeling_mt5.py | 8 ++++---- src/transformers/models/mvp/modeling_mvp.py | 8 ++++---- src/transformers/models/opt/modeling_opt.py | 4 ++-- .../models/pegasus/modeling_pegasus.py | 8 ++++---- .../models/pegasus_x/modeling_pegasus_x.py | 4 ++-- .../models/plbart/modeling_plbart.py | 8 ++++---- .../models/prophetnet/modeling_prophetnet.py | 8 ++++---- .../models/qdqbert/modeling_qdqbert.py | 4 ++-- src/transformers/models/rag/modeling_rag.py | 4 ++-- .../models/reformer/modeling_reformer.py | 4 ++-- .../models/rembert/modeling_rembert.py | 4 ++-- .../models/roberta/modeling_roberta.py | 4 ++-- .../modeling_roberta_prelayernorm.py | 4 ++-- .../models/roc_bert/modeling_roc_bert.py | 4 ++-- .../models/roformer/modeling_roformer.py | 4 ++-- .../modeling_speech_encoder_decoder.py | 8 ++++---- .../speech_to_text/modeling_speech_to_text.py | 4 ++-- .../modeling_speech_to_text_2.py | 4 ++-- .../models/speecht5/modeling_speecht5.py | 10 +++++----- .../modeling_switch_transformers.py | 8 ++++---- src/transformers/models/t5/modeling_t5.py | 8 ++++---- .../models/trocr/modeling_trocr.py | 4 ++-- .../modeling_vision_encoder_decoder.py | 4 ++-- .../models/whisper/modeling_whisper.py | 4 ++-- src/transformers/models/xglm/modeling_xglm.py | 4 ++-- .../xlm_prophetnet/modeling_xlm_prophetnet.py | 8 ++++---- .../xlm_roberta/modeling_xlm_roberta.py | 4 ++-- .../xlm_roberta_xl/modeling_xlm_roberta_xl.py | 4 ++-- ...ng_{{cookiecutter.lowercase_modelname}}.py | 12 ++++++------ tests/models/gpt2/test_modeling_tf_gpt2.py | 19 +++++++++++++------ tests/models/gptj/test_modeling_tf_gptj.py | 19 +++++++++++++------ 61 files changed, 204 insertions(+), 174 deletions(-) diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 1d9e53168e..5b77e79159 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -724,7 +724,7 @@ class GenerationMixin: return model_kwargs - def _reorder_cache(self, past, beam_idx): + def _reorder_cache(self, past_key_values, beam_idx): raise NotImplementedError( f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to" f" enable beam search for {self.__class__}" diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 2aac055f85..903231cc06 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -1444,9 +1444,9 @@ class BartForConditionalGeneration(BartPretrainedModel): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) @staticmethod - def _reorder_cache(past, beam_idx): + def _reorder_cache(past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: # cached cross_attention states don't have to be reordered -> they are always the same reordered_past += ( tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], @@ -1921,8 +1921,8 @@ class BartForCausalLM(BartPretrainedModel): } @staticmethod - def _reorder_cache(past, beam_idx): + def _reorder_cache(past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index eb0e0d2166..9c8cf80440 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -1288,9 +1288,9 @@ class BertLMHeadModel(BertPreTrainedModel): "use_cache": use_cache, } - def _reorder_cache(self, past, beam_idx): + def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 21a4c7ade4..bb91e62b03 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -2629,9 +2629,9 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) @staticmethod - def _reorder_cache(past, beam_idx): + def _reorder_cache(past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: # cached cross_attention states don't have to be reordered -> they are always the same reordered_past += ( tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], @@ -3098,8 +3098,8 @@ class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel): } @staticmethod - def _reorder_cache(past, beam_idx): + def _reorder_cache(past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index 3fd9c823f9..c4c89a9f4e 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -714,8 +714,8 @@ class BioGptForCausalLM(BioGptPreTrainedModel): } @staticmethod - def _reorder_cache(past, beam_idx): + def _reorder_cache(past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 22c647d67c..482e3d349c 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -1401,9 +1401,9 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel): } @staticmethod - def _reorder_cache(past, beam_idx): + def _reorder_cache(past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: # cached cross_attention states don't have to be reordered -> they are always the same reordered_past += ( tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], @@ -1624,8 +1624,8 @@ class BlenderbotForCausalLM(BlenderbotPreTrainedModel): } @staticmethod - def _reorder_cache(past, beam_idx): + def _reorder_cache(past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 0f5e5a46c9..9927d6ab4e 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -1367,9 +1367,9 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel): } @staticmethod - def _reorder_cache(past, beam_idx): + def _reorder_cache(past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: # cached cross_attention states don't have to be reordered -> they are always the same reordered_past += ( tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], @@ -1590,8 +1590,8 @@ class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel): } @staticmethod - def _reorder_cache(past, beam_idx): + def _reorder_cache(past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index c44cf3b0df..3ec4a994d7 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -925,8 +925,8 @@ class BlipTextLMHeadModel(BlipTextPreTrainedModel): "is_decoder": True, } - def _reorder_cache(self, past, beam_idx): + def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index 2a7b4ecbfa..81352b9cca 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -1563,9 +1563,9 @@ class CamembertForCausalLM(CamembertPreTrainedModel): return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} - def _reorder_cache(self, past, beam_idx): + def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past diff --git a/src/transformers/models/codegen/modeling_codegen.py b/src/transformers/models/codegen/modeling_codegen.py index 87a2a986c8..d3f4ba6219 100644 --- a/src/transformers/models/codegen/modeling_codegen.py +++ b/src/transformers/models/codegen/modeling_codegen.py @@ -727,7 +727,9 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel): ) @staticmethod - def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: """ This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or [`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct @@ -735,5 +737,5 @@ class CodeGenForCausalLM(CodeGenPreTrainedModel): """ return tuple( tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) - for layer_past in past + for layer_past in past_key_values ) diff --git a/src/transformers/models/ctrl/modeling_ctrl.py b/src/transformers/models/ctrl/modeling_ctrl.py index 58c1859f2a..b41b7c5d1b 100644 --- a/src/transformers/models/ctrl/modeling_ctrl.py +++ b/src/transformers/models/ctrl/modeling_ctrl.py @@ -624,7 +624,9 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): ) @staticmethod - def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: """ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct @@ -632,7 +634,7 @@ class CTRLLMHeadModel(CTRLPreTrainedModel): """ return tuple( tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) - for layer_past in past + for layer_past in past_key_values ) diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index 737209142c..aa89c8f5d0 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -1024,9 +1024,9 @@ class Data2VecTextForCausalLM(Data2VecTextPreTrainedModel): return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} - def _reorder_cache(self, past, beam_idx): + def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index 8801b0de9a..7d2a06a8ed 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -1671,8 +1671,8 @@ class ElectraForCausalLM(ElectraPreTrainedModel): return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} # Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM._reorder_cache - def _reorder_cache(self, past, beam_idx): + def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past diff --git a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py index a9441c3229..5b97d1a4c9 100644 --- a/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py +++ b/src/transformers/models/encoder_decoder/modeling_tf_encoder_decoder.py @@ -690,9 +690,9 @@ class TFEncoderDecoderModel(TFPreTrainedModel, TFCausalLanguageModelingLoss): ) def prepare_inputs_for_generation( - self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs + self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs ): - decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past) + decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values) decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None past_key_values = decoder_inputs.get("past_key_values") if past_key_values is None: diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index 25e9e2e251..8f178d64a9 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -1230,9 +1230,9 @@ class ErnieForCausalLM(ErniePreTrainedModel): } # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel._reorder_cache - def _reorder_cache(self, past, beam_idx): + def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past diff --git a/src/transformers/models/fsmt/modeling_fsmt.py b/src/transformers/models/fsmt/modeling_fsmt.py index 408967bd05..4ad4dec842 100644 --- a/src/transformers/models/fsmt/modeling_fsmt.py +++ b/src/transformers/models/fsmt/modeling_fsmt.py @@ -1291,9 +1291,9 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel): return shift_tokens_right(labels, self.config.pad_token_id) @staticmethod - def _reorder_cache(past, beam_idx): + def _reorder_cache(past_key_values, beam_idx): reordered_past = [] - for layer_past in past: + for layer_past in past_key_values: # get the correct batch idx from decoder layer's batch dim for cross and self-attn layer_past_new = { attn_key: _reorder_buffer(attn_cache, beam_idx) for attn_key, attn_cache in layer_past.items() diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index 96fa94d9a7..9d75451292 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -1512,9 +1512,11 @@ class GitForCausalLM(GitPreTrainedModel): attentions=outputs.attentions, ) - def prepare_inputs_for_generation(self, input_ids, past=None, attention_mask=None, use_cache=True, **kwargs): - # cut decoder_input_ids if past is used - if past is not None: + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, **kwargs + ): + # cut decoder_input_ids if past_key_values is used + if past_key_values is not None: input_ids = input_ids[:, -1:] # if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly @@ -1526,12 +1528,12 @@ class GitForCausalLM(GitPreTrainedModel): "input_ids": input_ids, "attention_mask": attention_mask, "pixel_values": kwargs.get("pixel_values", None), - "past_key_values": past, + "past_key_values": past_key_values, "use_cache": use_cache, } - def _reorder_cache(self, past, beam_idx): + def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 00d2bd7f11..02270f00a5 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -1117,7 +1117,9 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ) @staticmethod - def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: """ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct @@ -1125,7 +1127,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): """ return tuple( tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) - for layer_past in past + for layer_past in past_key_values ) @@ -1336,7 +1338,9 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): ) @staticmethod - def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: """ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct @@ -1344,7 +1348,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): """ return tuple( tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) - for layer_past in past + for layer_past in past_key_values ) diff --git a/src/transformers/models/gpt_neo/modeling_gpt_neo.py b/src/transformers/models/gpt_neo/modeling_gpt_neo.py index 68084f3b7c..82bdf8cee4 100755 --- a/src/transformers/models/gpt_neo/modeling_gpt_neo.py +++ b/src/transformers/models/gpt_neo/modeling_gpt_neo.py @@ -782,7 +782,9 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel): ) @staticmethod - def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: """ This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or [`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct @@ -790,7 +792,7 @@ class GPTNeoForCausalLM(GPTNeoPreTrainedModel): """ return tuple( tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) - for layer_past in past + for layer_past in past_key_values ) diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 7fed1ad556..874cf719d7 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -700,9 +700,9 @@ class GPTNeoXForCausalLM(GPTNeoXPreTrainedModel): "past_key_values": past_key_values, } - def _reorder_cache(self, past, beam_idx): + def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: reordered_past += ( tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], ) diff --git a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py index 8badb60f97..cc3e7bd2c9 100755 --- a/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +++ b/src/transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py @@ -713,9 +713,9 @@ class GPTNeoXJapaneseForCausalLM(GPTNeoXJapanesePreTrainedModel): return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} - def _reorder_cache(self, past, beam_idx): + def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: reordered_past += ( tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], ) diff --git a/src/transformers/models/gptj/modeling_gptj.py b/src/transformers/models/gptj/modeling_gptj.py index bb19df5e7f..bbae317ab5 100755 --- a/src/transformers/models/gptj/modeling_gptj.py +++ b/src/transformers/models/gptj/modeling_gptj.py @@ -878,7 +878,9 @@ class GPTJForCausalLM(GPTJPreTrainedModel): ) @staticmethod - def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: """ This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or [`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct @@ -886,7 +888,7 @@ class GPTJForCausalLM(GPTJPreTrainedModel): """ return tuple( tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) - for layer_past in past + for layer_past in past_key_values ) diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index d7537c9dc3..21305de732 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -1060,7 +1060,9 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel): ) @staticmethod - def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]: + def _reorder_cache( + past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor + ) -> Tuple[Tuple[torch.Tensor]]: """ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct @@ -1068,7 +1070,7 @@ class ImageGPTForCausalImageModeling(ImageGPTPreTrainedModel): """ return tuple( tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) - for layer_past in past + for layer_past in past_key_values ) diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index b60f088de8..5181be6b83 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -2508,9 +2508,9 @@ class LEDForConditionalGeneration(LEDPreTrainedModel): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) @staticmethod - def _reorder_cache(past, beam_idx): + def _reorder_cache(past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: # cached cross_attention states don't have to be reordered -> they are always the same reordered_past += ( tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index 1039b1cc5b..316781c623 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -2114,15 +2114,15 @@ class LongT5ForConditionalGeneration(LongT5PreTrainedModel): def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return self._shift_right(labels) - def _reorder_cache(self, past, beam_idx): + def _reorder_cache(self, past_key_values, beam_idx): # if decoder past is not included in output # speedy decoding is disabled and no need to reorder - if past is None: + if past_key_values is None: logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") - return past + return past_key_values reordered_decoder_past = () - for layer_past_states in past: + for layer_past_states in past_key_values: # get the correct batch idx from layer past batch dim # batch dim of `past` is at 2nd position reordered_layer_past_states = () diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 58f0ca289c..88c5c18b99 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -1395,8 +1395,8 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel): } @staticmethod - def _reorder_cache(past, beam_idx): + def _reorder_cache(past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index d4dc4d53dd..46c151ca81 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -1521,9 +1521,9 @@ class MarianMTModel(MarianPreTrainedModel): return logits @staticmethod - def _reorder_cache(past, beam_idx): + def _reorder_cache(past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: # cached cross_attention states don't have to be reordered -> they are always the same reordered_past += ( tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], @@ -1742,8 +1742,8 @@ class MarianForCausalLM(MarianPreTrainedModel): } @staticmethod - def _reorder_cache(past, beam_idx): + def _reorder_cache(past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past diff --git a/src/transformers/models/markuplm/modeling_markuplm.py b/src/transformers/models/markuplm/modeling_markuplm.py index 86c67f7fe4..9c9e3ffb58 100755 --- a/src/transformers/models/markuplm/modeling_markuplm.py +++ b/src/transformers/models/markuplm/modeling_markuplm.py @@ -955,9 +955,9 @@ class MarkupLMModel(MarkupLMPreTrainedModel): } # Copied from transformers.models.bert.modeling_bert.BertModel._reorder_cache - def _reorder_cache(self, past, beam_idx): + def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 2548b08ed1..688fc9fc9c 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -1418,9 +1418,9 @@ class MBartForConditionalGeneration(MBartPreTrainedModel): return shift_tokens_right(labels, self.config.pad_token_id) @staticmethod - def _reorder_cache(past, beam_idx): + def _reorder_cache(past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: # cached cross_attention states don't have to be reordered -> they are always the same reordered_past += ( tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], @@ -1889,8 +1889,8 @@ class MBartForCausalLM(MBartPreTrainedModel): } @staticmethod - def _reorder_cache(past, beam_idx): + def _reorder_cache(past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index 5b7889adbc..c98452891a 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -1255,9 +1255,9 @@ class MegatronBertForCausalLM(MegatronBertPreTrainedModel): return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} - def _reorder_cache(self, past, beam_idx): + def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index dd047fd7c6..1a61c493f6 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -1808,15 +1808,15 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel): return self._shift_right(labels) # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration._reorder_cache - def _reorder_cache(self, past, beam_idx): + def _reorder_cache(self, past_key_values, beam_idx): # if decoder past is not included in output # speedy decoding is disabled and no need to reorder - if past is None: + if past_key_values is None: logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") - return past + return past_key_values reordered_decoder_past = () - for layer_past_states in past: + for layer_past_states in past_key_values: # get the correct batch idx from layer past batch dim # batch dim of `past` is at 2nd position reordered_layer_past_states = () diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index dde522535d..b6a6a9c328 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -1579,9 +1579,9 @@ class MvpForConditionalGeneration(MvpPreTrainedModel): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) @staticmethod - def _reorder_cache(past, beam_idx): + def _reorder_cache(past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: # cached cross_attention states don't have to be reordered -> they are always the same reordered_past += ( tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], @@ -2053,8 +2053,8 @@ class MvpForCausalLM(MvpPreTrainedModel): } @staticmethod - def _reorder_cache(past, beam_idx): + def _reorder_cache(past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index 5269c67de7..493a3f1b7c 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -982,9 +982,9 @@ class OPTForCausalLM(OPTPreTrainedModel): return model_inputs @staticmethod - def _reorder_cache(past, beam_idx): + def _reorder_cache(past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index d4cc8f381e..85636e6a80 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -1478,9 +1478,9 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) @staticmethod - def _reorder_cache(past, beam_idx): + def _reorder_cache(past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: # cached cross_attention states don't have to be reordered -> they are always the same reordered_past += ( tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], @@ -1721,8 +1721,8 @@ class PegasusForCausalLM(PegasusPreTrainedModel): } @staticmethod - def _reorder_cache(past, beam_idx): + def _reorder_cache(past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index f09a0447fd..5e55bd1a19 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -1681,9 +1681,9 @@ class PegasusXForConditionalGeneration(PegasusXPreTrainedModel): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) @staticmethod - def _reorder_cache(past, beam_idx): + def _reorder_cache(past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: # cached cross_attention states don't have to be reordered -> they are always the same reordered_past += ( tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index ed506a6935..12f761c4f9 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -1391,9 +1391,9 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel): return shift_tokens_right(labels, self.config.pad_token_id) @staticmethod - def _reorder_cache(past, beam_idx): + def _reorder_cache(past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: # cached cross_attention states don't have to be reordered -> they are always the same reordered_past += ( tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], @@ -1739,8 +1739,8 @@ class PLBartForCausalLM(PLBartPreTrainedModel): } @staticmethod - def _reorder_cache(past, beam_idx): + def _reorder_cache(past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 7dbef7c6e2..f94c8efad3 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -2090,9 +2090,9 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel): @staticmethod # Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration._reorder_cache - def _reorder_cache(past, beam_idx): + def _reorder_cache(past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: # cached cross_attention states don't have to be reordered -> they are always the same reordered_past += ( tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], @@ -2336,9 +2336,9 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel): @staticmethod # Copied from transformers.models.bart.modeling_bart.BartForCausalLM._reorder_cache - def _reorder_cache(past, beam_idx): + def _reorder_cache(past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past diff --git a/src/transformers/models/qdqbert/modeling_qdqbert.py b/src/transformers/models/qdqbert/modeling_qdqbert.py index ba24b89723..ef1ad54381 100755 --- a/src/transformers/models/qdqbert/modeling_qdqbert.py +++ b/src/transformers/models/qdqbert/modeling_qdqbert.py @@ -1156,9 +1156,9 @@ class QDQBertLMHeadModel(QDQBertPreTrainedModel): return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} - def _reorder_cache(self, past, beam_idx): + def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index df6b56d05c..6941bca09c 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -1205,7 +1205,7 @@ class RagTokenForGeneration(RagPreTrainedModel): return self.rag.question_encoder @staticmethod - def _reorder_cache(past, beam_idx): + def _reorder_cache(past_key_values, beam_idx): """Reorders cache for generation. BART-inspired but we need to take care of the extra dimension for docs""" def _reorder_stacked(hidden_states, new_order): @@ -1216,7 +1216,7 @@ class RagTokenForGeneration(RagPreTrainedModel): return result reordered_past = () - for layer_past in past: + for layer_past in past_key_values: # get the correct batch idx from decoder layer's batch dim for cross and self-attn reordered_past += (tuple(_reorder_stacked(past_state, beam_idx) for past_state in layer_past),) diff --git a/src/transformers/models/reformer/modeling_reformer.py b/src/transformers/models/reformer/modeling_reformer.py index a9ca155bb1..1b6aba0b7f 100755 --- a/src/transformers/models/reformer/modeling_reformer.py +++ b/src/transformers/models/reformer/modeling_reformer.py @@ -2298,9 +2298,9 @@ class ReformerModelWithLMHead(ReformerPreTrainedModel): return inputs_dict - def _reorder_cache(self, past, beam_idx): + def _reorder_cache(self, past_key_values, beam_idx): reord_past_buckets_states = [] - for layer_past in past: + for layer_past in past_key_values: # buckets if layer_past[0] is not None: reord_buckets = layer_past[0].index_select(0, beam_idx) diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index 9a65cb97ab..92a313e4fd 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -1148,9 +1148,9 @@ class RemBertForCausalLM(RemBertPreTrainedModel): return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} - def _reorder_cache(self, past, beam_idx): + def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: reordered_past += ( tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], ) diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 2b6d47b420..18fae6d920 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -1022,9 +1022,9 @@ class RobertaForCausalLM(RobertaPreTrainedModel): return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} - def _reorder_cache(self, past, beam_idx): + def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index e5fbb6e341..5b8b0290c3 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -1029,9 +1029,9 @@ class RobertaPreLayerNormForCausalLM(RobertaPreLayerNormPreTrainedModel): return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} - def _reorder_cache(self, past, beam_idx): + def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index ac3d374a55..c8c85ff142 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -1575,9 +1575,9 @@ class RoCBertForCausalLM(RoCBertPreTrainedModel): } # Copied from transformers.models.bert.modeling_bert.BertLMHeadModel._reorder_cache - def _reorder_cache(self, past, beam_idx): + def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index 3631b9704f..5edbd39ded 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -1187,9 +1187,9 @@ class RoFormerForCausalLM(RoFormerPreTrainedModel): return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} - def _reorder_cache(self, past, beam_idx): + def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: reordered_past += ( tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], ) diff --git a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py index 0bc4134d4b..e80c26e269 100644 --- a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py @@ -583,9 +583,9 @@ class SpeechEncoderDecoderModel(PreTrainedModel): return shift_tokens_right(labels, self.config.pad_token_id, self.config.decoder_start_token_id) def prepare_inputs_for_generation( - self, input_ids, past=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs + self, input_ids, past_key_values=None, attention_mask=None, use_cache=None, encoder_outputs=None, **kwargs ): - decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past=past) + decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids, past_key_values=past_key_values) decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None input_dict = { "attention_mask": attention_mask, @@ -603,6 +603,6 @@ class SpeechEncoderDecoderModel(PreTrainedModel): " respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))" ) - def _reorder_cache(self, past, beam_idx): + def _reorder_cache(self, past_key_values, beam_idx): # apply decoder cache reordering here - return self.decoder._reorder_cache(past, beam_idx) + return self.decoder._reorder_cache(past_key_values, beam_idx) diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index e68009a3ae..ad26633cfa 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -1418,8 +1418,8 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel): } @staticmethod - def _reorder_cache(past, beam_idx): + def _reorder_cache(past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past diff --git a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py index 8d5c508b98..40530e111b 100755 --- a/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py +++ b/src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py @@ -967,8 +967,8 @@ class Speech2Text2ForCausalLM(Speech2Text2PreTrainedModel): } @staticmethod - def _reorder_cache(past, beam_idx): + def _reorder_cache(past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index d9c381fdc8..86f7b81ca0 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -2368,7 +2368,7 @@ class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel): def prepare_inputs_for_generation( self, decoder_input_ids, - past=None, + past_key_values=None, attention_mask=None, head_mask=None, decoder_head_mask=None, @@ -2378,12 +2378,12 @@ class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel): **kwargs, ): # cut decoder_input_ids if past is used - if past is not None: + if past_key_values is not None: decoder_input_ids = decoder_input_ids[:, -1:] return { "encoder_outputs": encoder_outputs, - "past_key_values": past, + "past_key_values": past_key_values, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "head_mask": head_mask, @@ -2393,9 +2393,9 @@ class SpeechT5ForSpeechToText(SpeechT5PreTrainedModel): } @staticmethod - def _reorder_cache(past, beam_idx): + def _reorder_cache(past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index 77fe3be9c3..e71bbece3e 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -1774,15 +1774,15 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return self._shift_right(labels) - def _reorder_cache(self, past, beam_idx): + def _reorder_cache(self, past_key_values, beam_idx): # if decoder past is not included in output # speedy decoding is disabled and no need to reorder - if past is None: + if past_key_values is None: logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") - return past + return past_key_values reordered_decoder_past = () - for layer_past_states in past: + for layer_past_states in past_key_values: # get the correct batch idx from layer past batch dim # batch dim of `past` is at 2nd position reordered_layer_past_states = () diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 0e93f51e70..1f41e66264 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1773,15 +1773,15 @@ class T5ForConditionalGeneration(T5PreTrainedModel): def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): return self._shift_right(labels) - def _reorder_cache(self, past, beam_idx): + def _reorder_cache(self, past_key_values, beam_idx): # if decoder past is not included in output # speedy decoding is disabled and no need to reorder - if past is None: + if past_key_values is None: logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") - return past + return past_key_values reordered_decoder_past = () - for layer_past_states in past: + for layer_past_states in past_key_values: # get the correct batch idx from layer past batch dim # batch dim of `past` is at 2nd position reordered_layer_past_states = () diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index 98ab48a938..5eda7479b4 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -1007,8 +1007,8 @@ class TrOCRForCausalLM(TrOCRPreTrainedModel): } @staticmethod - def _reorder_cache(past, beam_idx): + def _reorder_cache(past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past diff --git a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py index 34ead71465..2cac544377 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py @@ -669,6 +669,6 @@ class VisionEncoderDecoderModel(PreTrainedModel): " respective methods of the wrapped decoder object (model.decoder.resize_token_embeddings(...))" ) - def _reorder_cache(self, past, beam_idx): + def _reorder_cache(self, past_key_values, beam_idx): # apply decoder cache reordering here - return self.decoder._reorder_cache(past, beam_idx) + return self.decoder._reorder_cache(past_key_values, beam_idx) diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 13e3ad7abb..0313613b70 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -1396,8 +1396,8 @@ class WhisperForConditionalGeneration(WhisperPreTrainedModel): # @staticmethod - def _reorder_cache(past, beam_idx): + def _reorder_cache(past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index 6fb3491743..f90e6f4ca3 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -942,8 +942,8 @@ class XGLMForCausalLM(XGLMPreTrainedModel): } @staticmethod - def _reorder_cache(past, beam_idx): + def _reorder_cache(past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past diff --git a/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py b/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py index 0327bb4f08..ed005c68aa 100644 --- a/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py +++ b/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py @@ -2117,9 +2117,9 @@ class XLMProphetNetForConditionalGeneration(XLMProphetNetPreTrainedModel): @staticmethod # Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration._reorder_cache - def _reorder_cache(past, beam_idx): + def _reorder_cache(past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: # cached cross_attention states don't have to be reordered -> they are always the same reordered_past += ( tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:], @@ -2364,9 +2364,9 @@ class XLMProphetNetForCausalLM(XLMProphetNetPreTrainedModel): @staticmethod # Copied from transformers.models.bart.modeling_bart.BartForCausalLM._reorder_cache - def _reorder_cache(past, beam_idx): + def _reorder_cache(past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index 30d75a2949..650e523876 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -1026,9 +1026,9 @@ class XLMRobertaForCausalLM(XLMRobertaPreTrainedModel): return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} - def _reorder_cache(self, past, beam_idx): + def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index 9250fa9639..175254eb83 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -988,9 +988,9 @@ class XLMRobertaXLForCausalLM(XLMRobertaXLPreTrainedModel): return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} - def _reorder_cache(self, past, beam_idx): + def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py index a64bd6c8fd..869c1e3402 100755 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py @@ -1180,9 +1180,9 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m return {"input_ids": input_ids, "attention_mask": attention_mask, "past_key_values": past_key_values} - def _reorder_cache(self, past, beam_idx): + def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past[:2]) + layer_past[2:],) return reordered_past @@ -2905,9 +2905,9 @@ class {{cookiecutter.camelcase_modelname}}ForConditionalGeneration({{cookiecutte } @staticmethod - def _reorder_cache(past, beam_idx): + def _reorder_cache(past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past @@ -3344,9 +3344,9 @@ class {{cookiecutter.camelcase_modelname}}ForCausalLM({{cookiecutter.camelcase_m } @staticmethod - def _reorder_cache(past, beam_idx): + def _reorder_cache(past_key_values, beam_idx): reordered_past = () - for layer_past in past: + for layer_past in past_key_values: reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) return reordered_past {% endif -%} diff --git a/tests/models/gpt2/test_modeling_tf_gpt2.py b/tests/models/gpt2/test_modeling_tf_gpt2.py index 298d501e58..3534ca7cd0 100644 --- a/tests/models/gpt2/test_modeling_tf_gpt2.py +++ b/tests/models/gpt2/test_modeling_tf_gpt2.py @@ -180,7 +180,7 @@ class TFGPT2ModelTester: self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) - output, past = outputs.to_tuple() + output, past_key_values = outputs.to_tuple() # create hypothetical next token and extent to next_input_ids next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) @@ -191,7 +191,9 @@ class TFGPT2ModelTester: next_token_type_ids = tf.concat([token_type_ids, next_token_types], axis=-1) output_from_no_past = model(next_input_ids, token_type_ids=next_token_type_ids)["last_hidden_state"] - output_from_past = model(next_tokens, token_type_ids=next_token_types, past=past)["last_hidden_state"] + output_from_past = model(next_tokens, token_type_ids=next_token_types, past_key_values=past_key_values)[ + "last_hidden_state" + ] # select random slice random_slice_idx = int(ids_tensor((1,), shape_list(output_from_past)[-1])) @@ -213,7 +215,7 @@ class TFGPT2ModelTester: attn_mask = tf.concat([attn_mask_begin, attn_mask_end], axis=1) # first forward pass - output, past = model(input_ids, attention_mask=attn_mask).to_tuple() + output, past_key_values = model(input_ids, attention_mask=attn_mask).to_tuple() # create hypothetical next token and extent to next_input_ids next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) @@ -233,7 +235,9 @@ class TFGPT2ModelTester: # get two different outputs output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"] - output_from_past = model(next_tokens, past=past, attention_mask=attn_mask)["last_hidden_state"] + output_from_past = model(next_tokens, past_key_values=past_key_values, attention_mask=attn_mask)[ + "last_hidden_state" + ] # select random slice random_slice_idx = int(ids_tensor((1,), shape_list(output_from_past)[-1])) @@ -256,7 +260,7 @@ class TFGPT2ModelTester: # first forward pass outputs = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, use_cache=True) - output, past = outputs.to_tuple() + output, past_key_values = outputs.to_tuple() # create hypothetical next token and extent to next_input_ids next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) @@ -272,7 +276,10 @@ class TFGPT2ModelTester: next_input_ids, token_type_ids=next_token_type_ids, attention_mask=next_attention_mask )["last_hidden_state"] output_from_past = model( - next_tokens, token_type_ids=next_token_types, attention_mask=next_attention_mask, past=past + next_tokens, + token_type_ids=next_token_types, + attention_mask=next_attention_mask, + past_key_values=past_key_values, )["last_hidden_state"] self.parent.assertTrue(output_from_past.shape[1] == next_tokens.shape[1]) diff --git a/tests/models/gptj/test_modeling_tf_gptj.py b/tests/models/gptj/test_modeling_tf_gptj.py index 6557428c07..0113042ae0 100644 --- a/tests/models/gptj/test_modeling_tf_gptj.py +++ b/tests/models/gptj/test_modeling_tf_gptj.py @@ -148,7 +148,7 @@ class TFGPTJModelTester: self.parent.assertTrue(len(outputs) == len(outputs_use_cache_conf)) self.parent.assertTrue(len(outputs) == len(outputs_no_past) + 1) - output, past = outputs.to_tuple() + output, past_key_values = outputs.to_tuple() # create hypothetical next token and extent to next_input_ids next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) @@ -159,7 +159,9 @@ class TFGPTJModelTester: next_token_type_ids = tf.concat([token_type_ids, next_token_types], axis=-1) output_from_no_past = model(next_input_ids, token_type_ids=next_token_type_ids)["last_hidden_state"] - output_from_past = model(next_tokens, token_type_ids=next_token_types, past=past)["last_hidden_state"] + output_from_past = model(next_tokens, token_type_ids=next_token_types, past_key_values=past_key_values)[ + "last_hidden_state" + ] # select random slice random_slice_idx = int(ids_tensor((1,), shape_list(output_from_past)[-1])) @@ -181,7 +183,7 @@ class TFGPTJModelTester: attn_mask = tf.concat([attn_mask_begin, attn_mask_end], axis=1) # first forward pass - output, past = model(input_ids, attention_mask=attn_mask).to_tuple() + output, past_key_values = model(input_ids, attention_mask=attn_mask).to_tuple() # create hypothetical next token and extent to next_input_ids next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) @@ -201,7 +203,9 @@ class TFGPTJModelTester: # get two different outputs output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"] - output_from_past = model(next_tokens, past=past, attention_mask=attn_mask)["last_hidden_state"] + output_from_past = model(next_tokens, past_key_values=past_key_values, attention_mask=attn_mask)[ + "last_hidden_state" + ] # select random slice random_slice_idx = int(ids_tensor((1,), shape_list(output_from_past)[-1])) @@ -224,7 +228,7 @@ class TFGPTJModelTester: # first forward pass outputs = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, use_cache=True) - output, past = outputs.to_tuple() + output, past_key_values = outputs.to_tuple() # create hypothetical next token and extent to next_input_ids next_tokens = ids_tensor((self.batch_size, 3), config.vocab_size) @@ -240,7 +244,10 @@ class TFGPTJModelTester: next_input_ids, token_type_ids=next_token_type_ids, attention_mask=next_attention_mask )["last_hidden_state"] output_from_past = model( - next_tokens, token_type_ids=next_token_types, attention_mask=next_attention_mask, past=past + next_tokens, + token_type_ids=next_token_types, + attention_mask=next_attention_mask, + past_key_values=past_key_values, )["last_hidden_state"] self.parent.assertTrue(output_from_past.shape[1] == next_tokens.shape[1])