From 5c3fb7f731743609262d7c77a636b41f69d204fe Mon Sep 17 00:00:00 2001 From: Cyril Vallez Date: Fri, 8 Aug 2025 11:52:57 +0200 Subject: [PATCH] Harmonize `past_key_value` to `past_key_valueS` everywhere (#39956) * all modulars and llama * apply modular * bert and gpt2 copies * fix imports * do it everywhere * fix import * finalize it * fix * oups set it in modular * style * fix * Add 1 version to deprecation cycle * Update modeling_layers.py --- docs/source/en/modular_transformers.md | 10 +- .../modeling_add_function.py | 3 + .../modeling_dummy_bert.py | 288 +++++++++--------- .../modeling_global_indexing.py | 8 +- .../modeling_multimodal2.py | 4 +- .../modeling_my_new_model2.py | 258 +--------------- .../modeling_new_task_model.py | 56 ++-- .../modular-transformers/modeling_roberta.py | 288 +++++++++--------- .../modular-transformers/modeling_super.py | 31 +- .../modeling_switch_function.py | 8 +- .../modeling_test_detr.py | 2 +- .../modular_dummy_bert.py | 1 + src/transformers/cache_utils.py | 12 +- src/transformers/modeling_layers.py | 1 + .../models/arcee/modeling_arcee.py | 15 +- src/transformers/models/aria/modeling_aria.py | 15 +- .../models/autoformer/modeling_autoformer.py | 33 +- .../models/bamba/modeling_bamba.py | 19 +- .../models/bamba/modular_bamba.py | 12 +- src/transformers/models/bart/modeling_bart.py | 35 ++- src/transformers/models/bert/modeling_bert.py | 61 ++-- .../modeling_bert_generation.py | 38 +-- .../models/big_bird/modeling_big_bird.py | 26 +- .../modeling_bigbird_pegasus.py | 47 +-- .../models/biogpt/modeling_biogpt.py | 31 +- .../models/biogpt/modular_biogpt.py | 10 +- .../models/bitnet/modeling_bitnet.py | 15 +- .../models/bitnet/modular_bitnet.py | 8 +- .../models/blenderbot/modeling_blenderbot.py | 35 ++- .../modeling_blenderbot_small.py | 35 ++- src/transformers/models/blip/modeling_blip.py | 2 +- .../models/blip/modeling_blip_text.py | 34 ++- .../bridgetower/modeling_bridgetower.py | 45 +-- .../models/camembert/modeling_camembert.py | 61 ++-- .../models/chameleon/modeling_chameleon.py | 24 +- src/transformers/models/clvp/modeling_clvp.py | 17 +- .../models/cohere/modeling_cohere.py | 17 +- .../models/cohere/modular_cohere.py | 15 +- .../models/cohere2/modeling_cohere2.py | 17 +- .../models/cohere2/modular_cohere2.py | 15 +- src/transformers/models/csm/modeling_csm.py | 17 +- src/transformers/models/csm/modular_csm.py | 2 +- .../models/data2vec/modeling_data2vec_text.py | 38 +-- src/transformers/models/dbrx/modeling_dbrx.py | 40 +-- .../modeling_decision_transformer.py | 31 +- .../deepseek_v2/modeling_deepseek_v2.py | 15 +- .../models/deepseek_v2/modular_deepseek_v2.py | 8 +- .../deepseek_v3/modeling_deepseek_v3.py | 15 +- .../models/deepseek_v3/modular_deepseek_v3.py | 8 +- .../deprecated/ernie_m/modeling_ernie_m.py | 39 +-- .../modeling_gptsan_japanese.py | 44 +-- .../models/deprecated/mega/modeling_mega.py | 16 +- .../models/deprecated/nezha/modeling_nezha.py | 41 +-- .../open_llama/modeling_open_llama.py | 29 +- .../deprecated/qdqbert/modeling_qdqbert.py | 41 +-- .../models/deprecated/realm/modeling_realm.py | 43 +-- .../modeling_speech_to_text_2.py | 45 +-- .../xlm_prophetnet/modeling_xlm_prophetnet.py | 44 +-- src/transformers/models/dia/modeling_dia.py | 8 +- .../models/diffllama/modeling_diffllama.py | 31 +- .../models/diffllama/modular_diffllama.py | 24 +- src/transformers/models/doge/modeling_doge.py | 15 +- src/transformers/models/doge/modular_doge.py | 13 +- .../models/dots1/modeling_dots1.py | 15 +- .../models/electra/modeling_electra.py | 38 +-- src/transformers/models/emu3/modeling_emu3.py | 15 +- src/transformers/models/emu3/modular_emu3.py | 6 +- .../models/ernie/modeling_ernie.py | 39 +-- .../models/ernie4_5/modeling_ernie4_5.py | 15 +- .../ernie4_5_moe/modeling_ernie4_5_moe.py | 17 +- .../ernie4_5_moe/modular_ernie4_5_moe.py | 2 +- .../models/evolla/modeling_evolla.py | 18 +- .../models/evolla/modular_evolla.py | 11 +- .../models/exaone4/modeling_exaone4.py | 15 +- .../models/exaone4/modular_exaone4.py | 10 +- .../models/falcon_h1/modeling_falcon_h1.py | 19 +- .../models/falcon_h1/modular_falcon_h1.py | 19 +- .../models/funnel/modeling_tf_funnel.py | 24 +- .../models/gemma/modeling_gemma.py | 15 +- .../models/gemma/modular_gemma.py | 2 +- .../models/gemma2/modeling_gemma2.py | 15 +- .../models/gemma2/modular_gemma2.py | 15 +- .../models/gemma3/modeling_gemma3.py | 15 +- .../models/gemma3/modular_gemma3.py | 15 +- .../models/gemma3n/modeling_gemma3n.py | 19 +- .../models/gemma3n/modular_gemma3n.py | 19 +- src/transformers/models/git/modeling_git.py | 20 +- src/transformers/models/glm/modeling_glm.py | 15 +- src/transformers/models/glm4/modeling_glm4.py | 15 +- src/transformers/models/glm4/modular_glm4.py | 6 +- .../models/glm4_moe/modeling_glm4_moe.py | 15 +- .../models/glm4v/modeling_glm4v.py | 17 +- .../models/glm4v/modular_glm4v.py | 17 +- src/transformers/models/gpt2/modeling_gpt2.py | 31 +- .../models/gpt_neox/modeling_gpt_neox.py | 6 +- .../models/gpt_oss/modeling_gpt_oss.py | 15 +- .../models/gpt_oss/modular_gpt_oss.py | 15 +- .../models/granite/modeling_granite.py | 17 +- .../models/granite/modular_granite.py | 10 +- .../models/granitemoe/modeling_granitemoe.py | 17 +- .../modeling_granitemoehybrid.py | 19 +- .../modular_granitemoehybrid.py | 12 +- .../modeling_granitemoeshared.py | 17 +- .../modular_granitemoeshared.py | 8 +- .../models/helium/modeling_helium.py | 15 +- .../models/idefics/modeling_idefics.py | 28 +- .../models/idefics2/modeling_idefics2.py | 23 +- .../models/informer/modeling_informer.py | 54 ++-- .../models/informer/modular_informer.py | 22 +- .../models/jamba/modeling_jamba.py | 47 +-- .../models/jetmoe/modeling_jetmoe.py | 31 +- .../models/kosmos2/modeling_kosmos2.py | 33 +- .../modeling_kyutai_speech_to_text.py | 37 +-- src/transformers/models/led/modeling_led.py | 39 +-- src/transformers/models/lfm2/modeling_lfm2.py | 48 +-- src/transformers/models/lfm2/modular_lfm2.py | 48 +-- .../models/lightglue/modeling_lightglue.py | 2 + .../models/llama/modeling_llama.py | 15 +- .../models/llama4/modeling_llama4.py | 17 +- .../models/longt5/modeling_longt5.py | 43 +-- .../models/m2m_100/modeling_m2m_100.py | 36 ++- .../models/marian/modeling_marian.py | 36 ++- .../models/mbart/modeling_mbart.py | 35 ++- .../megatron_bert/modeling_megatron_bert.py | 36 ++- src/transformers/models/mimi/modeling_mimi.py | 35 ++- .../models/minimax/modeling_minimax.py | 28 +- .../models/minimax/modular_minimax.py | 21 +- .../models/mistral/modeling_mistral.py | 15 +- .../models/mistral/modular_mistral.py | 10 +- .../models/mixtral/modeling_mixtral.py | 15 +- .../models/mixtral/modular_mixtral.py | 8 +- .../models/mllama/modeling_mllama.py | 33 +- .../modeling_modernbert_decoder.py | 15 +- .../modular_modernbert_decoder.py | 15 +- .../models/moonshine/modeling_moonshine.py | 42 +-- .../models/moonshine/modular_moonshine.py | 37 +-- .../models/moshi/modeling_moshi.py | 39 +-- src/transformers/models/mpt/modeling_mpt.py | 19 +- src/transformers/models/mt5/modeling_mt5.py | 39 +-- .../models/musicgen/modeling_musicgen.py | 37 +-- .../modeling_musicgen_melody.py | 33 +- src/transformers/models/mvp/modeling_mvp.py | 35 ++- .../models/nemotron/modeling_nemotron.py | 37 +-- .../models/nllb_moe/modeling_nllb_moe.py | 35 ++- src/transformers/models/olmo/modeling_olmo.py | 15 +- src/transformers/models/olmo/modular_olmo.py | 8 +- .../models/olmo2/modeling_olmo2.py | 15 +- .../models/olmo2/modular_olmo2.py | 13 +- .../models/olmoe/modeling_olmoe.py | 33 +- src/transformers/models/opt/modeling_opt.py | 19 +- .../models/pegasus/modeling_pegasus.py | 35 ++- .../models/pegasus_x/modeling_pegasus_x.py | 35 ++- .../models/persimmon/modeling_persimmon.py | 17 +- src/transformers/models/phi/modeling_phi.py | 15 +- src/transformers/models/phi/modular_phi.py | 15 +- src/transformers/models/phi3/modeling_phi3.py | 15 +- src/transformers/models/phi3/modular_phi3.py | 13 +- .../modeling_phi4_multimodal.py | 15 +- .../modular_phi4_multimodal.py | 2 +- .../models/phimoe/modeling_phimoe.py | 33 +- .../models/pix2struct/modeling_pix2struct.py | 41 +-- .../models/plbart/modeling_plbart.py | 35 ++- .../models/plbart/modular_plbart.py | 2 +- .../models/pop2piano/modeling_pop2piano.py | 39 +-- .../models/prophetnet/modeling_prophetnet.py | 42 +-- .../models/qwen2/modeling_qwen2.py | 15 +- .../models/qwen2/modular_qwen2.py | 10 +- .../qwen2_5_omni/modeling_qwen2_5_omni.py | 19 +- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 17 +- .../models/qwen2_moe/modeling_qwen2_moe.py | 34 ++- .../models/qwen2_vl/modeling_qwen2_vl.py | 17 +- .../models/qwen3/modeling_qwen3.py | 15 +- .../models/qwen3/modular_qwen3.py | 8 +- .../models/qwen3_moe/modeling_qwen3_moe.py | 17 +- .../models/qwen3_moe/modular_qwen3_moe.py | 6 +- .../models/rag/modeling_tf_rag.py | 38 +-- .../models/reformer/modeling_reformer.py | 6 +- .../models/rembert/modeling_rembert.py | 32 +- .../models/roberta/modeling_roberta.py | 61 ++-- .../modeling_roberta_prelayernorm.py | 38 +-- .../models/roc_bert/modeling_roc_bert.py | 38 +-- .../models/roformer/modeling_roformer.py | 35 ++- .../models/rt_detr/modeling_rt_detr_resnet.py | 31 +- .../seamless_m4t/modeling_seamless_m4t.py | 45 +-- .../modeling_seamless_m4t_v2.py | 43 +-- .../models/smollm3/modeling_smollm3.py | 15 +- .../models/smollm3/modular_smollm3.py | 8 +- .../speech_to_text/modeling_speech_to_text.py | 35 ++- .../models/speecht5/modeling_speecht5.py | 33 +- .../models/stablelm/modeling_stablelm.py | 32 +- .../models/starcoder2/modeling_starcoder2.py | 15 +- .../models/starcoder2/modular_starcoder2.py | 10 +- .../modeling_switch_transformers.py | 38 +-- src/transformers/models/t5/modeling_t5.py | 39 +-- .../models/t5gemma/modeling_t5gemma.py | 42 +-- .../models/t5gemma/modular_t5gemma.py | 28 +- .../models/tapas/modeling_tapas.py | 36 +-- .../modeling_time_series_transformer.py | 33 +- .../models/trocr/modeling_trocr.py | 33 +- src/transformers/models/udop/modeling_udop.py | 39 +-- src/transformers/models/umt5/modeling_umt5.py | 41 +-- .../models/whisper/modeling_whisper.py | 37 +-- src/transformers/models/xglm/modeling_xglm.py | 33 +- .../xlm_roberta/modeling_xlm_roberta.py | 61 ++-- .../xlm_roberta_xl/modeling_xlm_roberta_xl.py | 59 ++-- src/transformers/models/xmod/modeling_xmod.py | 36 ++- .../models/zamba/modeling_zamba.py | 35 ++- .../models/zamba2/modeling_zamba2.py | 35 ++- .../models/zamba2/modular_zamba2.py | 26 +- src/transformers/utils/auto_docstring.py | 7 - tests/utils/test_auto_docstring.py | 2 +- 211 files changed, 3159 insertions(+), 2733 deletions(-) diff --git a/docs/source/en/modular_transformers.md b/docs/source/en/modular_transformers.md index d97de838af..ceade678b8 100644 --- a/docs/source/en/modular_transformers.md +++ b/docs/source/en/modular_transformers.md @@ -218,7 +218,7 @@ class Olmo2Attention(OlmoAttention): hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: @@ -236,10 +236,10 @@ class Olmo2Attention(OlmoAttention): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -290,7 +290,7 @@ class Olmo2DecoderLayer(OlmoDecoderLayer): hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -304,7 +304,7 @@ class Olmo2DecoderLayer(OlmoDecoderLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, diff --git a/examples/modular-transformers/modeling_add_function.py b/examples/modular-transformers/modeling_add_function.py index fdc768237b..06cdc5c571 100644 --- a/examples/modular-transformers/modeling_add_function.py +++ b/examples/modular-transformers/modeling_add_function.py @@ -10,6 +10,8 @@ from typing import Optional import torch from torch import nn +from ...utils.deprecation import deprecate_kwarg + def rotate_half(x): """Rotates half the hidden dims of the input.""" @@ -62,5 +64,6 @@ class TestAttention(nn.Module): def __init__(self): pass + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward(self) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: _ = apply_rotary_pos_emb(1, 1, 1, 1) diff --git a/examples/modular-transformers/modeling_dummy_bert.py b/examples/modular-transformers/modeling_dummy_bert.py index 40bd423067..1065afe80d 100644 --- a/examples/modular-transformers/modeling_dummy_bert.py +++ b/examples/modular-transformers/modeling_dummy_bert.py @@ -13,12 +13,14 @@ from packaging import version from torch import nn from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, get_torch_version, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_dummy_bert import DummyBertConfig @@ -90,7 +92,7 @@ class DummyBertEmbeddings(nn.Module): class DummyBertSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -115,66 +117,68 @@ class DummyBertSelfAttention(nn.Module): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) - # If this is instantiated as a cross-attention module, the keys - # and values come from an encoder; the attention mask needs to be - # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_values.cross_attention_cache + else: + curr_past_key_value = past_key_values.self_attention_cache + else: + curr_past_key_value = past_key_values - if is_cross_attention and past_key_value is not None: + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.layers[self.layer_idx].keys + value_layer = curr_past_key_value.layers[self.layer_idx].values else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = self.key(current_states) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + value_layer = self.value(current_states) + value_layer = value_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) - query_layer = self.transpose_for_scores(mixed_query_layer) - - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) + if past_key_values is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_values.is_updated[self.layer_idx] = True # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: + if past_key_values is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -216,29 +220,26 @@ class DummyBertSelfAttention(nn.Module): new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return context_layer, attention_probs class DummyBertSdpaSelfAttention(DummyBertSelfAttention): - def __init__(self, config, position_embedding_type=None): - super().__init__(config, position_embedding_type=position_embedding_type) + def __init__(self, config, position_embedding_type=None, layer_idx=None): + super().__init__(config, position_embedding_type=position_embedding_type, layer_idx=layer_idx) self.dropout_prob = config.attention_probs_dropout_prob self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") # Adapted from DummyBertSelfAttention + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None: # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. @@ -254,41 +255,56 @@ class DummyBertSdpaSelfAttention(DummyBertSelfAttention): attention_mask, head_mask, encoder_hidden_states, - encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) bsz, tgt_len, _ = hidden_states.size() - query_layer = self.transpose_for_scores(self.query(hidden_states)) + query_layer = ( + self.query(hidden_states).view(bsz, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + ) - # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention - # mask needs to be such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_values.cross_attention_cache + else: + curr_past_key_value = past_key_values.self_attention_cache + else: + curr_past_key_value = past_key_values current_states = encoder_hidden_states if is_cross_attention else hidden_states - attention_mask = encoder_attention_mask if is_cross_attention else attention_mask - - # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning - if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: - key_layer, value_layer = past_key_value + if is_cross_attention and past_key_values is not None and is_updated: + # reuse k,v, cross_attentions + key_layer = curr_past_key_value.layers[self.layer_idx].keys + value_layer = curr_past_key_value.layers[self.layer_idx].values else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) - if past_key_value is not None and not is_cross_attention: - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = ( + self.key(current_states) + .view(bsz, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(current_states) + .view(bsz, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) + if past_key_values is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_values.is_updated[self.layer_idx] = True # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. @@ -302,9 +318,7 @@ class DummyBertSdpaSelfAttention(DummyBertSelfAttention): # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create # a causal mask in case tgt_len == 1. - is_causal = ( - True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False - ) + is_causal = self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 attn_output = torch.nn.functional.scaled_dot_product_attention( query_layer, @@ -318,10 +332,7 @@ class DummyBertSdpaSelfAttention(DummyBertSelfAttention): attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) - outputs = (attn_output,) - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return attn_output, None class DummyBertSelfOutput(nn.Module): @@ -345,10 +356,12 @@ DUMMY_BERT_SELF_ATTENTION_CLASSES = { class DummyBertAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() self.self = DUMMY_BERT_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, position_embedding_type=position_embedding_type + config, + position_embedding_type=position_embedding_type, + layer_idx=layer_idx, ) self.output = DummyBertSelfOutput(config) self.pruned_heads = set() @@ -371,24 +384,25 @@ class DummyBertAttention(nn.Module): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_values=past_key_values, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -425,20 +439,21 @@ class DummyBertOutput(nn.Module): class DummyBertLayer(GradientCheckpointingLayer): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = DummyBertAttention(config) + self.attention = DummyBertAttention(config, layer_idx=layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = DummyBertAttention(config, position_embedding_type="absolute") + self.crossattention = DummyBertAttention(config, position_embedding_type="absolute", layer_idx=layer_idx) self.intermediate = DummyBertIntermediate(config) self.output = DummyBertOutput(config) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -446,28 +461,21 @@ class DummyBertLayer(GradientCheckpointingLayer): head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_values=past_key_values, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -475,33 +483,23 @@ class DummyBertLayer(GradientCheckpointingLayer): " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - cross_attn_past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_values=past_key_values, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (present_key_value,) - return outputs def feed_forward_chunk(self, attention_output): @@ -511,10 +509,10 @@ class DummyBertLayer(GradientCheckpointingLayer): class DummyBertEncoder(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.config = config - self.layer = nn.ModuleList([DummyBertLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([DummyBertLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( @@ -529,6 +527,7 @@ class DummyBertEncoder(nn.Module): output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -541,13 +540,21 @@ class DummyBertEncoder(nn.Module): ) use_cache = False - next_decoder_cache = () if use_cache else None + return_legacy_cache = False + if use_cache and self.config.is_decoder and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None layer_outputs = layer_module( hidden_states, @@ -555,13 +562,12 @@ class DummyBertEncoder(nn.Module): layer_head_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -570,12 +576,15 @@ class DummyBertEncoder(nn.Module): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -584,7 +593,7 @@ class DummyBertEncoder(nn.Module): ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -721,7 +730,7 @@ def load_tf_weights_in_dummy_bert(model, config, tf_checkpoint_path): @auto_docstring class DummyBertPreTrainedModel(PreTrainedModel): - config_class = DummyBertConfig + config: DummyBertConfig load_tf_weights = load_tf_weights_in_dummy_bert base_model_prefix = "dummy_bert" supports_gradient_checkpointing = True @@ -810,6 +819,7 @@ class DummyBertModel(DummyBertPreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -835,8 +845,13 @@ class DummyBertModel(DummyBertPreTrainedModel): batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if token_type_ids is None: if hasattr(self.embeddings, "token_type_ids"): @@ -921,6 +936,7 @@ class DummyBertModel(DummyBertPreTrainedModel): output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None diff --git a/examples/modular-transformers/modeling_global_indexing.py b/examples/modular-transformers/modeling_global_indexing.py index 1e462c9582..50f86222b3 100644 --- a/examples/modular-transformers/modeling_global_indexing.py +++ b/examples/modular-transformers/modeling_global_indexing.py @@ -14,6 +14,7 @@ from transformers.modeling_utils import AttentionInterface from ...cache_utils import Cache from ...processing_utils import Unpack from ...utils import TransformersKwargs +from ...utils.deprecation import deprecate_kwarg from .configuration_global_indexing import GlobalIndexingConfig @@ -125,12 +126,13 @@ class GlobalIndexingAttention(nn.Module): config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: @@ -144,10 +146,10 @@ class GlobalIndexingAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/examples/modular-transformers/modeling_multimodal2.py b/examples/modular-transformers/modeling_multimodal2.py index 01aca3a4d5..bb011ee126 100644 --- a/examples/modular-transformers/modeling_multimodal2.py +++ b/examples/modular-transformers/modeling_multimodal2.py @@ -493,7 +493,7 @@ class Multimodal2VisionTransformer(nn.Module): @auto_docstring class Multimodal2VisionPreTrainedModel(PreTrainedModel): - config_class = Multimodal2Config + config: Multimodal2Config base_model_prefix = "multimodal2_vision" supports_gradient_checkpointing = True _supports_sdpa = True @@ -512,7 +512,7 @@ MULTIMODAL2_VISION_START_DOCSTRING = "doc" @add_start_docstrings("New doc", MULTIMODAL2_VISION_START_DOCSTRING) class Multimodal2VisionModel(Multimodal2VisionPreTrainedModel): - config_class = Multimodal2VisionConfig + config: Multimodal2VisionConfig main_input_name = "pixel_values" _no_split_modules = ["Multimodal2VisionEncoderLayer"] diff --git a/examples/modular-transformers/modeling_my_new_model2.py b/examples/modular-transformers/modeling_my_new_model2.py index e56eeec7d7..4c1c8b0c0c 100644 --- a/examples/modular-transformers/modeling_my_new_model2.py +++ b/examples/modular-transformers/modeling_my_new_model2.py @@ -10,21 +10,15 @@ import torch from torch import nn from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache -from ...masking_utils import create_causal_mask -from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import BaseModelOutputWithPast, SequenceClassifierOutputWithPast -from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update +from ...cache_utils import Cache +from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging -from ...utils.generic import check_model_inputs +from ...utils import TransformersKwargs, auto_docstring +from ...utils.deprecation import deprecate_kwarg from .configuration_my_new_model2 import MyNewModel2Config -logger = logging.get_logger(__name__) - - class MyNewModel2RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() @@ -61,40 +55,6 @@ class MyNewModel2MLP(nn.Module): return down_proj -class MyNewModel2RotaryEmbedding(nn.Module): - def __init__(self, config: MyNewModel2Config, device=None): - super().__init__() - # BC: "rope_type" was originally "type" - if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): - self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) - else: - self.rope_type = "default" - self.max_seq_len_cached = config.max_position_embeddings - self.original_max_seq_len = config.max_position_embeddings - - self.config = config - self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] - - inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) - self.register_buffer("inv_freq", inv_freq, persistent=False) - self.original_inv_freq = self.inv_freq - - @torch.no_grad() - @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope) - def forward(self, x, position_ids): - inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) - position_ids_expanded = position_ids[:, None, :].float() - - device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" - with torch.autocast(device_type=device_type, enabled=False): # Force float32 - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() * self.attention_scaling - sin = emb.sin() * self.attention_scaling - - return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) - - def rotate_half(x): """Rotates half the hidden dims of the input.""" x1 = x[..., : x.shape[-1] // 2] @@ -193,12 +153,13 @@ class MyNewModel2Attention(nn.Module): config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: @@ -212,10 +173,10 @@ class MyNewModel2Attention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -248,12 +209,13 @@ class MyNewModel2DecoderLayer(GradientCheckpointingLayer): self.input_layernorm = MyNewModel2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = MyNewModel2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -266,7 +228,7 @@ class MyNewModel2DecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -284,7 +246,7 @@ class MyNewModel2DecoderLayer(GradientCheckpointingLayer): @auto_docstring class MyNewModel2PreTrainedModel(PreTrainedModel): - config_class = MyNewModel2Config + config: MyNewModel2Config base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["MyNewModel2DecoderLayer"] @@ -292,8 +254,7 @@ class MyNewModel2PreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True + _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { @@ -301,195 +262,6 @@ class MyNewModel2PreTrainedModel(PreTrainedModel): "attentions": MyNewModel2Attention, } - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, MyNewModel2RMSNorm): - module.weight.data.fill_(1.0) - -@auto_docstring -class MyNewModel2Model(MyNewModel2PreTrainedModel): - def __init__(self, config: MyNewModel2Config): - super().__init__(config) - self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size - - self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) - self.layers = nn.ModuleList( - [MyNewModel2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] - ) - self.norm = MyNewModel2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.rotary_emb = MyNewModel2RotaryEmbedding(config=config) - self.gradient_checkpointing = False - - # Initialize weights and apply final processing - self.post_init() - - @check_model_inputs - @auto_docstring - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, - **kwargs: Unpack[TransformersKwargs], - ) -> BaseModelOutputWithPast: - if (input_ids is None) ^ (inputs_embeds is not None): - raise ValueError("You must specify exactly one of input_ids or inputs_embeds") - - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - - if use_cache and past_key_values is None: - past_key_values = DynamicCache() - - if cache_position is None: - past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 - cache_position = torch.arange( - past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device - ) - - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - - causal_mask = create_causal_mask( - config=self.config, - input_embeds=inputs_embeds, - attention_mask=attention_mask, - cache_position=cache_position, - past_key_values=past_key_values, - position_ids=position_ids, - ) - - # embed positions - hidden_states = inputs_embeds - - # create position embeddings to be shared across the decoder layers - position_embeddings = self.rotary_emb(hidden_states, position_ids) - - # normalized - # MyNewModel2 downcasts the below to float16, causing sqrt(3072)=55.4256 to become 55.5 - # See https://github.com/huggingface/transformers/pull/29402 - normalizer = torch.tensor(self.config.hidden_size**0.5, dtype=hidden_states.dtype) - hidden_states = hidden_states * normalizer - - for decoder_layer in self.layers[: self.config.num_hidden_layers]: - hidden_states = decoder_layer( - hidden_states, - attention_mask=causal_mask, - position_ids=position_ids, - past_key_value=past_key_values, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - **kwargs, - ) - hidden_states = self.norm(hidden_states) - return BaseModelOutputWithPast( - last_hidden_state=hidden_states, - past_key_values=past_key_values if use_cache else None, - ) - - -@auto_docstring( - custom_intro=""" - The MyNewModel2 Model transformer with a sequence classification head on top (linear layer). - - [`MyNewModel2ForSequenceClassification`] uses the last token in order to do the classification, as other causal models - (e.g. GPT-2) do. - - Since it does classification on the last token, it requires to know the position of the last token. If a - `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If - no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the - padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in - each row of the batch). - """ -) -class MyNewModel2ForSequenceClassification(MyNewModel2PreTrainedModel): - def __init__(self, config): - super().__init__(config) - self.num_labels = config.num_labels - self.model = MyNewModel2Model(config) - self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - @can_return_tuple - @auto_docstring - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Cache] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - **kwargs: Unpack[TransformersKwargs], - ) -> SequenceClassifierOutputWithPast: - r""" - labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): - Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., - config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If - `config.num_labels > 1` a classification loss is computed (Cross-Entropy). - """ - - transformer_outputs: BaseModelOutputWithPast = self.model( - input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - **kwargs, - ) - hidden_states = transformer_outputs.last_hidden_state - logits = self.score(hidden_states) - - if input_ids is not None: - batch_size = input_ids.shape[0] - else: - batch_size = inputs_embeds.shape[0] - - if self.config.pad_token_id is None and batch_size != 1: - raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") - if self.config.pad_token_id is None: - last_non_pad_token = -1 - elif input_ids is not None: - # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id - non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32) - token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32) - last_non_pad_token = (token_indices * non_pad_mask).argmax(-1) - else: - last_non_pad_token = -1 - logger.warning_once( - f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " - "unexpected if using padding tokens in conjunction with `inputs_embeds.`" - ) - - pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token] - - loss = None - if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) - - return SequenceClassifierOutputWithPast( - loss=loss, - logits=pooled_logits, - past_key_values=transformer_outputs.past_key_values, - hidden_states=transformer_outputs.hidden_states, - attentions=transformer_outputs.attentions, - ) +class MyNewModel2ForSequenceClassification(GenericForSequenceClassification, MyNewModel2PreTrainedModel): + pass diff --git a/examples/modular-transformers/modeling_new_task_model.py b/examples/modular-transformers/modeling_new_task_model.py index 2a3df8e9c1..eb35c2ade5 100644 --- a/examples/modular-transformers/modeling_new_task_model.py +++ b/examples/modular-transformers/modeling_new_task_model.py @@ -16,7 +16,7 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack -from ...utils import ModelOutput, auto_docstring, can_return_tuple, is_torchdynamo_compiling +from ...utils import ModelOutput, auto_docstring, can_return_tuple from ..auto import AutoModel from .configuration_new_task_model import NewTaskModelConfig @@ -29,7 +29,7 @@ from .configuration_new_task_model import NewTaskModelConfig ) class NewTaskModelModelOutputWithPast(BaseModelOutputWithPast): r""" - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -55,7 +55,7 @@ class NewTaskModelCausalLMOutputWithPast(ModelOutput): Language modeling loss (for next-token prediction). logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.text_config.vocab_size)`): Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). - past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`) @@ -87,13 +87,12 @@ class NewTaskModelMultiModalProjector(nn.Module): @auto_docstring class NewTaskModelPreTrainedModel(PreTrainedModel): - config_class = NewTaskModelConfig + config: NewTaskModelConfig base_model_prefix = "" supports_gradient_checkpointing = True _no_split_modules = ["NewTaskModelMultiModalProjector"] _skip_keys_device_placement = "past_key_values" - _supports_cache_class = True - _supports_quantized_cache = True + _can_compile_fullgraph = True _supports_flash_attn = True _supports_sdpa = True @@ -229,6 +228,30 @@ class NewTaskModelModel(NewTaskModelPreTrainedModel): image_features = image_features / (self.config.text_config.hidden_size**0.5) return image_features + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, image_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholdr mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_image_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_image_mask = special_image_mask.all(-1) + else: + special_image_mask = input_ids == self.config.image_token_id + + n_image_tokens = special_image_mask.sum() + special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + n_image_features = image_features.shape[0] * image_features.shape[1] + if inputs_embeds[special_image_mask].numel() != image_features.numel(): + raise ValueError( + f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" + ) + return special_image_mask + @can_return_tuple @auto_docstring def forward( @@ -310,25 +333,10 @@ class NewTaskModelModel(NewTaskModelPreTrainedModel): # Merge text and images if pixel_values is not None: image_features = self.get_image_features(pixel_values) - - if input_ids is None: - special_image_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_image_mask = special_image_mask.all(-1) - else: - special_image_mask = input_ids == self.config.image_token_id - - special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - - if not is_torchdynamo_compiling() and inputs_embeds[special_image_mask].numel() != image_features.numel(): - image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] - raise ValueError( - f"Number of images does not match number of special image tokens in the input text. " - f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} " - "tokens from image embeddings." - ) image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) + special_image_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, image_features=image_features + ) inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) causal_mask = self._update_causal_mask( diff --git a/examples/modular-transformers/modeling_roberta.py b/examples/modular-transformers/modeling_roberta.py index 320b8eee15..68ce50d452 100644 --- a/examples/modular-transformers/modeling_roberta.py +++ b/examples/modular-transformers/modeling_roberta.py @@ -13,12 +13,14 @@ import torch.nn as nn from packaging import version from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPoolingAndCrossAttentions from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, get_torch_version, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_roberta import RobertaConfig @@ -93,7 +95,7 @@ class RobertaEmbeddings(nn.Module): class RobertaSelfAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( @@ -118,66 +120,68 @@ class RobertaSelfAttention(nn.Module): self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) self.is_decoder = config.is_decoder + self.layer_idx = layer_idx - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - mixed_query_layer = self.query(hidden_states) + batch_size, seq_length, _ = hidden_states.shape + query_layer = self.query(hidden_states) + query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) - # If this is instantiated as a cross-attention module, the keys - # and values come from an encoder; the attention mask needs to be - # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_layer from cache + curr_past_key_value = past_key_values.cross_attention_cache + else: + curr_past_key_value = past_key_values.self_attention_cache + else: + curr_past_key_value = past_key_values - if is_cross_attention and past_key_value is not None: + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] - attention_mask = encoder_attention_mask - elif is_cross_attention: - key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) - value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) - attention_mask = encoder_attention_mask - elif past_key_value is not None: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = curr_past_key_value.layers[self.layer_idx].keys + value_layer = curr_past_key_value.layers[self.layer_idx].values else: - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) + key_layer = self.key(current_states) + key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( + 1, 2 + ) + value_layer = self.value(current_states) + value_layer = value_layer.view( + batch_size, -1, self.num_attention_heads, self.attention_head_size + ).transpose(1, 2) - query_layer = self.transpose_for_scores(mixed_query_layer) - - use_cache = past_key_value is not None - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) + if past_key_values is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_values.is_updated[self.layer_idx] = True # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if use_cache: + if past_key_values is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -219,29 +223,26 @@ class RobertaSelfAttention(nn.Module): new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return context_layer, attention_probs class RobertaSdpaSelfAttention(RobertaSelfAttention): - def __init__(self, config, position_embedding_type=None): - super().__init__(config, position_embedding_type=position_embedding_type) + def __init__(self, config, position_embedding_type=None, layer_idx=None): + super().__init__(config, position_embedding_type=position_embedding_type, layer_idx=layer_idx) self.dropout_prob = config.attention_probs_dropout_prob self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") # Adapted from RobertaSelfAttention + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None: # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented. @@ -257,41 +258,56 @@ class RobertaSdpaSelfAttention(RobertaSelfAttention): attention_mask, head_mask, encoder_hidden_states, - encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, + cache_position, ) bsz, tgt_len, _ = hidden_states.size() - query_layer = self.transpose_for_scores(self.query(hidden_states)) + query_layer = ( + self.query(hidden_states).view(bsz, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) + ) - # If this is instantiated as a cross-attention module, the keys and values come from an encoder; the attention - # mask needs to be such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None + current_states = encoder_hidden_states if is_cross_attention else hidden_states + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_values.cross_attention_cache + else: + curr_past_key_value = past_key_values.self_attention_cache + else: + curr_past_key_value = past_key_values current_states = encoder_hidden_states if is_cross_attention else hidden_states - attention_mask = encoder_attention_mask if is_cross_attention else attention_mask - - # Check `seq_length` of `past_key_value` == `len(current_states)` to support prefix tuning - if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: - key_layer, value_layer = past_key_value + if is_cross_attention and past_key_values is not None and is_updated: + # reuse k,v, cross_attentions + key_layer = curr_past_key_value.layers[self.layer_idx].keys + value_layer = curr_past_key_value.layers[self.layer_idx].values else: - key_layer = self.transpose_for_scores(self.key(current_states)) - value_layer = self.transpose_for_scores(self.value(current_states)) - if past_key_value is not None and not is_cross_attention: - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = ( + self.key(current_states) + .view(bsz, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) + value_layer = ( + self.value(current_states) + .view(bsz, -1, self.num_attention_heads, self.attention_head_size) + .transpose(1, 2) + ) - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) + if past_key_values is not None: + # save all key/value_layer to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_layer, value_layer = curr_past_key_value.update( + key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_values.is_updated[self.layer_idx] = True # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. @@ -305,9 +321,7 @@ class RobertaSdpaSelfAttention(RobertaSelfAttention): # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create # a causal mask in case tgt_len == 1. - is_causal = ( - True if self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 else False - ) + is_causal = self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 attn_output = torch.nn.functional.scaled_dot_product_attention( query_layer, @@ -321,10 +335,7 @@ class RobertaSdpaSelfAttention(RobertaSelfAttention): attn_output = attn_output.transpose(1, 2) attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) - outputs = (attn_output,) - if self.is_decoder: - outputs = outputs + (past_key_value,) - return outputs + return attn_output, None class RobertaSelfOutput(nn.Module): @@ -348,10 +359,12 @@ ROBERTA_SELF_ATTENTION_CLASSES = { class RobertaAttention(nn.Module): - def __init__(self, config, position_embedding_type=None): + def __init__(self, config, position_embedding_type=None, layer_idx=None): super().__init__() self.self = ROBERTA_SELF_ATTENTION_CLASSES[config._attn_implementation]( - config, position_embedding_type=position_embedding_type + config, + position_embedding_type=position_embedding_type, + layer_idx=layer_idx, ) self.output = RobertaSelfOutput(config) self.pruned_heads = set() @@ -374,24 +387,25 @@ class RobertaAttention(nn.Module): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: self_outputs = self.self( hidden_states, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - past_key_value, - output_attentions, + attention_mask=attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_values=past_key_values, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = self.output(self_outputs[0], hidden_states) outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them @@ -428,20 +442,21 @@ class RobertaOutput(nn.Module): class RobertaLayer(GradientCheckpointingLayer): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.chunk_size_feed_forward = config.chunk_size_feed_forward self.seq_len_dim = 1 - self.attention = RobertaAttention(config) + self.attention = RobertaAttention(config, layer_idx=layer_idx) self.is_decoder = config.is_decoder self.add_cross_attention = config.add_cross_attention if self.add_cross_attention: if not self.is_decoder: raise ValueError(f"{self} should be used as a decoder model if cross attention is added") - self.crossattention = RobertaAttention(config, position_embedding_type="absolute") + self.crossattention = RobertaAttention(config, position_embedding_type="absolute", layer_idx=layer_idx) self.intermediate = RobertaIntermediate(config) self.output = RobertaOutput(config) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -449,28 +464,21 @@ class RobertaLayer(GradientCheckpointingLayer): head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, + cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None self_attention_outputs = self.attention( hidden_states, - attention_mask, - head_mask, + attention_mask=attention_mask, + head_mask=head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_values=past_key_values, + cache_position=cache_position, ) attention_output = self_attention_outputs[0] + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - # if decoder, the last output is tuple of self-attn cache - if self.is_decoder: - outputs = self_attention_outputs[1:-1] - present_key_value = self_attention_outputs[-1] - else: - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - cross_attn_present_key_value = None if self.is_decoder and encoder_hidden_states is not None: if not hasattr(self, "crossattention"): raise ValueError( @@ -478,33 +486,23 @@ class RobertaLayer(GradientCheckpointingLayer): " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None cross_attention_outputs = self.crossattention( attention_output, - attention_mask, - head_mask, - encoder_hidden_states, - encoder_attention_mask, - cross_attn_past_key_value, - output_attentions, + attention_mask=encoder_attention_mask, + head_mask=head_mask, + encoder_hidden_states=encoder_hidden_states, + past_key_values=past_key_values, + output_attentions=output_attentions, + cache_position=cache_position, ) attention_output = cross_attention_outputs[0] - outputs = outputs + cross_attention_outputs[1:-1] # add cross attentions if we output attention weights - - # add cross-attn cache to positions 3,4 of present_key_value tuple - cross_attn_present_key_value = cross_attention_outputs[-1] - present_key_value = present_key_value + cross_attn_present_key_value + outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights layer_output = apply_chunking_to_forward( self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs - # if decoder, return the attn key/values as the last output - if self.is_decoder: - outputs = outputs + (present_key_value,) - return outputs def feed_forward_chunk(self, attention_output): @@ -514,10 +512,10 @@ class RobertaLayer(GradientCheckpointingLayer): class RobertaEncoder(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx=None): super().__init__() self.config = config - self.layer = nn.ModuleList([RobertaLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer = nn.ModuleList([RobertaLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( @@ -532,6 +530,7 @@ class RobertaEncoder(nn.Module): output_attentions: Optional[bool] = False, output_hidden_states: Optional[bool] = False, return_dict: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: all_hidden_states = () if output_hidden_states else None all_self_attentions = () if output_attentions else None @@ -544,13 +543,21 @@ class RobertaEncoder(nn.Module): ) use_cache = False - next_decoder_cache = () if use_cache else None + return_legacy_cache = False + if use_cache and self.config.is_decoder and not isinstance(past_key_values, Cache): + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + return_legacy_cache = True + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + for i, layer_module in enumerate(self.layer): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None layer_outputs = layer_module( hidden_states, @@ -558,13 +565,12 @@ class RobertaEncoder(nn.Module): layer_head_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = layer_outputs[0] - if use_cache: - next_decoder_cache += (layer_outputs[-1],) if output_attentions: all_self_attentions = all_self_attentions + (layer_outputs[1],) if self.config.add_cross_attention: @@ -573,12 +579,15 @@ class RobertaEncoder(nn.Module): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ hidden_states, - next_decoder_cache, + past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions, @@ -587,7 +596,7 @@ class RobertaEncoder(nn.Module): ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=next_decoder_cache, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, @@ -724,7 +733,7 @@ def load_tf_weights_in_roberta(model, config, tf_checkpoint_path): @auto_docstring class RobertaPreTrainedModel(PreTrainedModel): - config_class = RobertaConfig + config: RobertaConfig load_tf_weights = load_tf_weights_in_roberta base_model_prefix = "roberta" supports_gradient_checkpointing = True @@ -813,6 +822,7 @@ class RobertaModel(RobertaPreTrainedModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -838,8 +848,13 @@ class RobertaModel(RobertaPreTrainedModel): batch_size, seq_length = input_shape device = input_ids.device if input_ids is not None else inputs_embeds.device - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = ( + past_key_values[0][0].shape[-2] + if not isinstance(past_key_values, Cache) + else past_key_values.get_seq_length() + ) if token_type_ids is None: if hasattr(self.embeddings, "token_type_ids"): @@ -924,6 +939,7 @@ class RobertaModel(RobertaPreTrainedModel): output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = encoder_outputs[0] pooled_output = self.pooler(sequence_output) if self.pooler is not None else None diff --git a/examples/modular-transformers/modeling_super.py b/examples/modular-transformers/modeling_super.py index ee90750cac..6927dab86d 100644 --- a/examples/modular-transformers/modeling_super.py +++ b/examples/modular-transformers/modeling_super.py @@ -19,6 +19,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import check_model_inputs from .configuration_super import SuperConfig @@ -192,12 +193,13 @@ class SuperAttention(nn.Module): config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: @@ -211,10 +213,10 @@ class SuperAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -247,12 +249,13 @@ class SuperDecoderLayer(GradientCheckpointingLayer): self.input_layernorm = SuperRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = SuperRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -265,7 +268,7 @@ class SuperDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -283,7 +286,7 @@ class SuperDecoderLayer(GradientCheckpointingLayer): @auto_docstring class SuperPreTrainedModel(PreTrainedModel): - config_class = SuperConfig + config: SuperConfig base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["SuperDecoderLayer"] @@ -291,8 +294,7 @@ class SuperPreTrainedModel(PreTrainedModel): _supports_flash_attn = True _supports_sdpa = True _supports_flex_attn = True - _supports_cache_class = True - _supports_quantized_cache = True + _can_compile_fullgraph = True _supports_attention_backend = True _can_record_outputs = { @@ -300,19 +302,6 @@ class SuperPreTrainedModel(PreTrainedModel): "attentions": SuperAttention, } - def _init_weights(self, module): - std = self.config.initializer_range - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - elif isinstance(module, SuperRMSNorm): - module.weight.data.fill_(1.0) - @auto_docstring class SuperModel(SuperPreTrainedModel): diff --git a/examples/modular-transformers/modeling_switch_function.py b/examples/modular-transformers/modeling_switch_function.py index 6b443d3411..6e8ffed806 100644 --- a/examples/modular-transformers/modeling_switch_function.py +++ b/examples/modular-transformers/modeling_switch_function.py @@ -14,6 +14,7 @@ from ...cache_utils import Cache from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...utils import TransformersKwargs +from ...utils.deprecation import deprecate_kwarg from .configuration_switch_function import SwitchFunctionConfig @@ -116,12 +117,13 @@ class SwitchFunctionAttention(nn.Module): config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: @@ -135,10 +137,10 @@ class SwitchFunctionAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/examples/modular-transformers/modeling_test_detr.py b/examples/modular-transformers/modeling_test_detr.py index e1286c3256..ac2d45a863 100644 --- a/examples/modular-transformers/modeling_test_detr.py +++ b/examples/modular-transformers/modeling_test_detr.py @@ -803,7 +803,7 @@ class TestDetrDecoderLayer(GradientCheckpointingLayer): @auto_docstring class TestDetrPreTrainedModel(PreTrainedModel): - config_class = TestDetrConfig + config: TestDetrConfig base_model_prefix = "model" main_input_name = "pixel_values" supports_gradient_checkpointing = True diff --git a/examples/modular-transformers/modular_dummy_bert.py b/examples/modular-transformers/modular_dummy_bert.py index 34d2cd1b33..fb7440228d 100644 --- a/examples/modular-transformers/modular_dummy_bert.py +++ b/examples/modular-transformers/modular_dummy_bert.py @@ -23,5 +23,6 @@ class DummyBertModel(BertModel): output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: return super().forward(input_ids) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index b502400dbb..bb5aac99b3 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1112,7 +1112,7 @@ class Cache: def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: """ - Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the + Support for backwards-compatible `past_key_values` indexing, e.g. `past_key_values[0][0].shape[2]` to get the sequence length. """ if layer_idx < len(self.layers): @@ -1124,7 +1124,7 @@ class Cache: def __iter__(self): """ - Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over + Support for backwards-compatible `past_key_values` iteration, e.g. `for x in past_key_values:` to iterate over keys and values """ for layer_idx in range(len(self)): @@ -1132,7 +1132,7 @@ class Cache: def __len__(self): """ - Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds + Support for backwards-compatible `past_key_values` length, e.g. `len(past_key_values)`. This value corresponds to the number of layers in the model. """ # Best effort BC support for old-style caches like Mambas, Falcon, HybridChunked that rely on __len__ @@ -1742,7 +1742,7 @@ class EncoderDecoderCache(Cache): def __iter__(self): """ - Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over + Support for backwards-compatible `past_key_values` iteration, e.g. `for x in past_key_values:` to iterate over keys and values """ for layer_idx in range(len(self)): @@ -1755,7 +1755,7 @@ class EncoderDecoderCache(Cache): def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ - Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the + Support for backwards-compatible `past_key_values` indexing, e.g. `past_key_values[0][0].shape[2]` to get the sequence length. """ if layer_idx < len(self): @@ -1770,7 +1770,7 @@ class EncoderDecoderCache(Cache): def __len__(self): """ - Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds + Support for backwards-compatible `past_key_values` length, e.g. `len(past_key_values)`. This value corresponds to the number of layers in the model. """ return len(self.self_attention_cache) diff --git a/src/transformers/modeling_layers.py b/src/transformers/modeling_layers.py index 16e7a7e05b..7e29ab1886 100644 --- a/src/transformers/modeling_layers.py +++ b/src/transformers/modeling_layers.py @@ -70,6 +70,7 @@ class GradientCheckpointingLayer(nn.Module): do_warn = True # different names for the same thing in different layers + # TODO cyril: this one without `S` can be removed after deprection cycle if "past_key_value" in kwargs and kwargs["past_key_value"] is not None: kwargs["past_key_value"] = None message += " `past_key_value=None`," diff --git a/src/transformers/models/arcee/modeling_arcee.py b/src/transformers/models/arcee/modeling_arcee.py index bb7d419a14..661d3a72dc 100644 --- a/src/transformers/models/arcee/modeling_arcee.py +++ b/src/transformers/models/arcee/modeling_arcee.py @@ -42,6 +42,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, can_return_tuple +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import check_model_inputs from .configuration_arcee import ArceeConfig @@ -215,12 +216,13 @@ class ArceeAttention(nn.Module): config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: @@ -234,10 +236,10 @@ class ArceeAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -270,12 +272,13 @@ class ArceeDecoderLayer(GradientCheckpointingLayer): self.input_layernorm = ArceeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = ArceeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -288,7 +291,7 @@ class ArceeDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -389,7 +392,7 @@ class ArceeModel(ArceePreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 9a87c4cc50..bfb2e6b194 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -33,6 +33,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import check_model_inputs from ...utils.import_utils import is_torch_available from ..auto import AutoModel @@ -522,12 +523,13 @@ class AriaTextAttention(nn.Module): config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: @@ -541,10 +543,10 @@ class AriaTextAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -588,12 +590,13 @@ class AriaTextDecoderLayer(GradientCheckpointingLayer): self.input_layernorm = AriaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = AriaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -606,7 +609,7 @@ class AriaTextDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -769,7 +772,7 @@ class AriaTextModel(AriaTextPreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, diff --git a/src/transformers/models/autoformer/modeling_autoformer.py b/src/transformers/models/autoformer/modeling_autoformer.py index 3a81f3e284..bc0e5faf59 100644 --- a/src/transformers/models/autoformer/modeling_autoformer.py +++ b/src/transformers/models/autoformer/modeling_autoformer.py @@ -36,6 +36,7 @@ from ...modeling_outputs import BaseModelOutput, ModelOutput, SampleTSPrediction from ...modeling_utils import PreTrainedModel from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput from ...utils import auto_docstring, is_torch_flex_attn_available, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_autoformer import AutoformerConfig @@ -451,11 +452,12 @@ class AutoformerAttention(nn.Module): self.autocorrelation_factor = autocorrelation_factor + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, @@ -471,19 +473,19 @@ class AutoformerAttention(nn.Module): # get query proj query_states = self.q_proj(hidden_states) - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_layer from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values @@ -493,7 +495,7 @@ class AutoformerAttention(nn.Module): key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( @@ -501,7 +503,7 @@ class AutoformerAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) @@ -751,6 +753,7 @@ class AutoformerDecoderLayer(GradientCheckpointingLayer): bias=False, ) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -759,7 +762,7 @@ class AutoformerDecoderLayer(GradientCheckpointingLayer): encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, cache_position: Optional[torch.Tensor] = None, @@ -777,7 +780,7 @@ class AutoformerDecoderLayer(GradientCheckpointingLayer): `(encoder_attention_heads,)`. cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of size `(decoder_attention_heads,)`. - past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -790,7 +793,7 @@ class AutoformerDecoderLayer(GradientCheckpointingLayer): # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, @@ -812,7 +815,7 @@ class AutoformerDecoderLayer(GradientCheckpointingLayer): key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) @@ -1205,7 +1208,7 @@ class AutoformerDecoder(AutoformerPreTrainedModel): encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, diff --git a/src/transformers/models/bamba/modeling_bamba.py b/src/transformers/models/bamba/modeling_bamba.py index e31c0a53ed..eaae2133b6 100644 --- a/src/transformers/models/bamba/modeling_bamba.py +++ b/src/transformers/models/bamba/modeling_bamba.py @@ -41,6 +41,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.deprecation import deprecate_kwarg from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available from .configuration_bamba import BambaConfig @@ -336,12 +337,13 @@ class BambaAttention(nn.Module): config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: @@ -355,10 +357,10 @@ class BambaAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -1003,12 +1005,13 @@ class BambaDecoderLayer(GradientCheckpointingLayer): else: raise ValueError("Invalid layer_type") + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -1020,7 +1023,7 @@ class BambaDecoderLayer(GradientCheckpointingLayer): hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + past_key_values (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -1045,7 +1048,7 @@ class BambaDecoderLayer(GradientCheckpointingLayer): if self.layer_type == "mamba": hidden_states = self.mamba( hidden_states=hidden_states, - cache_params=past_key_value, + cache_params=past_key_values, cache_position=cache_position, attention_mask=attention_mask, **kwargs, @@ -1056,7 +1059,7 @@ class BambaDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -1189,7 +1192,7 @@ class BambaModel(BambaPreTrainedModel): hidden_states, attention_mask=layer_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, diff --git a/src/transformers/models/bamba/modular_bamba.py b/src/transformers/models/bamba/modular_bamba.py index be58fd3abd..95c5ce8e36 100644 --- a/src/transformers/models/bamba/modular_bamba.py +++ b/src/transformers/models/bamba/modular_bamba.py @@ -52,6 +52,7 @@ from ...utils import ( can_return_tuple, logging, ) +from ...utils.deprecation import deprecate_kwarg from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available from .configuration_bamba import BambaConfig @@ -725,12 +726,13 @@ class BambaDecoderLayer(JambaAttentionDecoderLayer): else: raise ValueError("Invalid layer_type") + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -742,7 +744,7 @@ class BambaDecoderLayer(JambaAttentionDecoderLayer): hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + past_key_values (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -767,7 +769,7 @@ class BambaDecoderLayer(JambaAttentionDecoderLayer): if self.layer_type == "mamba": hidden_states = self.mamba( hidden_states=hidden_states, - cache_params=past_key_value, + cache_params=past_key_values, cache_position=cache_position, attention_mask=attention_mask, **kwargs, @@ -778,7 +780,7 @@ class BambaDecoderLayer(JambaAttentionDecoderLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -911,7 +913,7 @@ class BambaModel(BambaPreTrainedModel): hidden_states, attention_mask=layer_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 236a2f6471..de20331e82 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -50,6 +50,7 @@ from ...utils import ( is_torchdynamo_compiling, logging, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_bart import BartConfig @@ -186,11 +187,12 @@ class BartAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, @@ -215,19 +217,19 @@ class BartAttention(nn.Module): # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values @@ -237,7 +239,7 @@ class BartAttention(nn.Module): key_states = key_states.view(*kv_input_shape).transpose(1, 2) value_states = value_states.view(*kv_input_shape).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( @@ -245,7 +247,7 @@ class BartAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -373,6 +375,7 @@ class BartDecoderLayer(GradientCheckpointingLayer): self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -381,7 +384,7 @@ class BartDecoderLayer(GradientCheckpointingLayer): encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, cache_position: Optional[torch.Tensor] = None, @@ -399,7 +402,7 @@ class BartDecoderLayer(GradientCheckpointingLayer): `(encoder_attention_heads,)`. cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of size `(decoder_attention_heads,)`. - past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -412,7 +415,7 @@ class BartDecoderLayer(GradientCheckpointingLayer): # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, @@ -432,7 +435,7 @@ class BartDecoderLayer(GradientCheckpointingLayer): key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) @@ -1119,7 +1122,7 @@ class BartDecoder(BartPreTrainedModel): encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -1284,7 +1287,7 @@ class BartModel(BartPreTrainedModel): attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index 90ed959176..04323c7ec4 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -46,6 +46,7 @@ from ...modeling_outputs import ( from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, get_torch_version, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_bert import BertConfig @@ -217,13 +218,14 @@ class BertSelfAttention(nn.Module): self.is_decoder = config.is_decoder self.layer_idx = layer_idx + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -234,19 +236,19 @@ class BertSelfAttention(nn.Module): ) is_cross_attention = encoder_hidden_states is not None - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_layer from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = encoder_hidden_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_layer = curr_past_key_value.layers[self.layer_idx].keys value_layer = curr_past_key_value.layers[self.layer_idx].values @@ -260,7 +262,7 @@ class BertSelfAttention(nn.Module): batch_size, -1, self.num_attention_heads, self.attention_head_size ).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_layer, value_layer = curr_past_key_value.update( @@ -268,14 +270,14 @@ class BertSelfAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if past_key_value is not None: + if past_key_values is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -327,13 +329,14 @@ class BertSdpaSelfAttention(BertSelfAttention): self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") # Adapted from BertSelfAttention + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -351,7 +354,7 @@ class BertSdpaSelfAttention(BertSelfAttention): attention_mask, head_mask, encoder_hidden_states, - past_key_value, + past_key_values, output_attentions, cache_position, ) @@ -364,19 +367,19 @@ class BertSdpaSelfAttention(BertSelfAttention): is_cross_attention = encoder_hidden_states is not None current_states = encoder_hidden_states if is_cross_attention else hidden_states - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = encoder_hidden_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_layer = curr_past_key_value.layers[self.layer_idx].keys value_layer = curr_past_key_value.layers[self.layer_idx].values @@ -392,7 +395,7 @@ class BertSdpaSelfAttention(BertSelfAttention): .transpose(1, 2) ) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_layer, value_layer = curr_past_key_value.update( @@ -400,7 +403,7 @@ class BertSdpaSelfAttention(BertSelfAttention): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. @@ -480,13 +483,14 @@ class BertAttention(nn.Module): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -495,7 +499,7 @@ class BertAttention(nn.Module): attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) @@ -548,6 +552,7 @@ class BertLayer(GradientCheckpointingLayer): self.intermediate = BertIntermediate(config) self.output = BertOutput(config) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -555,7 +560,7 @@ class BertLayer(GradientCheckpointingLayer): head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -564,7 +569,7 @@ class BertLayer(GradientCheckpointingLayer): attention_mask=attention_mask, head_mask=head_mask, output_attentions=output_attentions, - past_key_value=past_key_value, + past_key_values=past_key_values, cache_position=cache_position, ) attention_output = self_attention_outputs[0] @@ -582,7 +587,7 @@ class BertLayer(GradientCheckpointingLayer): attention_mask=encoder_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) @@ -656,7 +661,7 @@ class BertEncoder(nn.Module): layer_head_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index 10ba8baca0..d582a91439 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -29,6 +29,7 @@ from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, Causa from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_bert_generation import BertGenerationConfig @@ -79,13 +80,14 @@ class BertGenerationSelfAttention(nn.Module): self.is_decoder = config.is_decoder self.layer_idx = layer_idx + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -96,19 +98,19 @@ class BertGenerationSelfAttention(nn.Module): ) is_cross_attention = encoder_hidden_states is not None - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_layer from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = encoder_hidden_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_layer = curr_past_key_value.layers[self.layer_idx].keys value_layer = curr_past_key_value.layers[self.layer_idx].values @@ -122,7 +124,7 @@ class BertGenerationSelfAttention(nn.Module): batch_size, -1, self.num_attention_heads, self.attention_head_size ).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_layer, value_layer = curr_past_key_value.update( @@ -130,14 +132,14 @@ class BertGenerationSelfAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if past_key_value is not None: + if past_key_values is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -217,13 +219,14 @@ class BertGenerationAttention(nn.Module): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -232,7 +235,7 @@ class BertGenerationAttention(nn.Module): attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) @@ -290,6 +293,7 @@ class BertGenerationLayer(GradientCheckpointingLayer): self.intermediate = BertGenerationIntermediate(config) self.output = BertGenerationOutput(config) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -297,7 +301,7 @@ class BertGenerationLayer(GradientCheckpointingLayer): head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -306,7 +310,7 @@ class BertGenerationLayer(GradientCheckpointingLayer): attention_mask=attention_mask, head_mask=head_mask, output_attentions=output_attentions, - past_key_value=past_key_value, + past_key_values=past_key_values, cache_position=cache_position, ) attention_output = self_attention_outputs[0] @@ -324,7 +328,7 @@ class BertGenerationLayer(GradientCheckpointingLayer): attention_mask=encoder_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) @@ -398,7 +402,7 @@ class BertEncoder(nn.Module): layer_head_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index 8112f10200..97f6187a1f 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -41,6 +41,7 @@ from ...modeling_outputs import ( from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward from ...utils import ModelOutput, auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_big_bird import BigBirdConfig @@ -315,6 +316,7 @@ class BigBirdSelfAttention(nn.Module): self.is_decoder = config.is_decoder self.layer_idx = layer_idx + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, @@ -322,7 +324,7 @@ class BigBirdSelfAttention(nn.Module): head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, - past_key_value=None, + past_key_values=None, output_attentions=False, cache_position=None, ): @@ -336,10 +338,10 @@ class BigBirdSelfAttention(nn.Module): is_cross_attention = encoder_hidden_states is not None current_states = encoder_hidden_states if is_cross_attention else hidden_states attention_mask = encoder_attention_mask if is_cross_attention else attention_mask - if is_cross_attention and past_key_value is not None and past_key_value.get_seq_length(self.layer_idx) > 0: + if is_cross_attention and past_key_values is not None and past_key_values.get_seq_length(self.layer_idx) > 0: # reuse k,v, cross_attentions - key_layer = past_key_value.layers[self.layer_idx].keys - value_layer = past_key_value.layers[self.layer_idx].values + key_layer = past_key_values.layers[self.layer_idx].keys + value_layer = past_key_values.layers[self.layer_idx].values else: key_layer = ( self.key(current_states) @@ -352,9 +354,9 @@ class BigBirdSelfAttention(nn.Module): .transpose(1, 2) ) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation - key_layer, value_layer = past_key_value.update( + key_layer, value_layer = past_key_values.update( key_layer, value_layer, self.layer_idx, @@ -1338,6 +1340,7 @@ class BigBirdAttention(nn.Module): if not self.training: self.self.eval() + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, @@ -1345,7 +1348,7 @@ class BigBirdAttention(nn.Module): head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, - past_key_value=None, + past_key_values=None, output_attentions=False, # block_sparse config band_mask=None, @@ -1369,7 +1372,7 @@ class BigBirdAttention(nn.Module): head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) @@ -1447,6 +1450,7 @@ class BigBirdLayer(GradientCheckpointingLayer): if self.add_cross_attention: self.crossattention.set_attention_type(value, layer_idx=layer_idx) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, @@ -1458,7 +1462,7 @@ class BigBirdLayer(GradientCheckpointingLayer): from_mask=None, to_mask=None, blocked_encoder_mask=None, - past_key_value=None, + past_key_values=None, output_attentions=False, cache_position=None, ): @@ -1469,7 +1473,7 @@ class BigBirdLayer(GradientCheckpointingLayer): head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, band_mask=band_mask, from_mask=from_mask, @@ -1493,7 +1497,7 @@ class BigBirdLayer(GradientCheckpointingLayer): attention_mask=encoder_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 40013e4f32..9af25afd79 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -44,6 +44,7 @@ from ...modeling_outputs import ( from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, is_torch_flex_attn_available, is_torchdynamo_compiling, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_bigbird_pegasus import BigBirdPegasusConfig @@ -127,6 +128,7 @@ class BigBirdPegasusSelfAttention(nn.Module): self.is_decoder = config.is_decoder self.layer_idx = layer_idx + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, @@ -134,7 +136,7 @@ class BigBirdPegasusSelfAttention(nn.Module): head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, - past_key_value=None, + past_key_values=None, output_attentions=False, cache_position=None, ): @@ -148,10 +150,10 @@ class BigBirdPegasusSelfAttention(nn.Module): is_cross_attention = encoder_hidden_states is not None current_states = encoder_hidden_states if is_cross_attention else hidden_states attention_mask = encoder_attention_mask if is_cross_attention else attention_mask - if is_cross_attention and past_key_value is not None and past_key_value.get_seq_length(self.layer_idx) > 0: + if is_cross_attention and past_key_values is not None and past_key_values.get_seq_length(self.layer_idx) > 0: # reuse k,v, cross_attentions - key_layer = past_key_value.layers[self.layer_idx].keys - value_layer = past_key_value.layers[self.layer_idx].values + key_layer = past_key_values.layers[self.layer_idx].keys + value_layer = past_key_values.layers[self.layer_idx].values else: key_layer = ( self.key(current_states) @@ -164,9 +166,9 @@ class BigBirdPegasusSelfAttention(nn.Module): .transpose(1, 2) ) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation - key_layer, value_layer = past_key_value.update( + key_layer, value_layer = past_key_values.update( key_layer, value_layer, self.layer_idx, @@ -1244,11 +1246,12 @@ class BigBirdPegasusDecoderAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, @@ -1273,19 +1276,19 @@ class BigBirdPegasusDecoderAttention(nn.Module): # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values @@ -1295,7 +1298,7 @@ class BigBirdPegasusDecoderAttention(nn.Module): key_states = key_states.view(*kv_input_shape).transpose(1, 2) value_states = value_states.view(*kv_input_shape).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( @@ -1303,7 +1306,7 @@ class BigBirdPegasusDecoderAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -1456,7 +1459,7 @@ class BigBirdPegasusDecoderLayer(GradientCheckpointingLayer): encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, cache_position: Optional[torch.Tensor] = None, @@ -1474,7 +1477,7 @@ class BigBirdPegasusDecoderLayer(GradientCheckpointingLayer): `(encoder_attention_heads,)`. cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of size `(decoder_attention_heads,)`. - past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -1488,7 +1491,7 @@ class BigBirdPegasusDecoderLayer(GradientCheckpointingLayer): # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, @@ -1508,7 +1511,7 @@ class BigBirdPegasusDecoderLayer(GradientCheckpointingLayer): key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -2275,7 +2278,7 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel): encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -2426,7 +2429,7 @@ class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel): attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index d6ed401cd6..bba29c892f 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -41,6 +41,7 @@ from ...modeling_outputs import ( from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, is_torch_flex_attn_available, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_biogpt import BioGptConfig @@ -164,11 +165,12 @@ class BioGptAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, @@ -193,19 +195,19 @@ class BioGptAttention(nn.Module): # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values @@ -215,7 +217,7 @@ class BioGptAttention(nn.Module): key_states = key_states.view(*kv_input_shape).transpose(1, 2) value_states = value_states.view(*kv_input_shape).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( @@ -223,7 +225,7 @@ class BioGptAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -272,12 +274,13 @@ class BioGptDecoderLayer(GradientCheckpointingLayer): self.fc2 = nn.Linear(config.intermediate_size, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, position_ids: Optional[torch.LongTensor] = None, @@ -291,7 +294,7 @@ class BioGptDecoderLayer(GradientCheckpointingLayer): `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size `(encoder_attention_heads,)`. - past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -309,7 +312,7 @@ class BioGptDecoderLayer(GradientCheckpointingLayer): # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, @@ -624,7 +627,7 @@ class BioGptModel(BioGptPreTrainedModel): hidden_states, attention_mask=causal_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, position_ids=position_ids, diff --git a/src/transformers/models/biogpt/modular_biogpt.py b/src/transformers/models/biogpt/modular_biogpt.py index db5ad5dbbc..7b29640cd8 100644 --- a/src/transformers/models/biogpt/modular_biogpt.py +++ b/src/transformers/models/biogpt/modular_biogpt.py @@ -42,6 +42,7 @@ from ...utils import ( is_torch_flex_attn_available, logger, ) +from ...utils.deprecation import deprecate_kwarg from ..bart.modeling_bart import ( BartAttention, BartDecoderLayer, @@ -97,12 +98,13 @@ class BioGptDecoderLayer(BartDecoderLayer): del self.encoder_attn del self.encoder_attn_layer_norm + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, position_ids: Optional[torch.LongTensor] = None, @@ -116,7 +118,7 @@ class BioGptDecoderLayer(BartDecoderLayer): `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size `(encoder_attention_heads,)`. - past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -134,7 +136,7 @@ class BioGptDecoderLayer(BartDecoderLayer): # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, @@ -449,7 +451,7 @@ class BioGptModel(BioGptPreTrainedModel): hidden_states, attention_mask=causal_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, position_ids=position_ids, diff --git a/src/transformers/models/bitnet/modeling_bitnet.py b/src/transformers/models/bitnet/modeling_bitnet.py index dee9edfc2e..6c4d8e21e2 100644 --- a/src/transformers/models/bitnet/modeling_bitnet.py +++ b/src/transformers/models/bitnet/modeling_bitnet.py @@ -35,6 +35,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import check_model_inputs from .configuration_bitnet import BitNetConfig @@ -176,12 +177,13 @@ class BitNetAttention(nn.Module): ) self.attn_sub_norm = BitNetRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -195,10 +197,10 @@ class BitNetAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -233,12 +235,13 @@ class BitNetDecoderLayer(GradientCheckpointingLayer): self.input_layernorm = BitNetRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = BitNetRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -251,7 +254,7 @@ class BitNetDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -388,7 +391,7 @@ class BitNetModel(BitNetPreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, diff --git a/src/transformers/models/bitnet/modular_bitnet.py b/src/transformers/models/bitnet/modular_bitnet.py index e8f29db38d..92ad3a3214 100644 --- a/src/transformers/models/bitnet/modular_bitnet.py +++ b/src/transformers/models/bitnet/modular_bitnet.py @@ -23,6 +23,7 @@ from ...modeling_outputs import CausalLMOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...utils import logging +from ...utils.deprecation import deprecate_kwarg from ..gemma.modeling_gemma import GemmaMLP from ..llama.modeling_llama import ( LlamaAttention, @@ -58,12 +59,13 @@ class BitNetAttention(LlamaAttention): super().__init__(config, layer_idx) self.attn_sub_norm = BitNetRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -77,10 +79,10 @@ class BitNetAttention(LlamaAttention): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index b4ec543b3e..78dd0223bc 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -49,6 +49,7 @@ from ...utils import ( is_torchdynamo_compiling, logging, ) +from ...utils.deprecation import deprecate_kwarg from ..blenderbot_small import BlenderbotSmallForConditionalGeneration, BlenderbotSmallModel from .configuration_blenderbot import BlenderbotConfig @@ -185,11 +186,12 @@ class BlenderbotAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, @@ -214,19 +216,19 @@ class BlenderbotAttention(nn.Module): # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values @@ -236,7 +238,7 @@ class BlenderbotAttention(nn.Module): key_states = key_states.view(*kv_input_shape).transpose(1, 2) value_states = value_states.view(*kv_input_shape).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( @@ -244,7 +246,7 @@ class BlenderbotAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -366,6 +368,7 @@ class BlenderbotDecoderLayer(GradientCheckpointingLayer): self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -374,7 +377,7 @@ class BlenderbotDecoderLayer(GradientCheckpointingLayer): encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, cache_position: Optional[torch.Tensor] = None, @@ -392,7 +395,7 @@ class BlenderbotDecoderLayer(GradientCheckpointingLayer): `(encoder_attention_heads,)`. cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of size `(decoder_attention_heads,)`. - past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -406,7 +409,7 @@ class BlenderbotDecoderLayer(GradientCheckpointingLayer): # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, @@ -426,7 +429,7 @@ class BlenderbotDecoderLayer(GradientCheckpointingLayer): key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -1073,7 +1076,7 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel): encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -1237,7 +1240,7 @@ class BlenderbotModel(BlenderbotPreTrainedModel): attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index b248b2f0da..212c7cb135 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -47,6 +47,7 @@ from ...utils import ( is_torchdynamo_compiling, logging, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_blenderbot_small import BlenderbotSmallConfig @@ -169,11 +170,12 @@ class BlenderbotSmallAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, @@ -198,19 +200,19 @@ class BlenderbotSmallAttention(nn.Module): # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values @@ -220,7 +222,7 @@ class BlenderbotSmallAttention(nn.Module): key_states = key_states.view(*kv_input_shape).transpose(1, 2) value_states = value_states.view(*kv_input_shape).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( @@ -228,7 +230,7 @@ class BlenderbotSmallAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -358,6 +360,7 @@ class BlenderbotSmallDecoderLayer(GradientCheckpointingLayer): self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -366,7 +369,7 @@ class BlenderbotSmallDecoderLayer(GradientCheckpointingLayer): encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, cache_position: Optional[torch.Tensor] = None, @@ -384,7 +387,7 @@ class BlenderbotSmallDecoderLayer(GradientCheckpointingLayer): `(encoder_attention_heads,)`. cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of size `(decoder_attention_heads,)`. - past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -397,7 +400,7 @@ class BlenderbotSmallDecoderLayer(GradientCheckpointingLayer): # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, @@ -417,7 +420,7 @@ class BlenderbotSmallDecoderLayer(GradientCheckpointingLayer): key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) @@ -1061,7 +1064,7 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -1209,7 +1212,7 @@ class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel): attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, diff --git a/src/transformers/models/blip/modeling_blip.py b/src/transformers/models/blip/modeling_blip.py index 267b0ffcb0..4a58d70eb0 100644 --- a/src/transformers/models/blip/modeling_blip.py +++ b/src/transformers/models/blip/modeling_blip.py @@ -438,7 +438,7 @@ class BlipPreTrainedModel(PreTrainedModel): base_model_prefix = "blip" supports_gradient_checkpointing = True _no_split_modules = ["BlipEncoderLayer", "BlipTextEmbeddings"] - _skip_keys_device_placement = ["past_key_value"] + _skip_keys_device_placement = ["past_key_values"] def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/blip/modeling_blip_text.py b/src/transformers/models/blip/modeling_blip_text.py index bdf26fe39f..a75582200b 100644 --- a/src/transformers/models/blip/modeling_blip_text.py +++ b/src/transformers/models/blip/modeling_blip_text.py @@ -38,6 +38,7 @@ from ...modeling_utils import ( prune_linear_layer, ) from ...utils import logging +from ...utils.deprecation import deprecate_kwarg from .configuration_blip import BlipTextConfig @@ -138,6 +139,7 @@ class BlipTextSelfAttention(nn.Module): def get_attention_map(self): return self.attention_map + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -145,7 +147,7 @@ class BlipTextSelfAttention(nn.Module): head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -162,19 +164,19 @@ class BlipTextSelfAttention(nn.Module): is_cross_attention = encoder_hidden_states is not None attention_mask = encoder_attention_mask if is_cross_attention else attention_mask - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_layer from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = encoder_hidden_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_layer = curr_past_key_value.layers[self.layer_idx].keys value_layer = curr_past_key_value.layers[self.layer_idx].values @@ -190,7 +192,7 @@ class BlipTextSelfAttention(nn.Module): .transpose(1, 2) ) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_layer, value_layer = curr_past_key_value.update( @@ -198,7 +200,7 @@ class BlipTextSelfAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -285,13 +287,14 @@ class BlipTextAttention(nn.Module): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -300,7 +303,7 @@ class BlipTextAttention(nn.Module): attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) @@ -355,6 +358,7 @@ class BlipTextLayer(GradientCheckpointingLayer): self.intermediate = BlipTextIntermediate(config) self.output = BlipTextOutput(config) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -362,7 +366,7 @@ class BlipTextLayer(GradientCheckpointingLayer): head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -371,7 +375,7 @@ class BlipTextLayer(GradientCheckpointingLayer): attention_mask=attention_mask, head_mask=head_mask, output_attentions=output_attentions, - past_key_value=past_key_value, + past_key_values=past_key_values, cache_position=cache_position, ) attention_output = self_attention_outputs[0] @@ -383,7 +387,7 @@ class BlipTextLayer(GradientCheckpointingLayer): attention_mask=encoder_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) diff --git a/src/transformers/models/bridgetower/modeling_bridgetower.py b/src/transformers/models/bridgetower/modeling_bridgetower.py index 650543adab..b547d160ab 100644 --- a/src/transformers/models/bridgetower/modeling_bridgetower.py +++ b/src/transformers/models/bridgetower/modeling_bridgetower.py @@ -37,6 +37,7 @@ from ...modeling_outputs import ( from ...modeling_utils import PreTrainedModel, apply_chunking_to_forward from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging, torch_int +from ...utils.deprecation import deprecate_kwarg from .configuration_bridgetower import BridgeTowerConfig, BridgeTowerTextConfig, BridgeTowerVisionConfig @@ -429,13 +430,14 @@ class BridgeTowerSelfAttention(nn.Module): self.is_decoder = config.is_decoder self.layer_idx = layer_idx + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -446,19 +448,19 @@ class BridgeTowerSelfAttention(nn.Module): ) is_cross_attention = encoder_hidden_states is not None - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_layer from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = encoder_hidden_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_layer = curr_past_key_value.layers[self.layer_idx].keys value_layer = curr_past_key_value.layers[self.layer_idx].values @@ -472,7 +474,7 @@ class BridgeTowerSelfAttention(nn.Module): batch_size, -1, self.num_attention_heads, self.attention_head_size ).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_layer, value_layer = curr_past_key_value.update( @@ -480,14 +482,14 @@ class BridgeTowerSelfAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if past_key_value is not None: + if past_key_values is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -567,13 +569,14 @@ class BridgeTowerAttention(nn.Module): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -582,7 +585,7 @@ class BridgeTowerAttention(nn.Module): attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) @@ -603,6 +606,7 @@ class BridgeTowerBertCrossLayer(nn.Module): self.intermediate = BridgeTowerIntermediate(config) self.output = BridgeTowerOutput(config) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, @@ -610,7 +614,7 @@ class BridgeTowerBertCrossLayer(nn.Module): attention_mask=None, head_mask=None, encoder_attention_mask=None, - past_key_value=None, + past_key_values=None, output_attentions=False, cache_position=None, ): @@ -620,7 +624,7 @@ class BridgeTowerBertCrossLayer(nn.Module): attention_mask=attention_mask, head_mask=None, output_attentions=output_attentions, - past_key_value=None, + past_key_values=None, ) attention_output = self_attention_outputs[0] @@ -633,7 +637,7 @@ class BridgeTowerBertCrossLayer(nn.Module): attention_mask=encoder_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) @@ -669,6 +673,7 @@ class BridgeTowerTextLayer(GradientCheckpointingLayer): self.intermediate = BridgeTowerIntermediate(config) self.output = BridgeTowerOutput(config) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -676,7 +681,7 @@ class BridgeTowerTextLayer(GradientCheckpointingLayer): head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -686,7 +691,7 @@ class BridgeTowerTextLayer(GradientCheckpointingLayer): attention_mask=attention_mask, head_mask=head_mask, output_attentions=output_attentions, - past_key_value=past_key_value, + past_key_values=past_key_values, cache_position=cache_position, ) attention_output = self_attention_outputs[0] @@ -709,7 +714,7 @@ class BridgeTowerTextLayer(GradientCheckpointingLayer): attention_mask=encoder_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) @@ -784,7 +789,7 @@ class BridgeTowerTextEncoder(nn.Module): layer_head_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) diff --git a/src/transformers/models/camembert/modeling_camembert.py b/src/transformers/models/camembert/modeling_camembert.py index bb983794ab..746112afef 100644 --- a/src/transformers/models/camembert/modeling_camembert.py +++ b/src/transformers/models/camembert/modeling_camembert.py @@ -42,6 +42,7 @@ from ...modeling_outputs import ( from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, get_torch_version, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_camembert import CamembertConfig @@ -167,13 +168,14 @@ class CamembertSelfAttention(nn.Module): self.is_decoder = config.is_decoder self.layer_idx = layer_idx + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -184,19 +186,19 @@ class CamembertSelfAttention(nn.Module): ) is_cross_attention = encoder_hidden_states is not None - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_layer from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = encoder_hidden_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_layer = curr_past_key_value.layers[self.layer_idx].keys value_layer = curr_past_key_value.layers[self.layer_idx].values @@ -210,7 +212,7 @@ class CamembertSelfAttention(nn.Module): batch_size, -1, self.num_attention_heads, self.attention_head_size ).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_layer, value_layer = curr_past_key_value.update( @@ -218,14 +220,14 @@ class CamembertSelfAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if past_key_value is not None: + if past_key_values is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -278,13 +280,14 @@ class CamembertSdpaSelfAttention(CamembertSelfAttention): self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") # Adapted from CamembertSelfAttention + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -302,7 +305,7 @@ class CamembertSdpaSelfAttention(CamembertSelfAttention): attention_mask, head_mask, encoder_hidden_states, - past_key_value, + past_key_values, output_attentions, cache_position, ) @@ -315,19 +318,19 @@ class CamembertSdpaSelfAttention(CamembertSelfAttention): is_cross_attention = encoder_hidden_states is not None current_states = encoder_hidden_states if is_cross_attention else hidden_states - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = encoder_hidden_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_layer = curr_past_key_value.layers[self.layer_idx].keys value_layer = curr_past_key_value.layers[self.layer_idx].values @@ -343,7 +346,7 @@ class CamembertSdpaSelfAttention(CamembertSelfAttention): .transpose(1, 2) ) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_layer, value_layer = curr_past_key_value.update( @@ -351,7 +354,7 @@ class CamembertSdpaSelfAttention(CamembertSelfAttention): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. @@ -433,13 +436,14 @@ class CamembertAttention(nn.Module): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -448,7 +452,7 @@ class CamembertAttention(nn.Module): attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) @@ -504,6 +508,7 @@ class CamembertLayer(GradientCheckpointingLayer): self.intermediate = CamembertIntermediate(config) self.output = CamembertOutput(config) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -511,7 +516,7 @@ class CamembertLayer(GradientCheckpointingLayer): head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -520,7 +525,7 @@ class CamembertLayer(GradientCheckpointingLayer): attention_mask=attention_mask, head_mask=head_mask, output_attentions=output_attentions, - past_key_value=past_key_value, + past_key_values=past_key_values, cache_position=cache_position, ) attention_output = self_attention_outputs[0] @@ -538,7 +543,7 @@ class CamembertLayer(GradientCheckpointingLayer): attention_mask=encoder_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) @@ -613,7 +618,7 @@ class CamembertEncoder(nn.Module): layer_head_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index 91bbc7e87f..70818f824c 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -37,6 +37,7 @@ from ...utils import ( can_return_tuple, logging, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_chameleon import ChameleonConfig, ChameleonVQVAEConfig @@ -311,12 +312,13 @@ class ChameleonAttention(nn.Module): else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -341,10 +343,10 @@ class ChameleonAttention(nn.Module): cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; position_ids needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -379,12 +381,13 @@ class ChameleonDecoderLayer(GradientCheckpointingLayer): self.input_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -402,7 +405,7 @@ class ChameleonDecoderLayer(GradientCheckpointingLayer): use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence kwargs (`dict`, *optional*): @@ -418,7 +421,7 @@ class ChameleonDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -451,12 +454,13 @@ class ChameleonSwinDecoderLayer(GradientCheckpointingLayer): self.input_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -471,7 +475,7 @@ class ChameleonSwinDecoderLayer(GradientCheckpointingLayer): query_sequence_length, key_sequence_length)` if default attention is used. position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Indices of positions of each input sequence tokens in the position embeddings - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -489,7 +493,7 @@ class ChameleonSwinDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -992,7 +996,7 @@ class ChameleonModel(ChameleonPreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, diff --git a/src/transformers/models/clvp/modeling_clvp.py b/src/transformers/models/clvp/modeling_clvp.py index b0170e6402..314c18c4d0 100644 --- a/src/transformers/models/clvp/modeling_clvp.py +++ b/src/transformers/models/clvp/modeling_clvp.py @@ -42,6 +42,7 @@ from ...utils import ( auto_docstring, logging, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_clvp import ( ClvpConfig, ClvpDecoderConfig, @@ -298,13 +299,14 @@ class ClvpSelfAttention(nn.Module): def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.FloatTensor, rotary_pos_emb: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, head_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = False, @@ -322,8 +324,8 @@ class ClvpSelfAttention(nn.Module): key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - if past_key_value is not None: - key_states, value_states = past_key_value.update( + if past_key_values is not None: + key_states, value_states = past_key_values.update( key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) @@ -607,10 +609,11 @@ class ClvpDecoderLayer(nn.Module): self.mlp = ClvpDecoderMLP(inner_dim, config) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: Optional[tuple[torch.FloatTensor]], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, @@ -622,7 +625,7 @@ class ClvpDecoderLayer(nn.Module): hidden_states = self.input_layernorm(hidden_states) attn_outputs = self.attn( hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, attention_mask=attention_mask, position_ids=position_ids, head_mask=head_mask, @@ -1128,7 +1131,7 @@ class ClvpDecoder(ClvpPreTrainedModel): else: outputs = block( hidden_states, - past_key_value=past_key_values, + past_key_values=past_key_values, attention_mask=attention_mask, position_ids=position_ids, head_mask=head_mask[i], @@ -1213,7 +1216,7 @@ class ClvpModel(ClvpPreTrainedModel): use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) decoder_outputs = self.decoder( input_ids=input_ids, attention_mask=attention_mask, diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index a07d9b8d1b..79cb07e0b8 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -43,6 +43,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import check_model_inputs from .configuration_cohere import CohereConfig @@ -227,12 +228,13 @@ class CohereAttention(nn.Module): hidden_size=(config.num_key_value_heads, self.head_dim), eps=config.layer_norm_eps ) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -254,10 +256,10 @@ class CohereAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; position_ids needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -287,12 +289,13 @@ class CohereDecoderLayer(GradientCheckpointingLayer): self.mlp = CohereMLP(config) self.input_layernorm = CohereLayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -304,7 +307,7 @@ class CohereDecoderLayer(GradientCheckpointingLayer): attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, query_sequence_length, key_sequence_length)` if default attention is used. - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -324,7 +327,7 @@ class CohereDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -421,7 +424,7 @@ class CohereModel(CoherePreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, diff --git a/src/transformers/models/cohere/modular_cohere.py b/src/transformers/models/cohere/modular_cohere.py index 9cc22c2732..4f05fedc98 100644 --- a/src/transformers/models/cohere/modular_cohere.py +++ b/src/transformers/models/cohere/modular_cohere.py @@ -36,6 +36,7 @@ from ...modeling_rope_utils import dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...utils import TransformersKwargs, logging +from ...utils.deprecation import deprecate_kwarg from ..llama.modeling_llama import ( LlamaAttention, LlamaForCausalLM, @@ -145,12 +146,13 @@ class CohereAttention(LlamaAttention): hidden_size=(config.num_key_value_heads, self.head_dim), eps=config.layer_norm_eps ) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -172,10 +174,10 @@ class CohereAttention(LlamaAttention): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; position_ids needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -205,12 +207,13 @@ class CohereDecoderLayer(GradientCheckpointingLayer): self.mlp = CohereMLP(config) self.input_layernorm = CohereLayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -222,7 +225,7 @@ class CohereDecoderLayer(GradientCheckpointingLayer): attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, query_sequence_length, key_sequence_length)` if default attention is used. - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -242,7 +245,7 @@ class CohereDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index 8b31b2f349..8ab3952088 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -35,6 +35,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import check_model_inputs from .configuration_cohere2 import Cohere2Config @@ -195,12 +196,13 @@ class Cohere2Attention(nn.Module): config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: @@ -215,9 +217,9 @@ class Cohere2Attention(nn.Module): if self.sliding_window is not None: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -265,12 +267,13 @@ class Cohere2DecoderLayer(GradientCheckpointingLayer): self.input_layernorm = Cohere2LayerNorm(hidden_size=(config.hidden_size), eps=config.layer_norm_eps) self.attention_type = config.layer_types[layer_idx] + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -281,7 +284,7 @@ class Cohere2DecoderLayer(GradientCheckpointingLayer): attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, query_sequence_length, key_sequence_length)` if default attention is used. - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -300,7 +303,7 @@ class Cohere2DecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, **kwargs, @@ -400,7 +403,7 @@ class Cohere2Model(Cohere2PreTrainedModel): hidden_states, position_embeddings=position_embeddings, attention_mask=causal_mask_mapping[decoder_layer.attention_type], - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, **kwargs, diff --git a/src/transformers/models/cohere2/modular_cohere2.py b/src/transformers/models/cohere2/modular_cohere2.py index f0fa8f12ac..f232862421 100644 --- a/src/transformers/models/cohere2/modular_cohere2.py +++ b/src/transformers/models/cohere2/modular_cohere2.py @@ -28,6 +28,7 @@ from ...modeling_rope_utils import rope_config_validation from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...utils import TransformersKwargs, logging +from ...utils.deprecation import deprecate_kwarg from ..cohere.modeling_cohere import ( CohereAttention, CohereDecoderLayer, @@ -297,12 +298,13 @@ class Cohere2Attention(CohereAttention, nn.Module): config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: @@ -317,9 +319,9 @@ class Cohere2Attention(CohereAttention, nn.Module): if self.sliding_window is not None: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -347,12 +349,13 @@ class Cohere2DecoderLayer(CohereDecoderLayer): super().__init__(config, layer_idx) self.attention_type = config.layer_types[layer_idx] + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -363,7 +366,7 @@ class Cohere2DecoderLayer(CohereDecoderLayer): hidden_states=hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, **kwargs, @@ -434,7 +437,7 @@ class Cohere2Model(Gemma2Model): hidden_states, position_embeddings=position_embeddings, attention_mask=causal_mask_mapping[decoder_layer.attention_type], - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, **kwargs, diff --git a/src/transformers/models/csm/modeling_csm.py b/src/transformers/models/csm/modeling_csm.py index cde7284037..8bcc3b028c 100644 --- a/src/transformers/models/csm/modeling_csm.py +++ b/src/transformers/models/csm/modeling_csm.py @@ -38,6 +38,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.deprecation import deprecate_kwarg from ..auto import AutoModel from .configuration_csm import CsmConfig, CsmDepthDecoderConfig from .generation_csm import CsmGenerationMixin @@ -267,12 +268,13 @@ class CsmAttention(nn.Module): config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: @@ -286,10 +288,10 @@ class CsmAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -322,12 +324,13 @@ class CsmDecoderLayer(GradientCheckpointingLayer): self.input_layernorm = CsmRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = CsmRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -340,7 +343,7 @@ class CsmDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -481,7 +484,7 @@ class CsmDepthDecoderModel(CsmPreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -727,7 +730,7 @@ class CsmBackboneModel(CsmPreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, diff --git a/src/transformers/models/csm/modular_csm.py b/src/transformers/models/csm/modular_csm.py index ad11589283..bccd3b4d3a 100644 --- a/src/transformers/models/csm/modular_csm.py +++ b/src/transformers/models/csm/modular_csm.py @@ -231,7 +231,7 @@ class CsmDepthDecoderModel(LlamaModel, CsmPreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, diff --git a/src/transformers/models/data2vec/modeling_data2vec_text.py b/src/transformers/models/data2vec/modeling_data2vec_text.py index 9f472b0250..e6754770d0 100644 --- a/src/transformers/models/data2vec/modeling_data2vec_text.py +++ b/src/transformers/models/data2vec/modeling_data2vec_text.py @@ -39,6 +39,7 @@ from ...modeling_outputs import ( from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_data2vec_text import Data2VecTextConfig @@ -167,13 +168,14 @@ class Data2VecTextSelfAttention(nn.Module): self.is_decoder = config.is_decoder self.layer_idx = layer_idx + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -184,19 +186,19 @@ class Data2VecTextSelfAttention(nn.Module): ) is_cross_attention = encoder_hidden_states is not None - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_layer from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = encoder_hidden_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_layer = curr_past_key_value.layers[self.layer_idx].keys value_layer = curr_past_key_value.layers[self.layer_idx].values @@ -210,7 +212,7 @@ class Data2VecTextSelfAttention(nn.Module): batch_size, -1, self.num_attention_heads, self.attention_head_size ).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_layer, value_layer = curr_past_key_value.update( @@ -218,14 +220,14 @@ class Data2VecTextSelfAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if past_key_value is not None: + if past_key_values is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -320,13 +322,14 @@ class Data2VecTextAttention(nn.Module): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -335,7 +338,7 @@ class Data2VecTextAttention(nn.Module): attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) @@ -393,6 +396,7 @@ class Data2VecTextLayer(GradientCheckpointingLayer): self.intermediate = Data2VecTextIntermediate(config) self.output = Data2VecTextOutput(config) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -400,7 +404,7 @@ class Data2VecTextLayer(GradientCheckpointingLayer): head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -409,7 +413,7 @@ class Data2VecTextLayer(GradientCheckpointingLayer): attention_mask=attention_mask, head_mask=head_mask, output_attentions=output_attentions, - past_key_value=past_key_value, + past_key_values=past_key_values, cache_position=cache_position, ) attention_output = self_attention_outputs[0] @@ -427,7 +431,7 @@ class Data2VecTextLayer(GradientCheckpointingLayer): attention_mask=encoder_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) @@ -502,7 +506,7 @@ class Data2VecTextEncoder(nn.Module): layer_head_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 2440abf3b6..4bd02e8b8d 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -30,6 +30,7 @@ from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, is_torch_flex_attn_available, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_dbrx import DbrxConfig @@ -239,12 +240,13 @@ class DbrxAttention(nn.Module): base=self.rope_theta, ) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -273,10 +275,10 @@ class DbrxAttention(nn.Module): cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; position_ids needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.block_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.block_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -324,18 +326,19 @@ class DbrxFlashAttention2(DbrxAttention): # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, **kwargs: Any, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: - if isinstance(past_key_value, StaticCache): + if isinstance(past_key_values, StaticCache): raise ValueError( "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" @@ -368,10 +371,10 @@ class DbrxFlashAttention2(DbrxAttention): cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.block_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.block_idx, cache_kwargs) # TODO: These transpose are quite inefficient but Flash Attention requires the layout # [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache @@ -440,12 +443,13 @@ class DbrxSdpaAttention(DbrxAttention): SDPA API. """ + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -460,7 +464,7 @@ class DbrxSdpaAttention(DbrxAttention): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -488,10 +492,10 @@ class DbrxSdpaAttention(DbrxAttention): cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, None) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.block_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.block_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -547,12 +551,13 @@ class DbrxNormAttentionNorm(nn.Module): ) self.norm_2 = nn.LayerNorm(config.d_model, bias=False) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_ids: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -565,7 +570,7 @@ class DbrxNormAttentionNorm(nn.Module): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -737,12 +742,13 @@ class DbrxBlock(GradientCheckpointingLayer): ) self.ffn = DbrxFFN(config=config) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -765,7 +771,7 @@ class DbrxBlock(GradientCheckpointingLayer): attention_mask (`torch.Tensor`, *optional*): attention mask of size (batch_size, sequence_length) if flash attention is used or (batch_size, 1, query_sequence_length, key_sequence_length) if default attention is used. - past_key_value (`Tuple(torch.Tensor)`, *optional*): cached past key and value projection states + past_key_values (`Tuple(torch.Tensor)`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. output_router_logits (`bool`, *optional*): Whether or not to return the router logits. @@ -779,7 +785,7 @@ class DbrxBlock(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -940,7 +946,7 @@ class DbrxModel(DbrxPreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, output_router_logits=output_router_logits, use_cache=use_cache, diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index 563615d217..6b3419beb6 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -34,6 +34,7 @@ from ...utils import ( auto_docstring, logging, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_decision_transformer import DecisionTransformerConfig @@ -255,10 +256,11 @@ class DecisionTransformerGPT2Attention(nn.Module): return attn_output, attn_weights + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: Optional[tuple[torch.FloatTensor]], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, @@ -268,16 +270,16 @@ class DecisionTransformerGPT2Attention(nn.Module): **kwargs, ) -> tuple[Union[torch.Tensor, tuple[torch.Tensor]], ...]: is_cross_attention = encoder_hidden_states is not None - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_layer from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values if is_cross_attention: if not hasattr(self, "q_attn"): @@ -289,7 +291,7 @@ class DecisionTransformerGPT2Attention(nn.Module): attention_mask = encoder_attention_mask # Try to get key/value states from cache if possible - if past_key_value is not None and is_updated: + if past_key_values is not None and is_updated: key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values else: @@ -306,8 +308,8 @@ class DecisionTransformerGPT2Attention(nn.Module): shape_q = (*query_states.shape[:-1], -1, self.head_dim) query_states = query_states.view(shape_q).transpose(1, 2) - if (past_key_value is not None and not is_cross_attention) or ( - past_key_value is not None and is_cross_attention and not is_updated + if (past_key_values is not None and not is_cross_attention) or ( + past_key_values is not None and is_cross_attention and not is_updated ): # save all key/value_layer to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None @@ -316,7 +318,7 @@ class DecisionTransformerGPT2Attention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention @@ -387,10 +389,11 @@ class DecisionTransformerGPT2Block(GradientCheckpointingLayer): self.mlp = DecisionTransformerGPT2MLP(inner_dim, config) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: Optional[tuple[torch.FloatTensor]], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, @@ -404,7 +407,7 @@ class DecisionTransformerGPT2Block(GradientCheckpointingLayer): hidden_states = self.ln_1(hidden_states) attn_output, self_attn_weights = self.attn( hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, cache_position=cache_position, attention_mask=attention_mask, head_mask=head_mask, @@ -426,7 +429,7 @@ class DecisionTransformerGPT2Block(GradientCheckpointingLayer): hidden_states = self.ln_cross_attn(hidden_states) cross_attn_output, cross_attn_weights = self.crossattention( hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, diff --git a/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py b/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py index bd41e7ba75..5da9d687fb 100644 --- a/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py +++ b/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py @@ -37,6 +37,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import check_model_inputs from .configuration_deepseek_v2 import DeepseekV2Config @@ -334,11 +335,12 @@ class DeepseekV2Attention(nn.Module): self.scaling = self.qk_head_dim ** (-0.5) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, position_ids: Optional[torch.Tensor] = None, @@ -371,10 +373,10 @@ class DeepseekV2Attention(nn.Module): query_states = torch.cat((q_nope, q_pe), dim=-1) key_states = torch.cat((k_nope, k_pe), dim=-1) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim: value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim]) @@ -413,12 +415,13 @@ class DeepseekV2DecoderLayer(GradientCheckpointingLayer): self.input_layernorm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = DeepseekV2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -431,7 +434,7 @@ class DeepseekV2DecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -537,7 +540,7 @@ class DeepseekV2Model(DeepseekV2PreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, diff --git a/src/transformers/models/deepseek_v2/modular_deepseek_v2.py b/src/transformers/models/deepseek_v2/modular_deepseek_v2.py index dddb74d1bf..dad427debc 100644 --- a/src/transformers/models/deepseek_v2/modular_deepseek_v2.py +++ b/src/transformers/models/deepseek_v2/modular_deepseek_v2.py @@ -25,6 +25,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...utils import ( logging, ) +from ...utils.deprecation import deprecate_kwarg from ..llama.configuration_llama import LlamaConfig from ..llama.modeling_llama import ( LlamaDecoderLayer, @@ -422,11 +423,12 @@ class DeepseekV2Attention(nn.Module): self.scaling = self.qk_head_dim ** (-0.5) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, position_ids: Optional[torch.Tensor] = None, @@ -459,10 +461,10 @@ class DeepseekV2Attention(nn.Module): query_states = torch.cat((q_nope, q_pe), dim=-1) key_states = torch.cat((k_nope, k_pe), dim=-1) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim: value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim]) diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index aa15c5ee87..5d87980080 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -23,6 +23,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import check_model_inputs from .configuration_deepseek_v3 import DeepseekV3Config @@ -371,12 +372,13 @@ class DeepseekV3Attention(nn.Module): mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) self.scaling = self.scaling * mscale * mscale + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: @@ -409,10 +411,10 @@ class DeepseekV3Attention(nn.Module): query_states = torch.cat((q_pass, q_rot), dim=-1) key_states = torch.cat((k_pass, k_rot), dim=-1) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim: value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim]) @@ -455,12 +457,13 @@ class DeepseekV3DecoderLayer(GradientCheckpointingLayer): self.input_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = DeepseekV3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -473,7 +476,7 @@ class DeepseekV3DecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -581,7 +584,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, diff --git a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py index ebcd3aed39..791c4a1958 100644 --- a/src/transformers/models/deepseek_v3/modular_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modular_deepseek_v3.py @@ -12,6 +12,7 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...utils import logging +from ...utils.deprecation import deprecate_kwarg from ..llama.modeling_llama import ( LlamaDecoderLayer, LlamaForCausalLM, @@ -252,12 +253,13 @@ class DeepseekV3Attention(nn.Module): mscale = yarn_get_mscale(scaling_factor, mscale_all_dim) self.scaling = self.scaling * mscale * mscale + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: @@ -290,10 +292,10 @@ class DeepseekV3Attention(nn.Module): query_states = torch.cat((q_pass, q_rot), dim=-1) key_states = torch.cat((k_pass, k_rot), dim=-1) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) if self.config._attn_implementation == "flash_attention_2" and self.qk_head_dim != self.v_head_dim: value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim]) diff --git a/src/transformers/models/deprecated/ernie_m/modeling_ernie_m.py b/src/transformers/models/deprecated/ernie_m/modeling_ernie_m.py index 69e5e61d3c..e2c939b255 100755 --- a/src/transformers/models/deprecated/ernie_m/modeling_ernie_m.py +++ b/src/transformers/models/deprecated/ernie_m/modeling_ernie_m.py @@ -34,6 +34,7 @@ from ....modeling_outputs import ( from ....modeling_utils import PreTrainedModel from ....pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ....utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, logging +from ....utils.deprecation import deprecate_kwarg from .configuration_ernie_m import ErnieMConfig @@ -118,6 +119,7 @@ class ErnieMSelfAttention(nn.Module): x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -125,7 +127,7 @@ class ErnieMSelfAttention(nn.Module): head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor]: mixed_query_layer = self.q_proj(hidden_states) @@ -135,27 +137,27 @@ class ErnieMSelfAttention(nn.Module): # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - if is_cross_attention and past_key_value is not None: + if is_cross_attention and past_key_values is not None: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] + key_layer = past_key_values[0] + value_layer = past_key_values[1] attention_mask = encoder_attention_mask elif is_cross_attention: key_layer = self.transpose_for_scores(self.k_proj(encoder_hidden_states)) value_layer = self.transpose_for_scores(self.v_proj(encoder_hidden_states)) attention_mask = encoder_attention_mask - elif past_key_value is not None: + elif past_key_values is not None: key_layer = self.transpose_for_scores(self.k_proj(hidden_states)) value_layer = self.transpose_for_scores(self.v_proj(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = torch.cat([past_key_values[0], key_layer], dim=2) + value_layer = torch.cat([past_key_values[1], value_layer], dim=2) else: key_layer = self.transpose_for_scores(self.k_proj(hidden_states)) value_layer = self.transpose_for_scores(self.v_proj(hidden_states)) query_layer = self.transpose_for_scores(mixed_query_layer) - use_cache = past_key_value is not None + use_cache = past_key_values is not None if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # Further calls to cross_attention layer can then reuse all cross-attention @@ -163,8 +165,8 @@ class ErnieMSelfAttention(nn.Module): # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of # all previous decoder key/value_states. Further calls to uni-directional self-attention # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) + # if encoder bi-directional self-attention `past_key_values` is always `None` + past_key_values = (key_layer, value_layer) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -216,7 +218,7 @@ class ErnieMSelfAttention(nn.Module): outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) if self.is_decoder: - outputs = outputs + (past_key_value,) + outputs = outputs + (past_key_values,) return outputs @@ -245,6 +247,7 @@ class ErnieMAttention(nn.Module): self.self_attn.all_head_size = self.self_attn.attention_head_size * self.self_attn.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -252,7 +255,7 @@ class ErnieMAttention(nn.Module): head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor]: self_outputs = self.self_attn( @@ -261,7 +264,7 @@ class ErnieMAttention(nn.Module): head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, ) attention_output = self.out_proj(self_outputs[0]) @@ -289,12 +292,13 @@ class ErnieMEncoderLayer(nn.Module): else: self.activation = config.hidden_act + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = True, ): residual = hidden_states @@ -303,7 +307,7 @@ class ErnieMEncoderLayer(nn.Module): hidden_states=hidden_states, attention_mask=attention_mask, head_mask=head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, ) @@ -312,7 +316,7 @@ class ErnieMEncoderLayer(nn.Module): hidden_states=hidden_states, attention_mask=attention_mask, head_mask=head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, ) hidden_states = residual + self.dropout1(hidden_states) @@ -356,13 +360,12 @@ class ErnieMEncoder(nn.Module): hidden_states = hidden_states + (output,) for i, layer in enumerate(self.layers): layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None output, opt_attn_weights = layer( hidden_states=output, attention_mask=attention_mask, head_mask=layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values[i] if past_key_values is not None else None, ) if output_hidden_states: diff --git a/src/transformers/models/deprecated/gptsan_japanese/modeling_gptsan_japanese.py b/src/transformers/models/deprecated/gptsan_japanese/modeling_gptsan_japanese.py index ae15ffd415..49b3e1bb05 100644 --- a/src/transformers/models/deprecated/gptsan_japanese/modeling_gptsan_japanese.py +++ b/src/transformers/models/deprecated/gptsan_japanese/modeling_gptsan_japanese.py @@ -30,6 +30,7 @@ from ....utils import ( is_torch_fx_proxy, logging, ) +from ....utils.deprecation import deprecate_kwarg from .configuration_gptsan_japanese import GPTSanJapaneseConfig @@ -377,11 +378,12 @@ class GPTSanJapaneseAttention(nn.Module): def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, @@ -397,27 +399,27 @@ class GPTSanJapaneseAttention(nn.Module): # get query proj query_states = self.q_proj(hidden_states) * self.scaling # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as + # `past_key_values[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_values` is the same as # the provided `key_value_states` to support prefix tuning if ( is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] + and past_key_values is not None + and past_key_values[0].shape[2] == key_value_states.shape[1] ): # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] + key_states = past_key_values[0] + value_states = past_key_values[1] elif is_cross_attention: # cross_attentions key_states = self._shape(self.k_proj(key_value_states), -1, bsz) value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: + elif past_key_values is not None: # reuse k, v, self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = torch.cat([past_key_values[0], key_states], dim=2) + value_states = torch.cat([past_key_values[1], value_states], dim=2) else: # self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) @@ -430,8 +432,8 @@ class GPTSanJapaneseAttention(nn.Module): # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of # all previous decoder key/value_states. Further calls to uni-directional self-attention # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + # if encoder bi-directional self-attention `past_key_values` is always `None` + past_key_values = (key_states, value_states) proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) @@ -495,7 +497,7 @@ class GPTSanJapaneseAttention(nn.Module): attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights_reshaped, past_key_values class GPTSanJapaneseLayerSelfAttention(nn.Module): @@ -513,10 +515,11 @@ class GPTSanJapaneseLayerSelfAttention(nn.Module): ) self.norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_epsilon) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: Optional[tuple[torch.FloatTensor]], - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = False, @@ -558,11 +561,11 @@ class GPTSanJapaneseLayerSelfAttention(nn.Module): """ # Self Attention # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attn_past_key_value = past_key_values[:2] if past_key_values is not None else None # add present self-attn cache to positions 1,2 of present_key_value tuple atten_out = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_values=self_attn_past_key_value, attention_mask=(1 - attention_mask) * torch.finfo(hidden_states.dtype).min, layer_head_mask=head_mask, output_attentions=output_attentions, @@ -594,10 +597,11 @@ class GPTSanJapaneseBlock(nn.Module): self.self_attn = GPTSanJapaneseLayerSelfAttention(config) self.feed_forward = GPTSanJapaneseLayerDenseFF(config) if ext_layer else GPTSanJapaneseLayerSparseFF(config) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: Optional[tuple[torch.FloatTensor]], - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = False, @@ -641,7 +645,7 @@ class GPTSanJapaneseBlock(nn.Module): """ atten_out = self.self_attn( hidden_states=hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, attention_mask=attention_mask, head_mask=head_mask, use_cache=use_cache, @@ -918,7 +922,7 @@ class GPTSanJapaneseModel(GPTSanJapanesePreTrainedModel): elif self.config.d_spout and spout is not None: # `spout` is a special input vector specific to GPTSAN # This controls the output by projecting embedded information such as the class of sentences during learning. - # It should passed instead of the first past_key_value. + # It should passed instead of the first past_key_values. # See the original GPTSAN repository for details num_pasts_contexts += 1 @@ -1032,7 +1036,7 @@ class GPTSanJapaneseModel(GPTSanJapanesePreTrainedModel): ) and layer < self.config.num_switch_layers block_output = self.blocks[layer]( hidden_states=hidden_states, - past_key_value=past, + past_key_values=past, attention_mask=extended_attention_mask, head_mask=head_mask, use_cache=self.config.use_cache or use_cache, diff --git a/src/transformers/models/deprecated/mega/modeling_mega.py b/src/transformers/models/deprecated/mega/modeling_mega.py index 9eb1c5c1d6..cc77cb2874 100644 --- a/src/transformers/models/deprecated/mega/modeling_mega.py +++ b/src/transformers/models/deprecated/mega/modeling_mega.py @@ -41,6 +41,7 @@ from ....utils import ( logging, replace_return_docstrings, ) +from ....utils.deprecation import deprecate_kwarg from .configuration_mega import MegaConfig @@ -1173,6 +1174,7 @@ class MegaBlock(nn.Module): else: self.cross_attn = None + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -1180,7 +1182,7 @@ class MegaBlock(nn.Module): causal_mask: Optional[torch.LongTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[torch.FloatTensor]] = None, + past_key_values: Optional[tuple[torch.FloatTensor]] = None, output_attentions: Optional[bool] = False, use_cache: bool = False, ) -> tuple[torch.Tensor]: @@ -1202,14 +1204,14 @@ class MegaBlock(nn.Module): encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, source_sequence_length)`, *optional*): Indicates which entries in the cross/source sequence are to be ignored (mostly due to padding), where elements are either 1 for *not masked* or 0 for *masked*. - past_key_value (`tuple(torch.Tensor)`, *optional*): + past_key_values (`tuple(torch.Tensor)`, *optional*): The hidden states returned from the previous timestep during incremental decoding; expects that self-attention key, value, and EMA states are the first 3 entries in the tuple, and (if doing cross-attention) cross-attention key and value are the last 2 entries in the tuple output_attentions (`bool`, default `False`): Whether to return self-attention weights use_cache (`bool`, default `False`): - Whether to perform incremental decoding; uses `past_key_value` as prior state, and returns the updated + Whether to perform incremental decoding; uses `past_key_values` as prior state, and returns the updated states for use in the next step Returns: @@ -1244,7 +1246,7 @@ class MegaBlock(nn.Module): # sequence length as the input tensor; if we're caching incremental states, we assume the input # sequence length is 1 (Mega will break otherwise), so we take the padding mask for the final # token in the input (mask is received as [batch X sequence length]) - if use_cache and (past_key_value is not None) and (attention_mask is not None): + if use_cache and (past_key_values is not None) and (attention_mask is not None): mega_padding_mask = attention_mask[:, -1].unsqueeze(-1) else: mega_padding_mask = attention_mask @@ -1253,7 +1255,7 @@ class MegaBlock(nn.Module): input=hidden_states, padding_mask=mega_padding_mask, causal_mask=causal_mask, - past_key_values=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, ) @@ -1272,7 +1274,7 @@ class MegaBlock(nn.Module): key=encoder_hidden_states, value=encoder_hidden_states, key_padding_mask=encoder_attention_mask, - past_key_values=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, ) @@ -1592,7 +1594,7 @@ class MegaModel(MegaPreTrainedModel): causal_mask=causal_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, - past_key_value=current_decoder_cache, + past_key_values=current_decoder_cache, output_attentions=output_attentions, use_cache=use_cache, ) diff --git a/src/transformers/models/deprecated/nezha/modeling_nezha.py b/src/transformers/models/deprecated/nezha/modeling_nezha.py index 692d5dd092..635a078c6a 100644 --- a/src/transformers/models/deprecated/nezha/modeling_nezha.py +++ b/src/transformers/models/deprecated/nezha/modeling_nezha.py @@ -47,6 +47,7 @@ from ....utils import ( logging, replace_return_docstrings, ) +from ....utils.deprecation import deprecate_kwarg from .configuration_nezha import NezhaConfig @@ -242,6 +243,7 @@ class NezhaSelfAttention(nn.Module): x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -249,7 +251,7 @@ class NezhaSelfAttention(nn.Module): head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor]: mixed_query_layer = self.query(hidden_states) @@ -259,20 +261,20 @@ class NezhaSelfAttention(nn.Module): # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - if is_cross_attention and past_key_value is not None: + if is_cross_attention and past_key_values is not None: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] + key_layer = past_key_values[0] + value_layer = past_key_values[1] attention_mask = encoder_attention_mask elif is_cross_attention: key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) attention_mask = encoder_attention_mask - elif past_key_value is not None: + elif past_key_values is not None: key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = torch.cat([past_key_values[0], key_layer], dim=2) + value_layer = torch.cat([past_key_values[1], value_layer], dim=2) else: key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) @@ -286,8 +288,8 @@ class NezhaSelfAttention(nn.Module): # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of # all previous decoder key/value_states. Further calls to uni-directional self-attention # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) + # if encoder bi-directional self-attention `past_key_values` is always `None` + past_key_values = (key_layer, value_layer) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -343,7 +345,7 @@ class NezhaSelfAttention(nn.Module): outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) if self.is_decoder: - outputs = outputs + (past_key_value,) + outputs = outputs + (past_key_values,) return outputs @@ -386,6 +388,7 @@ class NezhaAttention(nn.Module): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -393,7 +396,7 @@ class NezhaAttention(nn.Module): head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor]: self_outputs = self.self( @@ -402,7 +405,7 @@ class NezhaAttention(nn.Module): head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, ) attention_output = self.output(self_outputs[0], hidden_states) @@ -454,6 +457,7 @@ class NezhaLayer(GradientCheckpointingLayer): self.intermediate = NezhaIntermediate(config) self.output = NezhaOutput(config) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -461,17 +465,17 @@ class NezhaLayer(GradientCheckpointingLayer): head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor]: # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attn_past_key_value = past_key_values[:2] if past_key_values is not None else None self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_values=self_attn_past_key_value, ) attention_output = self_attention_outputs[0] @@ -490,8 +494,8 @@ class NezhaLayer(GradientCheckpointingLayer): " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + # cross_attn cached key/values tuple is at positions 3,4 of past_key_values tuple + cross_attn_past_key_value = past_key_values[-2:] if past_key_values is not None else None cross_attention_outputs = self.crossattention( attention_output, attention_mask, @@ -562,7 +566,6 @@ class NezhaEncoder(nn.Module): all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None layer_outputs = layer_module( hidden_states, @@ -570,7 +573,7 @@ class NezhaEncoder(nn.Module): layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values[i] if past_key_values is not None else None, output_attentions, ) diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index 9447e17473..fab787f8e7 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -33,6 +33,7 @@ from ....modeling_layers import GradientCheckpointingLayer from ....modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast from ....modeling_utils import PreTrainedModel from ....utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ....utils.deprecation import deprecate_kwarg from .configuration_open_llama import OpenLlamaConfig @@ -267,12 +268,13 @@ class OpenLlamaAttention(nn.Module): def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: @@ -283,18 +285,18 @@ class OpenLlamaAttention(nn.Module): value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] + if past_key_values is not None: + kv_seq_len += past_key_values[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) # [bsz, nh, t, hd] - if past_key_value is not None: + if past_key_values is not None: # reuse k, v, self_attention - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = torch.cat([past_key_values[0], key_states], dim=2) + value_states = torch.cat([past_key_values[1], value_states], dim=2) - past_key_value = (key_states, value_states) if use_cache else None + past_key_values = (key_states, value_states) if use_cache else None if self.config.use_memory_efficient_attention and xops is not None and self.training: attn_weights = None @@ -341,7 +343,7 @@ class OpenLlamaAttention(nn.Module): if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights, past_key_values class OpenLlamaDecoderLayer(GradientCheckpointingLayer): @@ -358,12 +360,13 @@ class OpenLlamaDecoderLayer(GradientCheckpointingLayer): self.input_layernorm = OpenLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = OpenLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: @@ -378,7 +381,7 @@ class OpenLlamaDecoderLayer(GradientCheckpointingLayer): use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ residual = hidden_states @@ -390,7 +393,7 @@ class OpenLlamaDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, ) @@ -628,13 +631,11 @@ class OpenLlamaModel(OpenLlamaPreTrainedModel): if output_hidden_states: all_hidden_states += (hidden_states,) - past_key_value = past_key_values[idx] if past_key_values is not None else None - layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values[idx] if past_key_values is not None else None, output_attentions=output_attentions, use_cache=use_cache, ) diff --git a/src/transformers/models/deprecated/qdqbert/modeling_qdqbert.py b/src/transformers/models/deprecated/qdqbert/modeling_qdqbert.py index 7245f44c34..cfa66aaf02 100755 --- a/src/transformers/models/deprecated/qdqbert/modeling_qdqbert.py +++ b/src/transformers/models/deprecated/qdqbert/modeling_qdqbert.py @@ -49,6 +49,7 @@ from ....utils import ( replace_return_docstrings, requires_backends, ) +from ....utils.deprecation import deprecate_kwarg from .configuration_qdqbert import QDQBertConfig @@ -242,6 +243,7 @@ class QDQBertSelfAttention(nn.Module): x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, @@ -249,7 +251,7 @@ class QDQBertSelfAttention(nn.Module): head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, - past_key_value=None, + past_key_values=None, output_attentions=False, ): mixed_query_layer = self.query(hidden_states) @@ -259,20 +261,20 @@ class QDQBertSelfAttention(nn.Module): # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - if is_cross_attention and past_key_value is not None: + if is_cross_attention and past_key_values is not None: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] + key_layer = past_key_values[0] + value_layer = past_key_values[1] attention_mask = encoder_attention_mask elif is_cross_attention: key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) attention_mask = encoder_attention_mask - elif past_key_value is not None: + elif past_key_values is not None: key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = torch.cat([past_key_values[0], key_layer], dim=2) + value_layer = torch.cat([past_key_values[1], value_layer], dim=2) else: key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) @@ -286,8 +288,8 @@ class QDQBertSelfAttention(nn.Module): # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of # all previous decoder key/value_states. Further calls to uni-directional self-attention # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) + # if encoder bi-directional self-attention `past_key_values` is always `None` + past_key_values = (key_layer, value_layer) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul( @@ -337,7 +339,7 @@ class QDQBertSelfAttention(nn.Module): outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) if self.is_decoder: - outputs = outputs + (past_key_value,) + outputs = outputs + (past_key_values,) return outputs @@ -390,6 +392,7 @@ class QDQBertAttention(nn.Module): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, @@ -397,7 +400,7 @@ class QDQBertAttention(nn.Module): head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, - past_key_value=None, + past_key_values=None, output_attentions=False, ): self_outputs = self.self( @@ -406,7 +409,7 @@ class QDQBertAttention(nn.Module): head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, ) attention_output = self.output(self_outputs[0], hidden_states) @@ -467,6 +470,7 @@ class QDQBertLayer(GradientCheckpointingLayer): self.intermediate = QDQBertIntermediate(config) self.output = QDQBertOutput(config) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, @@ -474,17 +478,17 @@ class QDQBertLayer(GradientCheckpointingLayer): head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, - past_key_value=None, + past_key_values=None, output_attentions=False, ): # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attn_past_key_value = past_key_values[:2] if past_key_values is not None else None self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_values=self_attn_past_key_value, ) attention_output = self_attention_outputs[0] @@ -503,8 +507,8 @@ class QDQBertLayer(GradientCheckpointingLayer): " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + # cross_attn cached key/values tuple is at positions 3,4 of past_key_values tuple + cross_attn_past_key_value = past_key_values[-2:] if past_key_values is not None else None cross_attention_outputs = self.crossattention( attention_output, attention_mask, @@ -567,7 +571,6 @@ class QDQBertEncoder(nn.Module): all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None layer_outputs = layer_module( hidden_states, @@ -575,7 +578,7 @@ class QDQBertEncoder(nn.Module): layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values[i] if past_key_values is not None else None, output_attentions, ) diff --git a/src/transformers/models/deprecated/realm/modeling_realm.py b/src/transformers/models/deprecated/realm/modeling_realm.py index 767bcf5a9c..8021a142dd 100644 --- a/src/transformers/models/deprecated/realm/modeling_realm.py +++ b/src/transformers/models/deprecated/realm/modeling_realm.py @@ -34,6 +34,7 @@ from ....modeling_outputs import ( from ....modeling_utils import PreTrainedModel from ....pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ....utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings +from ....utils.deprecation import deprecate_kwarg from .configuration_realm import RealmConfig @@ -247,6 +248,7 @@ class RealmSelfAttention(nn.Module): x = x.view(new_x_shape) return x.permute(0, 2, 1, 3) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -254,7 +256,7 @@ class RealmSelfAttention(nn.Module): head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor]: mixed_query_layer = self.query(hidden_states) @@ -264,27 +266,27 @@ class RealmSelfAttention(nn.Module): # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - if is_cross_attention and past_key_value is not None: + if is_cross_attention and past_key_values is not None: # reuse k,v, cross_attentions - key_layer = past_key_value[0] - value_layer = past_key_value[1] + key_layer = past_key_values[0] + value_layer = past_key_values[1] attention_mask = encoder_attention_mask elif is_cross_attention: key_layer = self.transpose_for_scores(self.key(encoder_hidden_states)) value_layer = self.transpose_for_scores(self.value(encoder_hidden_states)) attention_mask = encoder_attention_mask - elif past_key_value is not None: + elif past_key_values is not None: key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) - key_layer = torch.cat([past_key_value[0], key_layer], dim=2) - value_layer = torch.cat([past_key_value[1], value_layer], dim=2) + key_layer = torch.cat([past_key_values[0], key_layer], dim=2) + value_layer = torch.cat([past_key_values[1], value_layer], dim=2) else: key_layer = self.transpose_for_scores(self.key(hidden_states)) value_layer = self.transpose_for_scores(self.value(hidden_states)) query_layer = self.transpose_for_scores(mixed_query_layer) - use_cache = past_key_value is not None + use_cache = past_key_values is not None if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # Further calls to cross_attention layer can then reuse all cross-attention @@ -292,8 +294,8 @@ class RealmSelfAttention(nn.Module): # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of # all previous decoder key/value_states. Further calls to uni-directional self-attention # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_layer, value_layer) + # if encoder bi-directional self-attention `past_key_values` is always `None` + past_key_values = (key_layer, value_layer) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -345,7 +347,7 @@ class RealmSelfAttention(nn.Module): outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) if self.is_decoder: - outputs = outputs + (past_key_value,) + outputs = outputs + (past_key_values,) return outputs @@ -395,6 +397,7 @@ class RealmAttention(nn.Module): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -402,7 +405,7 @@ class RealmAttention(nn.Module): head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor]: self_outputs = self.self( @@ -411,7 +414,7 @@ class RealmAttention(nn.Module): head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values, output_attentions, ) attention_output = self.output(self_outputs[0], hidden_states) @@ -463,6 +466,7 @@ class RealmLayer(GradientCheckpointingLayer): self.intermediate = RealmIntermediate(config) self.output = RealmOutput(config) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -470,17 +474,17 @@ class RealmLayer(GradientCheckpointingLayer): head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[tuple[tuple[torch.FloatTensor]]] = None, + past_key_values: Optional[tuple[tuple[torch.FloatTensor]]] = None, output_attentions: Optional[bool] = False, ) -> tuple[torch.Tensor]: # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attn_past_key_value = past_key_values[:2] if past_key_values is not None else None self_attention_outputs = self.attention( hidden_states, attention_mask, head_mask, output_attentions=output_attentions, - past_key_value=self_attn_past_key_value, + past_key_values=self_attn_past_key_value, ) attention_output = self_attention_outputs[0] @@ -499,8 +503,8 @@ class RealmLayer(GradientCheckpointingLayer): " by setting `config.add_cross_attention=True`" ) - # cross_attn cached key/values tuple is at positions 3,4 of past_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + # cross_attn cached key/values tuple is at positions 3,4 of past_key_values tuple + cross_attn_past_key_value = past_key_values[-2:] if past_key_values is not None else None cross_attention_outputs = self.crossattention( attention_output, attention_mask, @@ -571,7 +575,6 @@ class RealmEncoder(nn.Module): all_hidden_states = all_hidden_states + (hidden_states,) layer_head_mask = head_mask[i] if head_mask is not None else None - past_key_value = past_key_values[i] if past_key_values is not None else None layer_outputs = layer_module( hidden_states, @@ -579,7 +582,7 @@ class RealmEncoder(nn.Module): layer_head_mask, encoder_hidden_states, encoder_attention_mask, - past_key_value, + past_key_values[i] if past_key_values is not None else None, output_attentions, ) diff --git a/src/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py b/src/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py index bf99ba6255..2117526b04 100755 --- a/src/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py +++ b/src/transformers/models/deprecated/speech_to_text_2/modeling_speech_to_text_2.py @@ -27,6 +27,7 @@ from ....modeling_layers import GradientCheckpointingLayer from ....modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ....modeling_utils import PreTrainedModel from ....utils import add_start_docstrings, logging, replace_return_docstrings +from ....utils.deprecation import deprecate_kwarg from .configuration_speech_to_text_2 import Speech2Text2Config @@ -142,11 +143,12 @@ class Speech2Text2Attention(nn.Module): def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, @@ -162,27 +164,27 @@ class Speech2Text2Attention(nn.Module): # get query proj query_states = self.q_proj(hidden_states) * self.scaling # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as + # `past_key_values[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_values` is the same as # the provided `key_value_states` to support prefix tuning if ( is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] + and past_key_values is not None + and past_key_values[0].shape[2] == key_value_states.shape[1] ): # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] + key_states = past_key_values[0] + value_states = past_key_values[1] elif is_cross_attention: # cross_attentions key_states = self._shape(self.k_proj(key_value_states), -1, bsz) value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: + elif past_key_values is not None: # reuse k, v, self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = torch.cat([past_key_values[0], key_states], dim=2) + value_states = torch.cat([past_key_values[1], value_states], dim=2) else: # self_attention key_states = self._shape(self.k_proj(hidden_states), -1, bsz) @@ -195,8 +197,8 @@ class Speech2Text2Attention(nn.Module): # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of # all previous decoder key/value_states. Further calls to uni-directional self-attention # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + # if encoder bi-directional self-attention `past_key_values` is always `None` + past_key_values = (key_states, value_states) proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) @@ -260,7 +262,7 @@ class Speech2Text2Attention(nn.Module): attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights_reshaped, past_key_values class Speech2Text2DecoderLayer(GradientCheckpointingLayer): @@ -293,6 +295,7 @@ class Speech2Text2DecoderLayer(GradientCheckpointingLayer): self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -301,7 +304,7 @@ class Speech2Text2DecoderLayer(GradientCheckpointingLayer): encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, ): @@ -318,7 +321,7 @@ class Speech2Text2DecoderLayer(GradientCheckpointingLayer): `(encoder_attention_heads,)`. cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of size *(decoder_attention_heads,)*. - past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -327,11 +330,11 @@ class Speech2Text2DecoderLayer(GradientCheckpointingLayer): # Self Attention # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attn_past_key_value = past_key_values[:2] if past_key_values is not None else None # add present self-attn cache to positions 1,2 of present_key_value tuple hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_values=self_attn_past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, @@ -347,13 +350,13 @@ class Speech2Text2DecoderLayer(GradientCheckpointingLayer): residual = hidden_states # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attn_past_key_value = past_key_values[-2:] if past_key_values is not None else None hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_values=cross_attn_past_key_value, output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -604,8 +607,6 @@ class Speech2Text2Decoder(Speech2Text2PreTrainedModel): if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - layer_outputs = decoder_layer( hidden_states, attention_mask=attention_mask, @@ -613,7 +614,7 @@ class Speech2Text2Decoder(Speech2Text2PreTrainedModel): encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), - past_key_value=past_key_value, + past_key_values=past_key_values[idx] if past_key_values is not None else None, output_attentions=output_attentions, use_cache=use_cache, ) diff --git a/src/transformers/models/deprecated/xlm_prophetnet/modeling_xlm_prophetnet.py b/src/transformers/models/deprecated/xlm_prophetnet/modeling_xlm_prophetnet.py index 7b0c3aede2..35d500731e 100644 --- a/src/transformers/models/deprecated/xlm_prophetnet/modeling_xlm_prophetnet.py +++ b/src/transformers/models/deprecated/xlm_prophetnet/modeling_xlm_prophetnet.py @@ -36,6 +36,7 @@ from ....utils import ( logging, replace_return_docstrings, ) +from ....utils.deprecation import deprecate_kwarg from .configuration_xlm_prophetnet import XLMProphetNetConfig @@ -650,13 +651,14 @@ class XLMProphetNetAttention(nn.Module): def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_attn_heads, self.head_dim).transpose(1, 2).contiguous() + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, key_value_states: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None, layer_head_mask: Optional[Tensor] = None, - past_key_value: Optional[tuple[Tensor]] = None, + past_key_values: Optional[tuple[Tensor]] = None, output_attentions: bool = False, ) -> tuple[Tensor, Optional[Tensor]]: batch_size, tgt_len, hidden_size = hidden_states.size() @@ -673,10 +675,10 @@ class XLMProphetNetAttention(nn.Module): # previous time steps are cached - no need to recompute key and value if they are static query_states = self.query_proj(hidden_states) / (self.head_dim**0.5) - if is_cross_attention and past_key_value is not None: + if is_cross_attention and past_key_values is not None: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] + key_states = past_key_values[0] + value_states = past_key_values[1] elif is_cross_attention: # cross_attentions key_states = self._shape(self.key_proj(key_value_states), -1, batch_size) @@ -690,8 +692,8 @@ class XLMProphetNetAttention(nn.Module): # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. # Further calls to cross_attention layer can then reuse all cross-attention # key/value_states (first "if" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + # if encoder bi-directional self-attention `past_key_values` is always `None` + past_key_values = (key_states, value_states) # project states into the correct shape proj_shape = (batch_size, self.num_attn_heads, -1, self.head_dim) @@ -746,7 +748,7 @@ class XLMProphetNetAttention(nn.Module): attn_output = self.out_proj(attn_output) attn_output = nn.functional.dropout(attn_output, p=self.dropout, training=self.training) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights_reshaped, past_key_values class XLMProphetNetFeedForward(nn.Module): @@ -808,10 +810,11 @@ class XLMProphetNetNgramSelfAttention(nn.Module): def prepare_for_onnx_export_(self): self.onnx_trace = True + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, - past_key_value: Optional[tuple[Tensor]] = None, + past_key_values: Optional[tuple[Tensor]] = None, attention_mask=None, layer_head_mask=None, extended_predict_attention_mask=None, @@ -855,14 +858,14 @@ class XLMProphetNetNgramSelfAttention(nn.Module): main_value_states, predict_value_states_list = value_states_list[0], value_states_list[1:] # saved states are stored with shape (batch_size, num_attn_heads, seq_len, head_dim) - if past_key_value is not None: - prev_main_key_states = past_key_value[0] + if past_key_values is not None: + prev_main_key_states = past_key_values[0] main_key_states = torch.cat((prev_main_key_states, main_key_states), dim=2) - prev_main_value_states = past_key_value[1] + prev_main_value_states = past_key_values[1] main_value_states = torch.cat((prev_main_value_states, main_value_states), dim=2) # Update cache - past_key_value = (main_key_states, main_value_states) + past_key_values = (main_key_states, main_value_states) # get seq_length of main stream only sequence_length = ngram_sequence_length // (1 + self.ngram) @@ -984,7 +987,7 @@ class XLMProphetNetNgramSelfAttention(nn.Module): attn_output = nn.functional.dropout(attn_output, p=self.dropout, training=self.training) - return attn_output, main_attn_probs, predict_attn_probs, past_key_value + return attn_output, main_attn_probs, predict_attn_probs, past_key_values def get_main_relative_pos_embeddings( self, hidden_states, attn_weights, position_ids, main_relative_position_buckets @@ -1154,6 +1157,7 @@ class XLMProphetNetDecoderLayer(GradientCheckpointingLayer): self.feed_forward = XLMProphetNetFeedForward(config, config.decoder_ffn_dim) self.feed_forward_layer_norm = LayerNorm(config.hidden_size) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, @@ -1166,16 +1170,16 @@ class XLMProphetNetDecoderLayer(GradientCheckpointingLayer): main_relative_position_buckets=None, predict_relative_position_buckets=None, position_ids=None, - past_key_value=None, + past_key_values=None, use_cache: bool = True, output_attentions: bool = False, ): # 1st residual block # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + self_attn_past_key_value = past_key_values[:2] if past_key_values is not None else None ngram_attention_output, self_attn_weights, self_attn_weights_ngram, present_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_values=self_attn_past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, extended_predict_attention_mask=extended_predict_attention_mask, @@ -1186,7 +1190,7 @@ class XLMProphetNetDecoderLayer(GradientCheckpointingLayer): hidden_states = self.self_attn_layer_norm(hidden_states + ngram_attention_output) # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + cross_attn_past_key_value = past_key_values[-2:] if past_key_values is not None else None cross_attn_weights = None if encoder_hidden_states is not None: # 2nd residual block @@ -1195,7 +1199,7 @@ class XLMProphetNetDecoderLayer(GradientCheckpointingLayer): key_value_states=encoder_hidden_states, attention_mask=encoder_attn_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_values=cross_attn_past_key_value, output_attentions=output_attentions, ) hidden_states = self.cross_attn_layer_norm(attention_output + hidden_states) @@ -1544,8 +1548,6 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel): if self.config.ngram > 0: all_ngram_stream_hidden_states += (hidden_states[:, sequence_length:],) - past_key_value = past_key_values[idx] if past_key_values is not None else None - layer_outputs = decoder_layer( hidden_states, attention_mask=extended_attention_mask, @@ -1557,7 +1559,7 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel): main_relative_position_buckets=main_relative_position_buckets, predict_relative_position_buckets=predict_relative_position_buckets, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values[idx] if past_key_values is not None else None, use_cache=use_cache, output_attentions=output_attentions, ) diff --git a/src/transformers/models/dia/modeling_dia.py b/src/transformers/models/dia/modeling_dia.py index 8b31ddfb7f..bfcc34aea0 100644 --- a/src/transformers/models/dia/modeling_dia.py +++ b/src/transformers/models/dia/modeling_dia.py @@ -48,6 +48,7 @@ from ...utils import ( is_torchdynamo_compiling, logging, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_dia import DiaConfig, DiaDecoderConfig, DiaEncoderConfig from .generation_dia import DiaGenerationMixin @@ -268,12 +269,13 @@ class DiaSelfAttention(nn.Module): self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: @@ -287,10 +289,10 @@ class DiaSelfAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index b0e56f702c..93bbb00824 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -44,6 +44,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import check_model_inputs from .configuration_diffllama import DiffLlamaConfig @@ -154,13 +155,14 @@ class DiffLlamaAttention(nn.Module): self.lambda_k2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,))) self.groupnorm = nn.RMSNorm(2 * self.head_dim, eps=config.rms_norm_eps, elementwise_affine=False) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -179,10 +181,10 @@ class DiffLlamaAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -232,17 +234,18 @@ class DiffLlamaFlashAttention2(DiffLlamaAttention): # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, ) -> tuple[torch.Tensor, None]: - if isinstance(past_key_value, StaticCache): + if isinstance(past_key_values, StaticCache): raise ValueError( "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" @@ -273,10 +276,10 @@ class DiffLlamaFlashAttention2(DiffLlamaAttention): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. @@ -373,13 +376,14 @@ class DiffLlamaSdpaAttention(DiffLlamaAttention): """ # Adapted from DiffLlamaAttention.forward + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -397,10 +401,10 @@ class DiffLlamaSdpaAttention(DiffLlamaAttention): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -488,12 +492,13 @@ class DiffLlamaDecoderLayer(GradientCheckpointingLayer): self.input_layernorm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = DiffLlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -506,7 +511,7 @@ class DiffLlamaDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -651,7 +656,7 @@ class DiffLlamaModel(DiffLlamaPreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, diff --git a/src/transformers/models/diffllama/modular_diffllama.py b/src/transformers/models/diffllama/modular_diffllama.py index 1452ebda28..1ea0a47058 100644 --- a/src/transformers/models/diffllama/modular_diffllama.py +++ b/src/transformers/models/diffllama/modular_diffllama.py @@ -24,6 +24,7 @@ from torch import nn from ...cache_utils import Cache, StaticCache from ...modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask from ...utils import logging +from ...utils.deprecation import deprecate_kwarg from ..gemma.modeling_gemma import GemmaForCausalLM from ..llama.modeling_llama import ( LlamaDecoderLayer, @@ -90,13 +91,14 @@ class DiffLlamaAttention(nn.Module): self.lambda_k2 = nn.Parameter(torch.normal(0, config.lambda_std_dev, size=(self.head_dim,))) self.groupnorm = nn.RMSNorm(2 * self.head_dim, eps=config.rms_norm_eps, elementwise_affine=False) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -115,10 +117,10 @@ class DiffLlamaAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -168,17 +170,18 @@ class DiffLlamaFlashAttention2(DiffLlamaAttention): # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, ) -> tuple[torch.Tensor, None]: - if isinstance(past_key_value, StaticCache): + if isinstance(past_key_values, StaticCache): raise ValueError( "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" @@ -209,10 +212,10 @@ class DiffLlamaFlashAttention2(DiffLlamaAttention): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. @@ -309,13 +312,14 @@ class DiffLlamaSdpaAttention(DiffLlamaAttention): """ # Adapted from DiffLlamaAttention.forward + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -333,10 +337,10 @@ class DiffLlamaSdpaAttention(DiffLlamaAttention): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) diff --git a/src/transformers/models/doge/modeling_doge.py b/src/transformers/models/doge/modeling_doge.py index d83b6f1796..1371ade487 100644 --- a/src/transformers/models/doge/modeling_doge.py +++ b/src/transformers/models/doge/modeling_doge.py @@ -40,6 +40,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import AttentionInterface, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import OutputRecorder, check_model_inputs from .configuration_doge import DogeConfig @@ -261,12 +262,13 @@ class DogeAttention(nn.Module): self.q_norm = DogeRMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = DogeRMSNorm(self.head_dim, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: @@ -280,10 +282,10 @@ class DogeAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) # calculate dynamic mask from value_states dt_states = self.dt_proj( @@ -441,13 +443,14 @@ class DogeDecoderLayer(GradientCheckpointingLayer): self.mlp = DogeMLP(config) if not config.is_moe else DogeCDMoE(config) self.post_attention_residual = nn.Parameter(torch.ones(config.hidden_size)) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], @@ -460,7 +463,7 @@ class DogeDecoderLayer(GradientCheckpointingLayer): position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, **kwargs, @@ -578,7 +581,7 @@ class DogeModel(DogePreTrainedModel): position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, **kwargs, diff --git a/src/transformers/models/doge/modular_doge.py b/src/transformers/models/doge/modular_doge.py index f9b8154ab1..482b3583a0 100644 --- a/src/transformers/models/doge/modular_doge.py +++ b/src/transformers/models/doge/modular_doge.py @@ -33,6 +33,7 @@ from ...modeling_rope_utils import rope_config_validation from ...modeling_utils import AttentionInterface from ...processing_utils import Unpack from ...utils import TransformersKwargs, is_torch_flex_attn_available +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import OutputRecorder from ..llama.modeling_llama import ( LlamaForSequenceClassification, @@ -357,12 +358,13 @@ class DogeAttention(nn.Module): self.q_norm = DogeRMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = DogeRMSNorm(self.head_dim, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: @@ -376,10 +378,10 @@ class DogeAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) # calculate dynamic mask from value_states dt_states = self.dt_proj( @@ -525,13 +527,14 @@ class DogeDecoderLayer(GradientCheckpointingLayer): self.mlp = DogeMLP(config) if not config.is_moe else DogeCDMoE(config) self.post_attention_residual = nn.Parameter(torch.ones(config.hidden_size)) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], @@ -544,7 +547,7 @@ class DogeDecoderLayer(GradientCheckpointingLayer): position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, **kwargs, diff --git a/src/transformers/models/dots1/modeling_dots1.py b/src/transformers/models/dots1/modeling_dots1.py index d3de040241..2986de2afa 100644 --- a/src/transformers/models/dots1/modeling_dots1.py +++ b/src/transformers/models/dots1/modeling_dots1.py @@ -36,6 +36,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import check_model_inputs from .configuration_dots1 import Dots1Config @@ -198,12 +199,13 @@ class Dots1Attention(nn.Module): self.k_norm = Dots1RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -217,10 +219,10 @@ class Dots1Attention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -375,12 +377,13 @@ class Dots1DecoderLayer(GradientCheckpointingLayer): self.post_attention_layernorm = Dots1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.attention_type = config.layer_types[layer_idx] + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -393,7 +396,7 @@ class Dots1DecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -512,7 +515,7 @@ class Dots1Model(Dots1PreTrainedModel): hidden_states, attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index 0bf12f0511..f402614e75 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -41,6 +41,7 @@ from ...modeling_outputs import ( from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_electra import ElectraConfig @@ -224,13 +225,14 @@ class ElectraSelfAttention(nn.Module): self.is_decoder = config.is_decoder self.layer_idx = layer_idx + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -241,19 +243,19 @@ class ElectraSelfAttention(nn.Module): ) is_cross_attention = encoder_hidden_states is not None - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_layer from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = encoder_hidden_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_layer = curr_past_key_value.layers[self.layer_idx].keys value_layer = curr_past_key_value.layers[self.layer_idx].values @@ -267,7 +269,7 @@ class ElectraSelfAttention(nn.Module): batch_size, -1, self.num_attention_heads, self.attention_head_size ).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_layer, value_layer = curr_past_key_value.update( @@ -275,14 +277,14 @@ class ElectraSelfAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if past_key_value is not None: + if past_key_values is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -377,13 +379,14 @@ class ElectraAttention(nn.Module): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -392,7 +395,7 @@ class ElectraAttention(nn.Module): attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) @@ -448,6 +451,7 @@ class ElectraLayer(GradientCheckpointingLayer): self.intermediate = ElectraIntermediate(config) self.output = ElectraOutput(config) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -455,7 +459,7 @@ class ElectraLayer(GradientCheckpointingLayer): head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -464,7 +468,7 @@ class ElectraLayer(GradientCheckpointingLayer): attention_mask=attention_mask, head_mask=head_mask, output_attentions=output_attentions, - past_key_value=past_key_value, + past_key_values=past_key_values, cache_position=cache_position, ) attention_output = self_attention_outputs[0] @@ -482,7 +486,7 @@ class ElectraLayer(GradientCheckpointingLayer): attention_mask=encoder_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) @@ -557,7 +561,7 @@ class ElectraEncoder(nn.Module): layer_head_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) diff --git a/src/transformers/models/emu3/modeling_emu3.py b/src/transformers/models/emu3/modeling_emu3.py index 6f4cafab7c..064a26df06 100644 --- a/src/transformers/models/emu3/modeling_emu3.py +++ b/src/transformers/models/emu3/modeling_emu3.py @@ -39,6 +39,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import check_model_inputs from .configuration_emu3 import Emu3Config, Emu3TextConfig, Emu3VQVAEConfig @@ -141,12 +142,13 @@ class Emu3Attention(nn.Module): config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: @@ -160,10 +162,10 @@ class Emu3Attention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -234,12 +236,13 @@ class Emu3DecoderLayer(GradientCheckpointingLayer): self.post_attention_layernorm = Emu3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.dropout = nn.Dropout(config.attention_dropout) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, @@ -252,7 +255,7 @@ class Emu3DecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -1211,7 +1214,7 @@ class Emu3TextModel(Emu3PreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, diff --git a/src/transformers/models/emu3/modular_emu3.py b/src/transformers/models/emu3/modular_emu3.py index b714e8c732..2933e89757 100644 --- a/src/transformers/models/emu3/modular_emu3.py +++ b/src/transformers/models/emu3/modular_emu3.py @@ -29,6 +29,7 @@ from ...modeling_outputs import CausalLMOutputWithPast from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, can_return_tuple, logging +from ...utils.deprecation import deprecate_kwarg from ..chameleon.modeling_chameleon import ( ChameleonPreTrainedModel, ChameleonVQVAEEncoderConvDownsample, @@ -51,12 +52,13 @@ class Emu3DecoderLayer(LlamaDecoderLayer): super().__init__(config, layer_idx) self.dropout = nn.Dropout(config.attention_dropout) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, @@ -69,7 +71,7 @@ class Emu3DecoderLayer(LlamaDecoderLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, diff --git a/src/transformers/models/ernie/modeling_ernie.py b/src/transformers/models/ernie/modeling_ernie.py index 98cdda20ef..bea39d1d8b 100644 --- a/src/transformers/models/ernie/modeling_ernie.py +++ b/src/transformers/models/ernie/modeling_ernie.py @@ -42,6 +42,7 @@ from ...modeling_outputs import ( from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_ernie import ErnieConfig @@ -153,13 +154,14 @@ class ErnieSelfAttention(nn.Module): self.is_decoder = config.is_decoder self.layer_idx = layer_idx + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -170,19 +172,19 @@ class ErnieSelfAttention(nn.Module): ) is_cross_attention = encoder_hidden_states is not None - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_layer from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = encoder_hidden_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_layer = curr_past_key_value.layers[self.layer_idx].keys value_layer = curr_past_key_value.layers[self.layer_idx].values @@ -196,7 +198,7 @@ class ErnieSelfAttention(nn.Module): batch_size, -1, self.num_attention_heads, self.attention_head_size ).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_layer, value_layer = curr_past_key_value.update( @@ -204,14 +206,14 @@ class ErnieSelfAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if past_key_value is not None: + if past_key_values is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -306,13 +308,14 @@ class ErnieAttention(nn.Module): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -321,7 +324,7 @@ class ErnieAttention(nn.Module): attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) @@ -377,6 +380,7 @@ class ErnieLayer(GradientCheckpointingLayer): self.intermediate = ErnieIntermediate(config) self.output = ErnieOutput(config) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -384,7 +388,7 @@ class ErnieLayer(GradientCheckpointingLayer): head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -393,7 +397,7 @@ class ErnieLayer(GradientCheckpointingLayer): attention_mask=attention_mask, head_mask=head_mask, output_attentions=output_attentions, - past_key_value=past_key_value, + past_key_values=past_key_values, cache_position=cache_position, ) attention_output = self_attention_outputs[0] @@ -411,7 +415,7 @@ class ErnieLayer(GradientCheckpointingLayer): attention_mask=encoder_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) @@ -422,6 +426,7 @@ class ErnieLayer(GradientCheckpointingLayer): self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output ) outputs = (layer_output,) + outputs + return outputs def feed_forward_chunk(self, attention_output): @@ -485,7 +490,7 @@ class ErnieEncoder(nn.Module): layer_head_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) diff --git a/src/transformers/models/ernie4_5/modeling_ernie4_5.py b/src/transformers/models/ernie4_5/modeling_ernie4_5.py index 38a3e5f563..1125e53727 100644 --- a/src/transformers/models/ernie4_5/modeling_ernie4_5.py +++ b/src/transformers/models/ernie4_5/modeling_ernie4_5.py @@ -34,6 +34,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import check_model_inputs from .configuration_ernie4_5 import Ernie4_5Config @@ -192,12 +193,13 @@ class Ernie4_5Attention(nn.Module): self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: @@ -211,10 +213,10 @@ class Ernie4_5Attention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -268,12 +270,13 @@ class Ernie4_5DecoderLayer(GradientCheckpointingLayer): self.input_layernorm = Ernie4_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Ernie4_5RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -286,7 +289,7 @@ class Ernie4_5DecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -387,7 +390,7 @@ class Ernie4_5Model(Ernie4_5PreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, diff --git a/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py b/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py index 3c579d5bd1..2df2f0d871 100644 --- a/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +++ b/src/transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py @@ -36,6 +36,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import OutputRecorder, check_model_inputs from .configuration_ernie4_5_moe import Ernie4_5_MoeConfig @@ -215,12 +216,13 @@ class Ernie4_5_MoeAttention(nn.Module): self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: @@ -234,10 +236,10 @@ class Ernie4_5_MoeAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -400,13 +402,14 @@ class Ernie4_5_MoeDecoderLayer(GradientCheckpointingLayer): self.input_layernorm = Ernie4_5_MoeRMSNorm(config.hidden_size, config.rms_norm_eps) self.post_attention_layernorm = Ernie4_5_MoeRMSNorm(config.hidden_size, config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> torch.FloatTensor: @@ -424,7 +427,7 @@ class Ernie4_5_MoeDecoderLayer(GradientCheckpointingLayer): use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): @@ -444,7 +447,7 @@ class Ernie4_5_MoeDecoderLayer(GradientCheckpointingLayer): position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, cache_position=cache_position, **kwargs, ) @@ -557,7 +560,7 @@ class Ernie4_5_MoeModel(Ernie4_5_MoePreTrainedModel): position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, **kwargs, diff --git a/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py b/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py index 9641532957..2c35b7fa06 100644 --- a/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +++ b/src/transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py @@ -297,7 +297,7 @@ class Ernie4_5_MoeModel(Ernie4_5_MoePreTrainedModel): position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, **kwargs, diff --git a/src/transformers/models/evolla/modeling_evolla.py b/src/transformers/models/evolla/modeling_evolla.py index 61746685ee..695372fa3a 100644 --- a/src/transformers/models/evolla/modeling_evolla.py +++ b/src/transformers/models/evolla/modeling_evolla.py @@ -52,6 +52,7 @@ from ...modeling_utils import ( ) from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import check_model_inputs from .configuration_evolla import EvollaConfig, SaProtConfig @@ -1134,6 +1135,7 @@ class EvollaSequenceAlignerCrossAttention(nn.Module): return context_layer + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, query_states, @@ -1147,7 +1149,7 @@ class EvollaSequenceAlignerCrossAttention(nn.Module): protein_batch_mask=None, structure_batch_mask=None, msa_batch_mask=None, - past_key_value=None, + past_key_values=None, ): if protein_kv_states is not None: bs, protein_kv_seq_len, dim = protein_kv_states.shape @@ -1379,12 +1381,13 @@ class EvollaAttention(nn.Module): config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: @@ -1398,10 +1401,10 @@ class EvollaAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -1439,13 +1442,14 @@ class EvollaDecoderLayer(GradientCheckpointingLayer): protein_encoder_dim=config.hidden_size, ) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, protein_kv_states: Optional[torch.Tensor] = None, @@ -1466,7 +1470,7 @@ class EvollaDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -1636,7 +1640,7 @@ class EvollaModel(EvollaPreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, diff --git a/src/transformers/models/evolla/modular_evolla.py b/src/transformers/models/evolla/modular_evolla.py index 7abd74e435..5783f869bd 100644 --- a/src/transformers/models/evolla/modular_evolla.py +++ b/src/transformers/models/evolla/modular_evolla.py @@ -36,6 +36,7 @@ from ...utils import ( can_return_tuple, logging, ) +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import check_model_inputs from ..esm.modeling_esm import ( EsmAttention, @@ -613,6 +614,7 @@ class EvollaSequenceAlignerCrossAttention(nn.Module): return context_layer + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, query_states, @@ -626,7 +628,7 @@ class EvollaSequenceAlignerCrossAttention(nn.Module): protein_batch_mask=None, structure_batch_mask=None, msa_batch_mask=None, - past_key_value=None, + past_key_values=None, ): if protein_kv_states is not None: bs, protein_kv_seq_len, dim = protein_kv_states.shape @@ -712,13 +714,14 @@ class EvollaDecoderLayer(LlamaDecoderLayer): protein_encoder_dim=config.hidden_size, ) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, protein_kv_states: Optional[torch.Tensor] = None, @@ -739,7 +742,7 @@ class EvollaDecoderLayer(LlamaDecoderLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -895,7 +898,7 @@ class EvollaModel(EvollaPreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, diff --git a/src/transformers/models/exaone4/modeling_exaone4.py b/src/transformers/models/exaone4/modeling_exaone4.py index 7dbd77a0c0..bbd5d0678c 100644 --- a/src/transformers/models/exaone4/modeling_exaone4.py +++ b/src/transformers/models/exaone4/modeling_exaone4.py @@ -43,6 +43,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg from .configuration_exaone4 import Exaone4Config @@ -200,12 +201,13 @@ class Exaone4Attention(nn.Module): self.q_norm = Exaone4RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = Exaone4RMSNorm(self.head_dim, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: @@ -225,11 +227,11 @@ class Exaone4Attention(nn.Module): if self.sliding_window is None or self.is_sliding: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: cache_kwargs = { "cache_position": cache_position, } - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -278,12 +280,13 @@ class Exaone4DecoderLayer(GradientCheckpointingLayer): self.post_attention_layernorm = Exaone4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Exaone4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -294,7 +297,7 @@ class Exaone4DecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -408,7 +411,7 @@ class Exaone4Model(Exaone4PreTrainedModel): position_embeddings=position_embeddings, attention_mask=causal_mask_mapping[layer_type], position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, **kwargs, diff --git a/src/transformers/models/exaone4/modular_exaone4.py b/src/transformers/models/exaone4/modular_exaone4.py index 41200030bb..fa7512a2ca 100644 --- a/src/transformers/models/exaone4/modular_exaone4.py +++ b/src/transformers/models/exaone4/modular_exaone4.py @@ -35,6 +35,7 @@ from ...utils import ( TransformersKwargs, logging, ) +from ...utils.deprecation import deprecate_kwarg from ..llama.modeling_llama import ( LlamaForCausalLM, LlamaForQuestionAnswering, @@ -287,12 +288,13 @@ class Exaone4Attention(nn.Module): self.q_norm = Exaone4RMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = Exaone4RMSNorm(self.head_dim, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: @@ -312,11 +314,11 @@ class Exaone4Attention(nn.Module): if self.sliding_window is None or self.is_sliding: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: cache_kwargs = { "cache_position": cache_position, } - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -422,7 +424,7 @@ class Exaone4Model(Exaone4PreTrainedModel, LlamaModel): position_embeddings=position_embeddings, attention_mask=causal_mask_mapping[layer_type], position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, **kwargs, diff --git a/src/transformers/models/falcon_h1/modeling_falcon_h1.py b/src/transformers/models/falcon_h1/modeling_falcon_h1.py index 5e60f1a8bb..b9b38b4c9c 100644 --- a/src/transformers/models/falcon_h1/modeling_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modeling_falcon_h1.py @@ -43,6 +43,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging +from ...utils.deprecation import deprecate_kwarg from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available from .configuration_falcon_h1 import FalconH1Config @@ -351,12 +352,13 @@ class FalconH1Attention(nn.Module): ) self.key_multiplier = config.key_multiplier + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -370,10 +372,10 @@ class FalconH1Attention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -1071,13 +1073,14 @@ class FalconH1DecoderLayer(GradientCheckpointingLayer): self.input_layernorm = FalconH1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.pre_ff_layernorm = FalconH1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, mamba_attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[FalconHybridMambaAttentionDynamicCache] = None, + past_key_values: Optional[FalconHybridMambaAttentionDynamicCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -1089,7 +1092,7 @@ class FalconH1DecoderLayer(GradientCheckpointingLayer): hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_value (`FalconHybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + past_key_values (`FalconHybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -1111,7 +1114,7 @@ class FalconH1DecoderLayer(GradientCheckpointingLayer): mamba_hidden_states = self.mamba( hidden_states=hidden_states, - cache_params=past_key_value, + cache_params=past_key_values, cache_position=cache_position, attention_mask=mamba_attention_mask, ) @@ -1121,7 +1124,7 @@ class FalconH1DecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states * self.attention_in_multiplier, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -1309,7 +1312,7 @@ class FalconH1Model(FalconH1PreTrainedModel): attention_mask=causal_mask, mamba_attention_mask=mamba_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, diff --git a/src/transformers/models/falcon_h1/modular_falcon_h1.py b/src/transformers/models/falcon_h1/modular_falcon_h1.py index afbacc5666..10aaadcdd1 100644 --- a/src/transformers/models/falcon_h1/modular_falcon_h1.py +++ b/src/transformers/models/falcon_h1/modular_falcon_h1.py @@ -52,6 +52,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging +from ...utils.deprecation import deprecate_kwarg from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available from .configuration_falcon_h1 import FalconH1Config @@ -204,12 +205,13 @@ class FalconH1Attention(LlamaAttention): super().__init__(config, layer_idx) self.key_multiplier = config.key_multiplier + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -223,10 +225,10 @@ class FalconH1Attention(LlamaAttention): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -843,13 +845,14 @@ class FalconH1DecoderLayer(GradientCheckpointingLayer): self.input_layernorm = FalconH1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.pre_ff_layernorm = FalconH1RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, mamba_attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[FalconHybridMambaAttentionDynamicCache] = None, + past_key_values: Optional[FalconHybridMambaAttentionDynamicCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -861,7 +864,7 @@ class FalconH1DecoderLayer(GradientCheckpointingLayer): hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_value (`FalconHybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + past_key_values (`FalconHybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -883,7 +886,7 @@ class FalconH1DecoderLayer(GradientCheckpointingLayer): mamba_hidden_states = self.mamba( hidden_states=hidden_states, - cache_params=past_key_value, + cache_params=past_key_values, cache_position=cache_position, attention_mask=mamba_attention_mask, ) @@ -893,7 +896,7 @@ class FalconH1DecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states * self.attention_in_multiplier, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -1081,7 +1084,7 @@ class FalconH1Model(FalconH1PreTrainedModel): attention_mask=causal_mask, mamba_attention_mask=mamba_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, diff --git a/src/transformers/models/funnel/modeling_tf_funnel.py b/src/transformers/models/funnel/modeling_tf_funnel.py index 86982ba295..3d57fa99ea 100644 --- a/src/transformers/models/funnel/modeling_tf_funnel.py +++ b/src/transformers/models/funnel/modeling_tf_funnel.py @@ -1342,20 +1342,24 @@ class TFFunnelForPreTraining(TFFunnelPreTrainedModel): **kwargs, ) -> tuple[tf.Tensor] | TFFunnelForPreTrainingOutput: r""" - Returns: + Returns: - Examples: + Examples: - ```python - >>> from transformers import AutoTokenizer, TFFunnelForPreTraining - >>> import torch + ```python + >>> from transformers import AutoTokenizer, TFFunnelForPreTraining + >>> import torch + from ...utils.deprecation import deprecate_kwarg + from ...utils.deprecation import deprecate_kwarg + from ...utils.deprecation import deprecate_kwarg + from ...utils.deprecation import deprecate_kwarg - >>> tokenizer = AutoTokenizer.from_pretrained("funnel-transformer/small") - >>> model = TFFunnelForPreTraining.from_pretrained("funnel-transformer/small") + >>> tokenizer = AutoTokenizer.from_pretrained("funnel-transformer/small") + >>> model = TFFunnelForPreTraining.from_pretrained("funnel-transformer/small") - >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf") - >>> logits = model(inputs).logits - ```""" + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="tf") + >>> logits = model(inputs).logits + ```""" discriminator_hidden_states = self.funnel( input_ids, attention_mask, diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index ce6b795527..8710c418e6 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -38,6 +38,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import check_model_inputs from .configuration_gemma import GemmaConfig @@ -212,12 +213,13 @@ class GemmaAttention(nn.Module): config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: @@ -231,10 +233,10 @@ class GemmaAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -267,12 +269,13 @@ class GemmaDecoderLayer(GradientCheckpointingLayer): self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -285,7 +288,7 @@ class GemmaDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -395,7 +398,7 @@ class GemmaModel(GemmaPreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, diff --git a/src/transformers/models/gemma/modular_gemma.py b/src/transformers/models/gemma/modular_gemma.py index b715f377c6..67aedbd551 100644 --- a/src/transformers/models/gemma/modular_gemma.py +++ b/src/transformers/models/gemma/modular_gemma.py @@ -422,7 +422,7 @@ class GemmaModel(LlamaModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 80b3c647b1..583d3de289 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -39,6 +39,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import check_model_inputs from .configuration_gemma2 import Gemma2Config @@ -191,12 +192,13 @@ class Gemma2Attention(nn.Module): self.attn_logit_softcapping = self.config.attn_logit_softcapping self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: @@ -210,10 +212,10 @@ class Gemma2Attention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -251,13 +253,14 @@ class Gemma2DecoderLayer(GradientCheckpointingLayer): self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -273,7 +276,7 @@ class Gemma2DecoderLayer(GradientCheckpointingLayer): position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -456,7 +459,7 @@ class Gemma2Model(Gemma2PreTrainedModel): position_embeddings=position_embeddings, attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, diff --git a/src/transformers/models/gemma2/modular_gemma2.py b/src/transformers/models/gemma2/modular_gemma2.py index 1f22987e67..398d036caf 100644 --- a/src/transformers/models/gemma2/modular_gemma2.py +++ b/src/transformers/models/gemma2/modular_gemma2.py @@ -29,6 +29,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...utils import TransformersKwargs, logging +from ...utils.deprecation import deprecate_kwarg from ..gemma.modeling_gemma import ( GemmaAttention, GemmaForCausalLM, @@ -256,12 +257,13 @@ class Gemma2Attention(GemmaAttention): self.scaling = config.query_pre_attn_scalar**-0.5 self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: @@ -275,10 +277,10 @@ class Gemma2Attention(GemmaAttention): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -316,13 +318,14 @@ class Gemma2DecoderLayer(GradientCheckpointingLayer): self.pre_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Gemma2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -338,7 +341,7 @@ class Gemma2DecoderLayer(GradientCheckpointingLayer): position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -453,7 +456,7 @@ class Gemma2Model(GemmaModel): position_embeddings=position_embeddings, attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, diff --git a/src/transformers/models/gemma3/modeling_gemma3.py b/src/transformers/models/gemma3/modeling_gemma3.py index 95aa3ab6f0..9e45c1f162 100644 --- a/src/transformers/models/gemma3/modeling_gemma3.py +++ b/src/transformers/models/gemma3/modeling_gemma3.py @@ -39,6 +39,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import check_model_inputs from ..auto import AutoModel from .configuration_gemma3 import Gemma3Config, Gemma3TextConfig @@ -298,12 +299,13 @@ class Gemma3Attention(nn.Module): self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: @@ -320,10 +322,10 @@ class Gemma3Attention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -360,6 +362,7 @@ class Gemma3DecoderLayer(GradientCheckpointingLayer): self.pre_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -367,7 +370,7 @@ class Gemma3DecoderLayer(GradientCheckpointingLayer): position_embeddings_local: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -388,7 +391,7 @@ class Gemma3DecoderLayer(GradientCheckpointingLayer): position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -555,7 +558,7 @@ class Gemma3TextModel(Gemma3PreTrainedModel): position_embeddings_local=position_embeddings_local, attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, diff --git a/src/transformers/models/gemma3/modular_gemma3.py b/src/transformers/models/gemma3/modular_gemma3.py index df379b5544..af355d7b74 100644 --- a/src/transformers/models/gemma3/modular_gemma3.py +++ b/src/transformers/models/gemma3/modular_gemma3.py @@ -32,6 +32,7 @@ from ...modeling_rope_utils import rope_config_validation from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.deprecation import deprecate_kwarg from ..gemma2.configuration_gemma2 import Gemma2Config from ..gemma2.modeling_gemma2 import ( Gemma2Attention, @@ -401,12 +402,13 @@ class Gemma3Attention(Gemma2Attention): self.q_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) self.k_norm = Gemma3RMSNorm(dim=config.head_dim, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: @@ -423,10 +425,10 @@ class Gemma3Attention(Gemma2Attention): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -463,6 +465,7 @@ class Gemma3DecoderLayer(GradientCheckpointingLayer): self.pre_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Gemma3RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -470,7 +473,7 @@ class Gemma3DecoderLayer(GradientCheckpointingLayer): position_embeddings_local: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -491,7 +494,7 @@ class Gemma3DecoderLayer(GradientCheckpointingLayer): position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -633,7 +636,7 @@ class Gemma3TextModel(Gemma2Model): position_embeddings_local=position_embeddings_local, attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, diff --git a/src/transformers/models/gemma3n/modeling_gemma3n.py b/src/transformers/models/gemma3n/modeling_gemma3n.py index 59aefc7742..72b19b639d 100644 --- a/src/transformers/models/gemma3n/modeling_gemma3n.py +++ b/src/transformers/models/gemma3n/modeling_gemma3n.py @@ -40,6 +40,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.deprecation import deprecate_kwarg from ..auto import AutoModel from .configuration_gemma3n import Gemma3nAudioConfig, Gemma3nConfig, Gemma3nTextConfig, Gemma3nVisionConfig @@ -1302,12 +1303,13 @@ class Gemma3nTextAttention(nn.Module): else None ) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: @@ -1321,9 +1323,9 @@ class Gemma3nTextAttention(nn.Module): query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2) query_states = query_states.transpose(1, 2) - if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_value is not None: + if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_values is not None: # In this case we need special handling of the slice as the layer is of fixed small size (for full layers, we never go beyond) - layer = past_key_value.layers[self.kv_shared_layer_index] + layer = past_key_values.layers[self.kv_shared_layer_index] # Device of past layer may be different from current one indices = cache_position.to(layer.keys.device) # Sliding window cache layers might have smaller size (for full layers, we never go beyond) @@ -1346,7 +1348,7 @@ class Gemma3nTextAttention(nn.Module): value_states = self.v_norm(value_states) value_states = value_states.transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = { "sin": sin, @@ -1354,7 +1356,7 @@ class Gemma3nTextAttention(nn.Module): "cache_position": cache_position, "sliding_window": self.sliding_window, } - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -1400,6 +1402,7 @@ class Gemma3nTextDecoderLayer(GradientCheckpointingLayer): self.per_layer_projection = nn.Linear(self.hidden_size_per_layer_input, self.hidden_size, bias=False) self.post_per_layer_input_norm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -1408,7 +1411,7 @@ class Gemma3nTextDecoderLayer(GradientCheckpointingLayer): per_layer_input: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -1431,7 +1434,7 @@ class Gemma3nTextDecoderLayer(GradientCheckpointingLayer): position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -1672,7 +1675,7 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel): per_layer_input, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, diff --git a/src/transformers/models/gemma3n/modular_gemma3n.py b/src/transformers/models/gemma3n/modular_gemma3n.py index fd40253553..e29ff12ca2 100644 --- a/src/transformers/models/gemma3n/modular_gemma3n.py +++ b/src/transformers/models/gemma3n/modular_gemma3n.py @@ -32,6 +32,7 @@ from ...modeling_rope_utils import rope_config_validation from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.deprecation import deprecate_kwarg from ..auto import AutoModel from ..gemma2.configuration_gemma2 import Gemma2Config from ..gemma2.modeling_gemma2 import ( @@ -1749,12 +1750,13 @@ class Gemma3nTextAttention(Gemma3Attention): else None ) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: @@ -1768,9 +1770,9 @@ class Gemma3nTextAttention(Gemma3Attention): query_states = apply_rotary_pos_emb(query_states, cos, sin, unsqueeze_dim=2) query_states = query_states.transpose(1, 2) - if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_value is not None: + if self.is_kv_shared_layer and self.kv_shared_layer_index is not None and past_key_values is not None: # In this case we need special handling of the slice as the layer is of fixed small size (for full layers, we never go beyond) - layer = past_key_value.layers[self.kv_shared_layer_index] + layer = past_key_values.layers[self.kv_shared_layer_index] # Device of past layer may be different from current one indices = cache_position.to(layer.keys.device) # Sliding window cache layers might have smaller size (for full layers, we never go beyond) @@ -1793,7 +1795,7 @@ class Gemma3nTextAttention(Gemma3Attention): value_states = self.v_norm(value_states) value_states = value_states.transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = { "sin": sin, @@ -1801,7 +1803,7 @@ class Gemma3nTextAttention(Gemma3Attention): "cache_position": cache_position, "sliding_window": self.sliding_window, } - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -1839,6 +1841,7 @@ class Gemma3nTextDecoderLayer(Gemma3DecoderLayer): self.per_layer_projection = nn.Linear(self.hidden_size_per_layer_input, self.hidden_size, bias=False) self.post_per_layer_input_norm = Gemma3nRMSNorm(self.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -1847,7 +1850,7 @@ class Gemma3nTextDecoderLayer(Gemma3DecoderLayer): per_layer_input: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -1870,7 +1873,7 @@ class Gemma3nTextDecoderLayer(Gemma3DecoderLayer): position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -2121,7 +2124,7 @@ class Gemma3nTextModel(Gemma3TextModel): per_layer_input, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, diff --git a/src/transformers/models/git/modeling_git.py b/src/transformers/models/git/modeling_git.py index 5a195b06ec..53f02443f3 100644 --- a/src/transformers/models/git/modeling_git.py +++ b/src/transformers/models/git/modeling_git.py @@ -43,6 +43,7 @@ from ...utils import ( logging, torch_int, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_git import GitConfig, GitVisionConfig @@ -151,12 +152,13 @@ class GitSelfAttention(nn.Module): self.max_position_embeddings = config.max_position_embeddings self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, pixel_values_present: Optional[bool] = False, ) -> tuple[torch.Tensor]: @@ -178,9 +180,9 @@ class GitSelfAttention(nn.Module): .view(batch_size, -1, self.num_attention_heads, self.attention_head_size) .transpose(1, 2) ) - if past_key_value is not None: + if past_key_values is not None: # NOTE: like in other caches, we store the text component. In GIT it means we discard the image component. - key_layer_past, value_layer_past = past_key_value.update( + key_layer_past, value_layer_past = past_key_values.update( key_layer[:, :, cutoff:, :], value_layer[:, :, cutoff:, :], self.layer_idx ) key_layer = torch.cat([key_layer[:, :, :cutoff, :], key_layer_past], dim=2) @@ -191,7 +193,7 @@ class GitSelfAttention(nn.Module): if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if past_key_value is not None: + if past_key_values is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -284,12 +286,13 @@ class GitAttention(nn.Module): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, pixel_values_present: Optional[bool] = False, ) -> tuple[torch.Tensor]: @@ -297,7 +300,7 @@ class GitAttention(nn.Module): hidden_states, attention_mask, head_mask, - past_key_value, + past_key_values, output_attentions, pixel_values_present, ) @@ -345,12 +348,13 @@ class GitLayer(GradientCheckpointingLayer): self.intermediate = GitIntermediate(config) self.output = GitOutput(config) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, pixel_values_present: Optional[bool] = False, ) -> tuple[torch.Tensor]: @@ -360,7 +364,7 @@ class GitLayer(GradientCheckpointingLayer): attention_mask, head_mask, output_attentions=output_attentions, - past_key_value=past_key_value, + past_key_values=past_key_values, pixel_values_present=pixel_values_present, ) diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index cc2de4f9e5..17895e8283 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -39,6 +39,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import check_model_inputs from .configuration_glm import GlmConfig @@ -172,12 +173,13 @@ class GlmAttention(nn.Module): ) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: @@ -191,10 +193,10 @@ class GlmAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -284,12 +286,13 @@ class GlmDecoderLayer(GradientCheckpointingLayer): self.input_layernorm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = GlmRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -302,7 +305,7 @@ class GlmDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -403,7 +406,7 @@ class GlmModel(GlmPreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, diff --git a/src/transformers/models/glm4/modeling_glm4.py b/src/transformers/models/glm4/modeling_glm4.py index 76193a6132..c8febbc563 100644 --- a/src/transformers/models/glm4/modeling_glm4.py +++ b/src/transformers/models/glm4/modeling_glm4.py @@ -40,6 +40,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import check_model_inputs from .configuration_glm4 import Glm4Config @@ -74,12 +75,13 @@ class Glm4DecoderLayer(GradientCheckpointingLayer): self.post_self_attn_layernorm = Glm4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_mlp_layernorm = Glm4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -92,7 +94,7 @@ class Glm4DecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -221,12 +223,13 @@ class Glm4Attention(nn.Module): ) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: @@ -240,10 +243,10 @@ class Glm4Attention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -407,7 +410,7 @@ class Glm4Model(Glm4PreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, diff --git a/src/transformers/models/glm4/modular_glm4.py b/src/transformers/models/glm4/modular_glm4.py index 4312110293..6bbc9b601f 100644 --- a/src/transformers/models/glm4/modular_glm4.py +++ b/src/transformers/models/glm4/modular_glm4.py @@ -23,6 +23,7 @@ from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import CausalLMOutputWithPast from ...processing_utils import Unpack from ...utils import TransformersKwargs, logging +from ...utils.deprecation import deprecate_kwarg from ..glm.modeling_glm import GlmAttention, GlmForCausalLM, GlmForSequenceClassification, GlmForTokenClassification from ..phi3.modeling_phi3 import Phi3MLP from .configuration_glm4 import Glm4Config @@ -50,12 +51,13 @@ class Glm4DecoderLayer(GradientCheckpointingLayer): self.post_self_attn_layernorm = Glm4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_mlp_layernorm = Glm4RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -68,7 +70,7 @@ class Glm4DecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, diff --git a/src/transformers/models/glm4_moe/modeling_glm4_moe.py b/src/transformers/models/glm4_moe/modeling_glm4_moe.py index 81e1611795..8cda3a71b5 100644 --- a/src/transformers/models/glm4_moe/modeling_glm4_moe.py +++ b/src/transformers/models/glm4_moe/modeling_glm4_moe.py @@ -37,6 +37,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import check_model_inputs from .configuration_glm4_moe import Glm4MoeConfig @@ -152,12 +153,13 @@ class Glm4MoeAttention(nn.Module): self.q_norm = Glm4MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = Glm4MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -179,10 +181,10 @@ class Glm4MoeAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; position_ids needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -359,12 +361,13 @@ class Glm4MoeDecoderLayer(GradientCheckpointingLayer): self.input_layernorm = Glm4MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Glm4MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -377,7 +380,7 @@ class Glm4MoeDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -520,7 +523,7 @@ class Glm4MoeModel(Glm4MoePreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, diff --git a/src/transformers/models/glm4v/modeling_glm4v.py b/src/transformers/models/glm4v/modeling_glm4v.py index c93ffe8bd1..0ee6d25f91 100644 --- a/src/transformers/models/glm4v/modeling_glm4v.py +++ b/src/transformers/models/glm4v/modeling_glm4v.py @@ -39,6 +39,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_glm4v import Glm4vConfig, Glm4vTextConfig, Glm4vVisionConfig @@ -650,13 +651,14 @@ class Glm4vTextAttention(nn.Module): self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -677,9 +679,9 @@ class Glm4vTextAttention(nn.Module): query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] ) - if past_key_value is not None: + if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -698,7 +700,7 @@ class Glm4vTextAttention(nn.Module): attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights, past_key_values class Glm4vTextMLP(nn.Module): @@ -730,13 +732,14 @@ class Glm4vTextDecoderLayer(GradientCheckpointingLayer): self.post_self_attn_layernorm = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_mlp_layernorm = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -752,7 +755,7 @@ class Glm4vTextDecoderLayer(GradientCheckpointingLayer): position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -902,7 +905,7 @@ class Glm4vTextModel(Glm4vPreTrainedModel): position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, diff --git a/src/transformers/models/glm4v/modular_glm4v.py b/src/transformers/models/glm4v/modular_glm4v.py index df32beb5da..414fac94cb 100644 --- a/src/transformers/models/glm4v/modular_glm4v.py +++ b/src/transformers/models/glm4v/modular_glm4v.py @@ -36,6 +36,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import ImagesKwargs, Unpack from ...tokenization_utils_base import PreTokenizedInput, TextInput from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging +from ...utils.deprecation import deprecate_kwarg from ...video_utils import VideoInput from ..glm4.modeling_glm4 import Glm4MLP, Glm4RMSNorm, eager_attention_forward from ..qwen2_5_vl.configuration_qwen2_5_vl import Qwen2_5_VLConfig @@ -730,13 +731,14 @@ class Glm4vTextAttention(nn.Module): self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -757,9 +759,9 @@ class Glm4vTextAttention(nn.Module): query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] ) - if past_key_value is not None: + if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -778,7 +780,7 @@ class Glm4vTextAttention(nn.Module): attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() attn_output = self.o_proj(attn_output) - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights, past_key_values class Glm4vTextMLP(Glm4MLP): @@ -796,13 +798,14 @@ class Glm4vTextDecoderLayer(GradientCheckpointingLayer): self.post_self_attn_layernorm = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_mlp_layernorm = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -818,7 +821,7 @@ class Glm4vTextDecoderLayer(GradientCheckpointingLayer): position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -938,7 +941,7 @@ class Glm4vTextModel(Qwen2_5_VLTextModel): position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 3403ac3196..b08c2f718a 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -47,6 +47,7 @@ from ...utils import ( auto_docstring, logging, ) +from ...utils.deprecation import deprecate_kwarg from ...utils.model_parallel_utils import assert_device_map, get_device_map from .configuration_gpt2 import GPT2Config @@ -266,10 +267,11 @@ class GPT2Attention(nn.Module): return attn_output, attn_weights + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: Optional[tuple[torch.FloatTensor]], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, @@ -279,16 +281,16 @@ class GPT2Attention(nn.Module): **kwargs, ) -> tuple[Union[torch.Tensor, tuple[torch.Tensor]], ...]: is_cross_attention = encoder_hidden_states is not None - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_layer from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values if is_cross_attention: if not hasattr(self, "q_attn"): @@ -300,7 +302,7 @@ class GPT2Attention(nn.Module): attention_mask = encoder_attention_mask # Try to get key/value states from cache if possible - if past_key_value is not None and is_updated: + if past_key_values is not None and is_updated: key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values else: @@ -317,8 +319,8 @@ class GPT2Attention(nn.Module): shape_q = (*query_states.shape[:-1], -1, self.head_dim) query_states = query_states.view(shape_q).transpose(1, 2) - if (past_key_value is not None and not is_cross_attention) or ( - past_key_value is not None and is_cross_attention and not is_updated + if (past_key_values is not None and not is_cross_attention) or ( + past_key_values is not None and is_cross_attention and not is_updated ): # save all key/value_layer to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None @@ -327,7 +329,7 @@ class GPT2Attention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention @@ -393,10 +395,11 @@ class GPT2Block(GradientCheckpointingLayer): self.mlp = GPT2MLP(inner_dim, config) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: Optional[tuple[torch.FloatTensor]], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, @@ -410,7 +413,7 @@ class GPT2Block(GradientCheckpointingLayer): hidden_states = self.ln_1(hidden_states) attn_output, self_attn_weights = self.attn( hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, cache_position=cache_position, attention_mask=attention_mask, head_mask=head_mask, @@ -432,7 +435,7 @@ class GPT2Block(GradientCheckpointingLayer): hidden_states = self.ln_cross_attn(hidden_states) cross_attn_output, cross_attn_weights = self.crossattention( hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, diff --git a/src/transformers/models/gpt_neox/modeling_gpt_neox.py b/src/transformers/models/gpt_neox/modeling_gpt_neox.py index 908be720fd..fd89c071a1 100755 --- a/src/transformers/models/gpt_neox/modeling_gpt_neox.py +++ b/src/transformers/models/gpt_neox/modeling_gpt_neox.py @@ -27,6 +27,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import check_model_inputs from .configuration_gpt_neox import GPTNeoXConfig @@ -321,12 +322,13 @@ class GPTNeoXDecoderLayer(GradientCheckpointingLayer): self.input_layernorm = GPTNeoXRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = GPTNeoXRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -339,7 +341,7 @@ class GPTNeoXDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, diff --git a/src/transformers/models/gpt_oss/modeling_gpt_oss.py b/src/transformers/models/gpt_oss/modeling_gpt_oss.py index 0c07800955..81af054833 100644 --- a/src/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -34,6 +34,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import OutputRecorder, check_model_inputs from .configuration_gpt_oss import GptOssConfig @@ -285,12 +286,13 @@ class GptOssAttention(nn.Module): self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None self.sinks = nn.Parameter(torch.empty(config.num_attention_heads)) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: @@ -304,9 +306,9 @@ class GptOssAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: cache_kwargs = {"cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -340,12 +342,13 @@ class GptOssDecoderLayer(GradientCheckpointingLayer): self.post_attention_layernorm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.attention_type = config.layer_types[layer_idx] + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -358,7 +361,7 @@ class GptOssDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -494,7 +497,7 @@ class GptOssModel(GptOssPreTrainedModel): hidden_states, attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, diff --git a/src/transformers/models/gpt_oss/modular_gpt_oss.py b/src/transformers/models/gpt_oss/modular_gpt_oss.py index b49366362b..891ab3a506 100644 --- a/src/transformers/models/gpt_oss/modular_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modular_gpt_oss.py @@ -32,6 +32,7 @@ from ...utils import ( auto_docstring, logging, ) +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import OutputRecorder, check_model_inputs from ..llama.modeling_llama import ( LlamaDecoderLayer, @@ -242,12 +243,13 @@ class GptOssAttention(Qwen2Attention): ) self.sinks = nn.Parameter(torch.empty(config.num_attention_heads)) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: @@ -261,9 +263,9 @@ class GptOssAttention(Qwen2Attention): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: cache_kwargs = {"cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -297,12 +299,13 @@ class GptOssDecoderLayer(LlamaDecoderLayer): self.post_attention_layernorm = GptOssRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.attention_type = config.layer_types[layer_idx] + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -315,7 +318,7 @@ class GptOssDecoderLayer(LlamaDecoderLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -423,7 +426,7 @@ class GptOssModel(MixtralModel): hidden_states, attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index c614c1bcba..318be2bec1 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -35,6 +35,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import check_model_inputs from .configuration_granite import GraniteConfig @@ -140,12 +141,13 @@ class GraniteAttention(nn.Module): config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: @@ -159,10 +161,10 @@ class GraniteAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -232,12 +234,13 @@ class GraniteDecoderLayer(GradientCheckpointingLayer): self.post_attention_layernorm = GraniteRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.residual_multiplier = config.residual_multiplier + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -256,7 +259,7 @@ class GraniteDecoderLayer(GradientCheckpointingLayer): use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): @@ -275,7 +278,7 @@ class GraniteDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -445,7 +448,7 @@ class GraniteModel(GranitePreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, diff --git a/src/transformers/models/granite/modular_granite.py b/src/transformers/models/granite/modular_granite.py index 9dcd2c1d1b..5e447e80aa 100644 --- a/src/transformers/models/granite/modular_granite.py +++ b/src/transformers/models/granite/modular_granite.py @@ -24,6 +24,7 @@ from ...masking_utils import create_causal_mask from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast from ...processing_utils import Unpack from ...utils import TransformersKwargs, logging +from ...utils.deprecation import deprecate_kwarg from ..llama.modeling_llama import ( LlamaAttention, LlamaDecoderLayer, @@ -51,12 +52,13 @@ class GraniteDecoderLayer(LlamaDecoderLayer): self.residual_multiplier = config.residual_multiplier self.self_attn = GraniteAttention(config=config, layer_idx=layer_idx) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -75,7 +77,7 @@ class GraniteDecoderLayer(LlamaDecoderLayer): use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): @@ -94,7 +96,7 @@ class GraniteDecoderLayer(LlamaDecoderLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -200,7 +202,7 @@ class GraniteModel(LlamaModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, diff --git a/src/transformers/models/granitemoe/modeling_granitemoe.py b/src/transformers/models/granitemoe/modeling_granitemoe.py index 13ab0dae0c..012a0581a8 100644 --- a/src/transformers/models/granitemoe/modeling_granitemoe.py +++ b/src/transformers/models/granitemoe/modeling_granitemoe.py @@ -28,6 +28,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, MoeCausalLMOutputWithPa from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...utils import auto_docstring, is_torch_flex_attn_available, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_granitemoe import GraniteMoeConfig @@ -422,12 +423,13 @@ class GraniteMoeAttention(nn.Module): self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # None or rope embeddings @@ -447,10 +449,10 @@ class GraniteMoeAttention(nn.Module): if position_embeddings is not None: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -513,12 +515,13 @@ class GraniteMoeDecoderLayer(GradientCheckpointingLayer): self.residual_multiplier = config.residual_multiplier + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -538,7 +541,7 @@ class GraniteMoeDecoderLayer(GradientCheckpointingLayer): use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence output_router_logits (`bool`, *optional*): @@ -560,7 +563,7 @@ class GraniteMoeDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -710,7 +713,7 @@ class GraniteMoeModel(GraniteMoePreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, diff --git a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py index 109c60eab9..91439bb2a3 100644 --- a/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modeling_granitemoehybrid.py @@ -36,6 +36,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils.deprecation import deprecate_kwarg from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available from .configuration_granitemoehybrid import GraniteMoeHybridConfig @@ -171,12 +172,13 @@ class GraniteMoeHybridAttention(nn.Module): self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # None or rope embeddings @@ -196,10 +198,10 @@ class GraniteMoeHybridAttention(nn.Module): if position_embeddings is not None: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -1141,11 +1143,12 @@ class GraniteMoeHybridDecoderLayer(GradientCheckpointingLayer): # Accept 0 experts: skip MoE if num_local_experts == 0 self.has_experts = getattr(config, "num_local_experts", 0) > 0 + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -1159,7 +1162,7 @@ class GraniteMoeHybridDecoderLayer(GradientCheckpointingLayer): attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, query_sequence_length, key_sequence_length)` if default attention is used. - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -1185,7 +1188,7 @@ class GraniteMoeHybridDecoderLayer(GradientCheckpointingLayer): hidden_states = self.mamba( hidden_states=hidden_states, cache_position=cache_position, - cache_params=past_key_value, + cache_params=past_key_values, attention_mask=attention_mask, **kwargs, ) @@ -1195,7 +1198,7 @@ class GraniteMoeHybridDecoderLayer(GradientCheckpointingLayer): hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -1398,7 +1401,7 @@ class GraniteMoeHybridModel(GraniteMoeHybridPreTrainedModel): layer_outputs = decoder_layer( hidden_states, attention_mask=layer_mask, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, diff --git a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py index 242c95b907..25151b6936 100644 --- a/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py +++ b/src/transformers/models/granitemoehybrid/modular_granitemoehybrid.py @@ -22,6 +22,7 @@ from ...cache_utils import Cache from ...modeling_outputs import BaseModelOutputWithPast, MoeModelOutputWithPast from ...processing_utils import Unpack from ...utils import auto_docstring, can_return_tuple, logging +from ...utils.deprecation import deprecate_kwarg from ..bamba.configuration_bamba import BambaConfig from ..bamba.modeling_bamba import BambaMixer, BambaRMSNormGated, HybridMambaAttentionDynamicCache from ..granitemoeshared.modeling_granitemoeshared import ( @@ -76,11 +77,12 @@ class GraniteMoeHybridDecoderLayer(GraniteMoeSharedDecoderLayer): # Accept 0 experts: skip MoE if num_local_experts == 0 self.has_experts = getattr(config, "num_local_experts", 0) > 0 + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -94,7 +96,7 @@ class GraniteMoeHybridDecoderLayer(GraniteMoeSharedDecoderLayer): attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, query_sequence_length, key_sequence_length)` if default attention is used. - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -120,7 +122,7 @@ class GraniteMoeHybridDecoderLayer(GraniteMoeSharedDecoderLayer): hidden_states = self.mamba( hidden_states=hidden_states, cache_position=cache_position, - cache_params=past_key_value, + cache_params=past_key_values, attention_mask=attention_mask, **kwargs, ) @@ -130,7 +132,7 @@ class GraniteMoeHybridDecoderLayer(GraniteMoeSharedDecoderLayer): hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, attention_mask=attention_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -267,7 +269,7 @@ class GraniteMoeHybridModel(GraniteMoeSharedModel): layer_outputs = decoder_layer( hidden_states, attention_mask=layer_mask, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, diff --git a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py index 1acac400cc..fa7577735c 100644 --- a/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py +++ b/src/transformers/models/granitemoeshared/modeling_granitemoeshared.py @@ -35,6 +35,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, is_torch_flex_attn_available, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_granitemoeshared import GraniteMoeSharedConfig @@ -381,12 +382,13 @@ class GraniteMoeSharedAttention(nn.Module): self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=config.attention_bias) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # None or rope embeddings @@ -406,10 +408,10 @@ class GraniteMoeSharedAttention(nn.Module): if position_embeddings is not None: query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -446,12 +448,13 @@ class GraniteMoeSharedDecoderLayer(GradientCheckpointingLayer): self.residual_multiplier = config.residual_multiplier self.shared_mlp = None if config.shared_intermediate_size == 0 else GraniteMoeSharedMLP(config) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -471,7 +474,7 @@ class GraniteMoeSharedDecoderLayer(GradientCheckpointingLayer): use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence output_router_logits (`bool`, *optional*): @@ -493,7 +496,7 @@ class GraniteMoeSharedDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -684,7 +687,7 @@ class GraniteMoeSharedModel(GraniteMoeSharedPreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, diff --git a/src/transformers/models/granitemoeshared/modular_granitemoeshared.py b/src/transformers/models/granitemoeshared/modular_granitemoeshared.py index 630e5aa184..4170deca2e 100644 --- a/src/transformers/models/granitemoeshared/modular_granitemoeshared.py +++ b/src/transformers/models/granitemoeshared/modular_granitemoeshared.py @@ -22,6 +22,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache from ...processing_utils import Unpack from ...utils import logging +from ...utils.deprecation import deprecate_kwarg from ..granitemoe.modeling_granitemoe import ( GraniteMoeDecoderLayer, GraniteMoeForCausalLM, @@ -90,12 +91,13 @@ class GraniteMoeSharedDecoderLayer(GraniteMoeDecoderLayer): super().__init__(config, layer_idx) self.shared_mlp = None if config.shared_intermediate_size == 0 else GraniteMoeSharedMLP(config) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -115,7 +117,7 @@ class GraniteMoeSharedDecoderLayer(GraniteMoeDecoderLayer): use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence output_router_logits (`bool`, *optional*): @@ -137,7 +139,7 @@ class GraniteMoeSharedDecoderLayer(GraniteMoeDecoderLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index 44d687d46d..4cca435c10 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -39,6 +39,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import check_model_inputs from .configuration_helium import HeliumConfig @@ -214,12 +215,13 @@ class HeliumAttention(nn.Module): ) self.o_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=False) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: @@ -233,10 +235,10 @@ class HeliumAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -269,12 +271,13 @@ class HeliumDecoderLayer(GradientCheckpointingLayer): self.input_layernorm = HeliumRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = HeliumRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -287,7 +290,7 @@ class HeliumDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -388,7 +391,7 @@ class HeliumModel(HeliumPreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, diff --git a/src/transformers/models/idefics/modeling_idefics.py b/src/transformers/models/idefics/modeling_idefics.py index e47ddbe273..4a678a97fa 100644 --- a/src/transformers/models/idefics/modeling_idefics.py +++ b/src/transformers/models/idefics/modeling_idefics.py @@ -37,6 +37,7 @@ from ...modeling_outputs import ModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PretrainedConfig, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_idefics import IdeficsConfig from .perceiver import IdeficsPerceiverResampler from .vision import IdeficsVisionEmbeddings, IdeficsVisionTransformer @@ -573,13 +574,14 @@ class IdeficsAttention(nn.Module): def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -602,7 +604,7 @@ class IdeficsAttention(nn.Module): ) kv_seq_len = key_states.shape[-2] - if past_key_value is not None: + if past_key_values is not None: kv_seq_len += cache_position[0] if not is_cross_attention: @@ -610,10 +612,10 @@ class IdeficsAttention(nn.Module): query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) # [bsz, nh, t, hd] - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) if self.qk_layer_norms: query_states = self.q_layer_norm(query_states) @@ -671,12 +673,13 @@ class IdeficsDecoderLayer(GradientCheckpointingLayer): self.post_attention_layernorm = IdeficsRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.dropout = config.dropout + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -693,7 +696,7 @@ class IdeficsDecoderLayer(GradientCheckpointingLayer): use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ residual = hidden_states @@ -705,7 +708,7 @@ class IdeficsDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -796,6 +799,7 @@ class IdeficsGatedCrossAttentionLayer(GradientCheckpointingLayer): if not (hasattr(self, "alpha_cross_attn") and hasattr(self, "alpha_dense")): raise ValueError("Alpha parameters not initialized correctly!") + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -805,7 +809,7 @@ class IdeficsGatedCrossAttentionLayer(GradientCheckpointingLayer): cross_attention_gate: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, **kwargs, ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -823,7 +827,7 @@ class IdeficsGatedCrossAttentionLayer(GradientCheckpointingLayer): use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ if image_hidden_states is None: raise ValueError( @@ -836,7 +840,7 @@ class IdeficsGatedCrossAttentionLayer(GradientCheckpointingLayer): "`cross_attention_gate` is required for Idefics cross attention module to zero-out the cross-attention hidden_states attending to no images." ) - if past_key_value is not None: + if past_key_values is not None: raise NotImplementedError("Past key value states are not implemented for Idefics cross attention module.") residual = hidden_states @@ -1158,7 +1162,7 @@ class IdeficsModel(IdeficsPreTrainedModel): cross_attention_gate=cross_attention_gate, output_attentions=output_attentions, use_cache=use_cache, - past_key_value=None, # not implemented + past_key_values=None, # not implemented **kwargs, ) hidden_states = outputs[0] @@ -1167,7 +1171,7 @@ class IdeficsModel(IdeficsPreTrainedModel): hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, diff --git a/src/transformers/models/idefics2/modeling_idefics2.py b/src/transformers/models/idefics2/modeling_idefics2.py index 97ad4c7116..18d7702a3d 100644 --- a/src/transformers/models/idefics2/modeling_idefics2.py +++ b/src/transformers/models/idefics2/modeling_idefics2.py @@ -31,6 +31,7 @@ from ...modeling_outputs import BaseModelOutput, ModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.deprecation import deprecate_kwarg from ..auto import AutoModel from .configuration_idefics2 import Idefics2Config, Idefics2PerceiverConfig, Idefics2VisionConfig @@ -634,13 +635,14 @@ class Idefics2PerceiverAttention(nn.Module): self.is_causal = False + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, latents: torch.Tensor, context: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: @@ -652,9 +654,9 @@ class Idefics2PerceiverAttention(nn.Module): context (`torch.Tensor`): Tensor of shape [bsz, seq, embed_dim] representing long-form context to resample. attention_mask (`torch.Tensor`, *optional*): Tensor of shape [bsz, 1, seq, n_latents] representing attention mask. position_ids (`torch.LongTensor`, *optional*): Tensor of shape [bsz, seq] representing position indices of each input token. - past_key_value (`tuple[torch.Tensor]`, *optional*): Tuple of tensors containing cached key and value states. + past_key_values (`tuple[torch.Tensor]`, *optional*): Tuple of tensors containing cached key and value states. output_attentions (`bool`, *optional*, defaults to `False`): Whether to return attention weights. - use_cache (`bool`, *optional*, defaults to `False`): Whether to use past_key_value for caching. + use_cache (`bool`, *optional*, defaults to `False`): Whether to use past_key_values for caching. """ bsz, q_len, _ = latents.size() kv_seq_len = q_len + context.size()[1] @@ -669,10 +671,10 @@ class Idefics2PerceiverAttention(nn.Module): keys = keys.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) values = values.view(bsz, kv_seq_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - past_key_value = getattr(self, "past_key_value", past_key_value) + past_key_values = getattr(self, "past_key_values", past_key_values) - if past_key_value is not None: - keys, values = past_key_value.update(keys, values, self.layer_idx) + if past_key_values is not None: + keys, values = past_key_values.update(keys, values, self.layer_idx) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -701,7 +703,7 @@ class Idefics2PerceiverAttention(nn.Module): if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights, past_key_values class Idefics2PerceiverLayer(nn.Module): @@ -723,13 +725,14 @@ class Idefics2PerceiverLayer(nn.Module): hidden_act=config.hidden_act, ) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, latents: torch.Tensor, context: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, **kwargs, @@ -746,7 +749,7 @@ class Idefics2PerceiverLayer(nn.Module): use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ residual = latents @@ -834,7 +837,7 @@ class Idefics2PerceiverResampler(Idefics2PreTrainedModel): context, attention_mask=attention_mask, position_ids=None, - past_key_value=None, + past_key_values=None, output_attentions=False, use_cache=False, ) diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index 223df2e819..8ac7e905a2 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -46,6 +46,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput from ...utils import auto_docstring, is_torch_flex_attn_available, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_informer import InformerConfig @@ -433,11 +434,12 @@ class InformerAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, @@ -462,19 +464,19 @@ class InformerAttention(nn.Module): # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values @@ -484,7 +486,7 @@ class InformerAttention(nn.Module): key_states = key_states.view(*kv_input_shape).transpose(1, 2) value_states = value_states.view(*kv_input_shape).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( @@ -492,7 +494,7 @@ class InformerAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -556,11 +558,12 @@ class InformerProbSparseAttention(nn.Module): def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, @@ -579,19 +582,19 @@ class InformerProbSparseAttention(nn.Module): # get query proj query_states = self.q_proj(hidden_states) * self.scaling - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values @@ -601,7 +604,7 @@ class InformerProbSparseAttention(nn.Module): key_states = key_states.view(*kv_input_shape).transpose(1, 2) value_states = value_states.view(*kv_input_shape).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( @@ -609,7 +612,7 @@ class InformerProbSparseAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) @@ -879,6 +882,7 @@ class InformerDecoderLayer(GradientCheckpointingLayer): layer_idx=layer_idx, ) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -887,7 +891,7 @@ class InformerDecoderLayer(GradientCheckpointingLayer): encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, cache_position: Optional[torch.Tensor] = None, @@ -905,7 +909,7 @@ class InformerDecoderLayer(GradientCheckpointingLayer): `(encoder_attention_heads,)`. cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of size `(decoder_attention_heads,)`. - past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -918,7 +922,7 @@ class InformerDecoderLayer(GradientCheckpointingLayer): # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, @@ -938,7 +942,7 @@ class InformerDecoderLayer(GradientCheckpointingLayer): key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) @@ -1283,7 +1287,7 @@ class InformerDecoder(InformerPreTrainedModel): encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, diff --git a/src/transformers/models/informer/modular_informer.py b/src/transformers/models/informer/modular_informer.py index f1f73f505f..f8fb02c1c5 100644 --- a/src/transformers/models/informer/modular_informer.py +++ b/src/transformers/models/informer/modular_informer.py @@ -37,6 +37,7 @@ from ...utils import ( auto_docstring, is_torch_flex_attn_available, ) +from ...utils.deprecation import deprecate_kwarg from ..bart.modeling_bart import BartAttention from ..time_series_transformer.modeling_time_series_transformer import ( TimeSeriesFeatureEmbedder, @@ -245,11 +246,12 @@ class InformerProbSparseAttention(nn.Module): def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, @@ -268,19 +270,19 @@ class InformerProbSparseAttention(nn.Module): # get query proj query_states = self.q_proj(hidden_states) * self.scaling - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values @@ -290,7 +292,7 @@ class InformerProbSparseAttention(nn.Module): key_states = key_states.view(*kv_input_shape).transpose(1, 2) value_states = value_states.view(*kv_input_shape).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( @@ -298,7 +300,7 @@ class InformerProbSparseAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 6db0a49dc6..4fe7d6cee1 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -40,6 +40,7 @@ from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPas from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.deprecation import deprecate_kwarg from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available from .configuration_jamba import JambaConfig @@ -317,12 +318,13 @@ class JambaAttention(nn.Module): self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -337,8 +339,8 @@ class JambaAttention(nn.Module): key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - if past_key_value is not None: - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx) + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -369,7 +371,7 @@ class JambaAttention(nn.Module): if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights, past_key_values # Adapted from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Jamba @@ -393,7 +395,7 @@ class JambaFlashAttention2(JambaAttention): hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -412,8 +414,8 @@ class JambaFlashAttention2(JambaAttention): key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - if past_key_value is not None: - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx) + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -470,7 +472,7 @@ class JambaFlashAttention2(JambaAttention): if not output_attentions: attn_weights = None - return attn_output, attn_weights, past_key_value + return attn_output, attn_weights, past_key_values # Adapted from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Jamba @@ -482,12 +484,13 @@ class JambaSdpaAttention(JambaAttention): """ # Adapted from JambaAttention.forward + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -502,7 +505,7 @@ class JambaSdpaAttention(JambaAttention): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, ) @@ -517,8 +520,8 @@ class JambaSdpaAttention(JambaAttention): key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - if past_key_value is not None: - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx) + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -553,7 +556,7 @@ class JambaSdpaAttention(JambaAttention): attn_output = self.o_proj(attn_output) - return attn_output, None, past_key_value + return attn_output, None, past_key_values JAMBA_ATTENTION_CLASSES = { @@ -923,12 +926,13 @@ class JambaAttentionDecoderLayer(GradientCheckpointingLayer): self.input_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.pre_ff_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, output_attentions: Optional[bool] = False, output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -939,7 +943,7 @@ class JambaAttentionDecoderLayer(GradientCheckpointingLayer): hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + past_key_values (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -961,7 +965,7 @@ class JambaAttentionDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -1005,12 +1009,13 @@ class JambaMambaDecoderLayer(GradientCheckpointingLayer): self.input_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.pre_ff_layernorm = JambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[HybridMambaAttentionDynamicCache] = None, + past_key_values: Optional[HybridMambaAttentionDynamicCache] = None, output_attentions: Optional[bool] = False, output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -1021,7 +1026,7 @@ class JambaMambaDecoderLayer(GradientCheckpointingLayer): hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_value (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states + past_key_values (`HybridMambaAttentionDynamicCache`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -1041,7 +1046,7 @@ class JambaMambaDecoderLayer(GradientCheckpointingLayer): hidden_states = self.mamba( hidden_states=hidden_states, - cache_params=past_key_value, + cache_params=past_key_values, attention_mask=attention_mask, ) self_attn_weights = None @@ -1065,7 +1070,7 @@ class JambaMambaDecoderLayer(GradientCheckpointingLayer): outputs += (self_attn_weights,) if use_cache: - outputs += (past_key_value,) + outputs += (past_key_values,) if output_router_logits: outputs += (router_logits,) @@ -1203,7 +1208,7 @@ class JambaModel(JambaPreTrainedModel): hidden_states, attention_mask=layer_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, output_router_logits=output_router_logits, use_cache=use_cache, diff --git a/src/transformers/models/jetmoe/modeling_jetmoe.py b/src/transformers/models/jetmoe/modeling_jetmoe.py index 37095197bc..2159761547 100644 --- a/src/transformers/models/jetmoe/modeling_jetmoe.py +++ b/src/transformers/models/jetmoe/modeling_jetmoe.py @@ -35,6 +35,7 @@ from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPas from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_jetmoe import JetMoeConfig @@ -498,12 +499,13 @@ class JetMoeAttention(nn.Module): self.rotary_emb = JetMoeRotaryEmbedding(config) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -520,10 +522,10 @@ class JetMoeAttention(nn.Module): cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) # repeat k/v heads for top-k attention experts key_states = key_states.repeat(1, self.top_k, 1, 1) @@ -566,12 +568,13 @@ class JetMoeSdpaAttention(JetMoeAttention): """ # Adapted from JetMoeAttention.forward + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -586,7 +589,7 @@ class JetMoeSdpaAttention(JetMoeAttention): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -604,10 +607,10 @@ class JetMoeSdpaAttention(JetMoeAttention): cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) # repeat k/v heads for top-k attention experts key_states = key_states.repeat(1, self.top_k, 1, 1) @@ -655,12 +658,13 @@ class JetMoeFlashAttention2(JetMoeAttention): # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: Optional[torch.FloatTensor], attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -696,10 +700,10 @@ class JetMoeFlashAttention2(JetMoeAttention): cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) # repeat k/v heads for top-k attention experts key_states = key_states.repeat(1, self.top_k, 1, 1) @@ -789,11 +793,12 @@ class JetMoeBlock(GradientCheckpointingLayer): self.mlp = JetMoeMoE(config) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: Optional[torch.FloatTensor], position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, attention_mask: Optional[torch.FloatTensor] = None, output_attentions: Optional[bool] = False, output_router_logits: Optional[bool] = False, @@ -805,7 +810,7 @@ class JetMoeBlock(GradientCheckpointingLayer): hidden_states=self.input_layernorm(hidden_states), attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -961,7 +966,7 @@ class JetMoeModel(JetMoePreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, output_router_logits=output_router_logits, use_cache=use_cache, diff --git a/src/transformers/models/kosmos2/modeling_kosmos2.py b/src/transformers/models/kosmos2/modeling_kosmos2.py index 57419c150c..92aa3ace5a 100644 --- a/src/transformers/models/kosmos2/modeling_kosmos2.py +++ b/src/transformers/models/kosmos2/modeling_kosmos2.py @@ -36,6 +36,7 @@ from ...modeling_outputs import ( from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_int +from ...utils.deprecation import deprecate_kwarg from .configuration_kosmos2 import Kosmos2Config, Kosmos2TextConfig, Kosmos2VisionConfig @@ -709,11 +710,12 @@ class KosmosTextAttention(nn.Module): if add_inner_attn_layernorm: self.inner_attn_ln = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, @@ -730,19 +732,19 @@ class KosmosTextAttention(nn.Module): query_states = self.q_proj(hidden_states) query_states = query_states.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = encoder_hidden_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values @@ -752,7 +754,7 @@ class KosmosTextAttention(nn.Module): key_states = key_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( @@ -760,7 +762,7 @@ class KosmosTextAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True attention_interface: Callable = eager_attention_forward @@ -848,6 +850,7 @@ class Kosmos2TextBlock(GradientCheckpointingLayer): self.ffn = Kosmos2TextFFN(config) self.final_layer_norm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -856,7 +859,7 @@ class Kosmos2TextBlock(GradientCheckpointingLayer): encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, cache_position: Optional[torch.Tensor] = None, @@ -867,7 +870,7 @@ class Kosmos2TextBlock(GradientCheckpointingLayer): hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, @@ -894,7 +897,7 @@ class Kosmos2TextBlock(GradientCheckpointingLayer): encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, **kwargs, @@ -1114,7 +1117,7 @@ class Kosmos2TextTransformer(nn.Module): encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -1495,7 +1498,7 @@ class Kosmos2ImageToTextProjection(nn.Module): hidden_states, attn_weights = self.x_attn( hidden_states=latent_query, encoder_hidden_states=key_value_states, - past_key_value=None, + past_key_values=None, attention_mask=None, output_attentions=None, ) diff --git a/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py b/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py index 304f7261de..89f269f8a0 100644 --- a/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +++ b/src/transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py @@ -37,6 +37,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils.deprecation import deprecate_kwarg from ..auto import AutoModel from .configuration_kyutai_speech_to_text import KyutaiSpeechToTextConfig @@ -428,14 +429,13 @@ class KyutaiSpeechToTextAttention(nn.Module): self.rope_theta = config.rope_theta self.rotary_emb = KyutaiSpeechToTextRotaryEmbedding(config) - # copied from transformers.models.gemma.modeling_gemma.GemmaAttention.forward - # no longer copied after attention refactors + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -454,14 +454,14 @@ class KyutaiSpeechToTextAttention(nn.Module): cos, sin = self.rotary_emb(value_states, position_ids) # Ignore copy query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # Ignore copy - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = ( {"sin": sin, "cos": cos, "cache_position": cache_position} if self.rotary_emb is not None else {"cache_position": cache_position} ) # Ignore copy - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -511,17 +511,18 @@ class KyutaiSpeechToTextFlashAttention2(KyutaiSpeechToTextAttention): # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: - if isinstance(past_key_value, StaticCache): + if isinstance(past_key_values, StaticCache): raise ValueError( "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" @@ -546,14 +547,14 @@ class KyutaiSpeechToTextFlashAttention2(KyutaiSpeechToTextAttention): cos, sin = self.rotary_emb(value_states, position_ids) # Ignore copy query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # Ignore copy - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = ( {"sin": sin, "cos": cos, "cache_position": cache_position} if self.rotary_emb is not None else {"cache_position": cache_position} ) # Ignore copy - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. @@ -626,12 +627,13 @@ class KyutaiSpeechToTextSdpaAttention(KyutaiSpeechToTextAttention): """ # Adapted from KyutaiSpeechToTextAttention.forward + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -647,7 +649,7 @@ class KyutaiSpeechToTextSdpaAttention(KyutaiSpeechToTextAttention): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -667,14 +669,14 @@ class KyutaiSpeechToTextSdpaAttention(KyutaiSpeechToTextAttention): cos, sin = self.rotary_emb(value_states, position_ids) # Ignore copy query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # Ignore copy - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = ( {"sin": sin, "cos": cos, "cache_position": cache_position} if self.rotary_emb is not None else {"cache_position": cache_position} ) # Ignore copy - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -735,12 +737,13 @@ class KyutaiSpeechToTextDecoderLayer(GradientCheckpointingLayer): self._attn_implementation = config._attn_implementation + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -758,7 +761,7 @@ class KyutaiSpeechToTextDecoderLayer(GradientCheckpointingLayer): use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence kwargs (`dict`, *optional*): @@ -774,7 +777,7 @@ class KyutaiSpeechToTextDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -884,7 +887,7 @@ class KyutaiSpeechToTextModel(KyutaiSpeechToTextPreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, diff --git a/src/transformers/models/led/modeling_led.py b/src/transformers/models/led/modeling_led.py index 639fa3df4d..9ca56d8931 100755 --- a/src/transformers/models/led/modeling_led.py +++ b/src/transformers/models/led/modeling_led.py @@ -32,6 +32,7 @@ from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_led import LEDConfig @@ -790,11 +791,12 @@ class LEDDecoderAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, @@ -810,19 +812,19 @@ class LEDDecoderAttention(nn.Module): # get query proj query_states = self.q_proj(hidden_states) * self.scaling - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values @@ -832,7 +834,7 @@ class LEDDecoderAttention(nn.Module): key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( @@ -840,7 +842,7 @@ class LEDDecoderAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) @@ -903,7 +905,7 @@ class LEDDecoderAttention(nn.Module): attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights_reshaped, past_key_values class LEDEncoderLayer(GradientCheckpointingLayer): @@ -997,6 +999,7 @@ class LEDDecoderLayer(GradientCheckpointingLayer): self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -1005,7 +1008,7 @@ class LEDDecoderLayer(GradientCheckpointingLayer): encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, cache_position: Optional[torch.Tensor] = None, @@ -1023,7 +1026,7 @@ class LEDDecoderLayer(GradientCheckpointingLayer): *(decoder_attention_heads,)*. cross_attn_layer_head_mask (`torch.FloatTensor`): mask for encoder attention heads in a given layer of size *(decoder_attention_heads,)*. - past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (`bool`): Whether the base model outputs attentions. This requires the attentions tensor to be reshaped in this function. """ @@ -1032,7 +1035,7 @@ class LEDDecoderLayer(GradientCheckpointingLayer): # Self-Attention hidden_states, self_attn_weights, present_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, @@ -1053,7 +1056,7 @@ class LEDDecoderLayer(GradientCheckpointingLayer): key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) @@ -1076,7 +1079,7 @@ class LEDDecoderLayer(GradientCheckpointingLayer): outputs += (self_attn_weights, cross_attn_weights) if use_cache: - outputs += (past_key_value,) + outputs += (past_key_values,) return outputs @@ -1826,7 +1829,7 @@ class LEDDecoder(LEDPreTrainedModel): encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -1980,7 +1983,7 @@ class LEDModel(LEDPreTrainedModel): global_attentions=encoder_outputs[3] if len(encoder_outputs) > 3 else None, ) - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, diff --git a/src/transformers/models/lfm2/modeling_lfm2.py b/src/transformers/models/lfm2/modeling_lfm2.py index 5e8e6eaeb6..7fc244cb58 100644 --- a/src/transformers/models/lfm2/modeling_lfm2.py +++ b/src/transformers/models/lfm2/modeling_lfm2.py @@ -33,6 +33,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import check_model_inputs from ...utils.import_utils import is_causal_conv1d_available from .configuration_lfm2 import Lfm2Config @@ -359,12 +360,13 @@ class Lfm2Attention(nn.Module): self.q_layernorm = Lfm2RMSNorm(self.head_dim, eps=config.norm_eps) self.k_layernorm = Lfm2RMSNorm(self.head_dim, eps=config.norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Lfm2HybridConvCache] = None, + past_key_values: Optional[Lfm2HybridConvCache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -378,9 +380,9 @@ class Lfm2Attention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -439,10 +441,11 @@ class Lfm2ShortConv(nn.Module): self.in_proj = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=self.bias) self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=self.bias) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def cuda_kernels_forward( self, x: torch.Tensor, - past_key_value: Optional[Lfm2HybridConvCache] = None, + past_key_values: Optional[Lfm2HybridConvCache] = None, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, ): @@ -453,19 +456,19 @@ class Lfm2ShortConv(nn.Module): Bx = B * x conv_weights = self.conv.weight.view(self.conv.weight.size(0), self.conv.weight.size(2)) - if past_key_value is not None and cache_position[0] > 0: + if past_key_values is not None and cache_position[0] > 0: conv_out = causal_conv1d_update( Bx.squeeze(-1), - past_key_value.conv_cache[self.layer_idx], + past_key_values.conv_cache[self.layer_idx], conv_weights, self.conv.bias, None, ) conv_out = conv_out.unsqueeze(-1) else: - if past_key_value is not None: + if past_key_values is not None: conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0)) - past_key_value.conv_cache[self.layer_idx].copy_(conv_state) + past_key_values.conv_cache[self.layer_idx].copy_(conv_state) conv_out = causal_conv1d_fn(Bx, conv_weights, self.conv.bias, activation=None) @@ -473,10 +476,11 @@ class Lfm2ShortConv(nn.Module): y = self.out_proj(y.transpose(-1, -2).contiguous()) return y + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def slow_forward( self, x: torch.Tensor, - past_key_value: Optional[Lfm2HybridConvCache] = None, + past_key_values: Optional[Lfm2HybridConvCache] = None, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, ): @@ -488,21 +492,21 @@ class Lfm2ShortConv(nn.Module): Bx = B * x - if past_key_value is not None and cache_position[0] > 0: - conv_state = past_key_value.conv_cache[self.layer_idx] + if past_key_values is not None and cache_position[0] > 0: + conv_state = past_key_values.conv_cache[self.layer_idx] cache_position = cache_position.clamp(0, self.L_cache - 1) conv_state = conv_state.roll(shifts=-1, dims=-1) conv_state[:, :, cache_position] = Bx.to(device=conv_state.device, dtype=conv_state.dtype) - past_key_value.conv_cache[self.layer_idx].copy_(conv_state) + past_key_values.conv_cache[self.layer_idx].copy_(conv_state) conv_out = torch.sum(conv_state.to(Bx.device) * self.conv.weight[:, 0, :], dim=-1) if self.bias: conv_out += self.conv.bias conv_out = conv_out.unsqueeze(-1) else: - if past_key_value is not None: + if past_key_values is not None: conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0)) - past_key_value.conv_cache[self.layer_idx].copy_(conv_state) + past_key_values.conv_cache[self.layer_idx].copy_(conv_state) conv_out = self.conv(Bx)[..., :seqlen] @@ -511,16 +515,17 @@ class Lfm2ShortConv(nn.Module): y = self.out_proj(y) return y + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, - past_key_value: Optional[Lfm2HybridConvCache] = None, + past_key_values: Optional[Lfm2HybridConvCache] = None, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, ): if is_fast_path_available and "cuda" in hidden_states.device.type and not torch._dynamo.is_compiling(): - return self.cuda_kernels_forward(hidden_states, past_key_value, cache_position, attention_mask) - return self.slow_forward(hidden_states, past_key_value, cache_position, attention_mask) + return self.cuda_kernels_forward(hidden_states, past_key_values, cache_position, attention_mask) + return self.slow_forward(hidden_states, past_key_values, cache_position, attention_mask) class Lfm2DecoderLayer(GradientCheckpointingLayer): @@ -536,13 +541,14 @@ class Lfm2DecoderLayer(GradientCheckpointingLayer): self.operator_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps) self.ffn_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> torch.Tensor: @@ -553,14 +559,14 @@ class Lfm2DecoderLayer(GradientCheckpointingLayer): position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, cache_position=cache_position, **kwargs, ) else: hidden_states = self.conv( hidden_states=self.operator_norm(hidden_states), - past_key_value=past_key_value, + past_key_values=past_key_values, cache_position=cache_position, attention_mask=attention_mask, ) @@ -659,7 +665,7 @@ class Lfm2Model(Lfm2PreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, diff --git a/src/transformers/models/lfm2/modular_lfm2.py b/src/transformers/models/lfm2/modular_lfm2.py index d91b82926a..046d79dbdd 100644 --- a/src/transformers/models/lfm2/modular_lfm2.py +++ b/src/transformers/models/lfm2/modular_lfm2.py @@ -24,6 +24,7 @@ from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...utils import TransformersKwargs, logging +from ...utils.deprecation import deprecate_kwarg from ...utils.import_utils import is_causal_conv1d_available from ..bamba.modeling_bamba import apply_mask_to_padding_states from ..llama.modeling_llama import ( @@ -240,12 +241,13 @@ class Lfm2Attention(LlamaAttention): del self.o_proj del self.attention_dropout + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Lfm2HybridConvCache] = None, + past_key_values: Optional[Lfm2HybridConvCache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -259,9 +261,9 @@ class Lfm2Attention(LlamaAttention): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -305,10 +307,11 @@ class Lfm2ShortConv(nn.Module): self.in_proj = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=self.bias) self.out_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=self.bias) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def cuda_kernels_forward( self, x: torch.Tensor, - past_key_value: Optional[Lfm2HybridConvCache] = None, + past_key_values: Optional[Lfm2HybridConvCache] = None, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, ): @@ -319,19 +322,19 @@ class Lfm2ShortConv(nn.Module): Bx = B * x conv_weights = self.conv.weight.view(self.conv.weight.size(0), self.conv.weight.size(2)) - if past_key_value is not None and cache_position[0] > 0: + if past_key_values is not None and cache_position[0] > 0: conv_out = causal_conv1d_update( Bx.squeeze(-1), - past_key_value.conv_cache[self.layer_idx], + past_key_values.conv_cache[self.layer_idx], conv_weights, self.conv.bias, None, ) conv_out = conv_out.unsqueeze(-1) else: - if past_key_value is not None: + if past_key_values is not None: conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0)) - past_key_value.conv_cache[self.layer_idx].copy_(conv_state) + past_key_values.conv_cache[self.layer_idx].copy_(conv_state) conv_out = causal_conv1d_fn(Bx, conv_weights, self.conv.bias, activation=None) @@ -339,10 +342,11 @@ class Lfm2ShortConv(nn.Module): y = self.out_proj(y.transpose(-1, -2).contiguous()) return y + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def slow_forward( self, x: torch.Tensor, - past_key_value: Optional[Lfm2HybridConvCache] = None, + past_key_values: Optional[Lfm2HybridConvCache] = None, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, ): @@ -354,21 +358,21 @@ class Lfm2ShortConv(nn.Module): Bx = B * x - if past_key_value is not None and cache_position[0] > 0: - conv_state = past_key_value.conv_cache[self.layer_idx] + if past_key_values is not None and cache_position[0] > 0: + conv_state = past_key_values.conv_cache[self.layer_idx] cache_position = cache_position.clamp(0, self.L_cache - 1) conv_state = conv_state.roll(shifts=-1, dims=-1) conv_state[:, :, cache_position] = Bx.to(device=conv_state.device, dtype=conv_state.dtype) - past_key_value.conv_cache[self.layer_idx].copy_(conv_state) + past_key_values.conv_cache[self.layer_idx].copy_(conv_state) conv_out = torch.sum(conv_state.to(Bx.device) * self.conv.weight[:, 0, :], dim=-1) if self.bias: conv_out += self.conv.bias conv_out = conv_out.unsqueeze(-1) else: - if past_key_value is not None: + if past_key_values is not None: conv_state = nn.functional.pad(Bx, (self.L_cache - Bx.shape[-1], 0)) - past_key_value.conv_cache[self.layer_idx].copy_(conv_state) + past_key_values.conv_cache[self.layer_idx].copy_(conv_state) conv_out = self.conv(Bx)[..., :seqlen] @@ -377,16 +381,17 @@ class Lfm2ShortConv(nn.Module): y = self.out_proj(y) return y + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, - past_key_value: Optional[Lfm2HybridConvCache] = None, + past_key_values: Optional[Lfm2HybridConvCache] = None, cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, ): if is_fast_path_available and "cuda" in hidden_states.device.type and not torch._dynamo.is_compiling(): - return self.cuda_kernels_forward(hidden_states, past_key_value, cache_position, attention_mask) - return self.slow_forward(hidden_states, past_key_value, cache_position, attention_mask) + return self.cuda_kernels_forward(hidden_states, past_key_values, cache_position, attention_mask) + return self.slow_forward(hidden_states, past_key_values, cache_position, attention_mask) class Lfm2DecoderLayer(GradientCheckpointingLayer): @@ -402,13 +407,14 @@ class Lfm2DecoderLayer(GradientCheckpointingLayer): self.operator_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps) self.ffn_norm = Lfm2RMSNorm(config.hidden_size, eps=config.norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> torch.Tensor: @@ -419,14 +425,14 @@ class Lfm2DecoderLayer(GradientCheckpointingLayer): position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, cache_position=cache_position, **kwargs, ) else: hidden_states = self.conv( hidden_states=self.operator_norm(hidden_states), - past_key_value=past_key_value, + past_key_values=past_key_values, cache_position=cache_position, attention_mask=attention_mask, ) @@ -498,7 +504,7 @@ class Lfm2Model(LlamaModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, diff --git a/src/transformers/models/lightglue/modeling_lightglue.py b/src/transformers/models/lightglue/modeling_lightglue.py index 744c397c60..ee3b39ed9e 100644 --- a/src/transformers/models/lightglue/modeling_lightglue.py +++ b/src/transformers/models/lightglue/modeling_lightglue.py @@ -30,6 +30,7 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import ModelOutput, TransformersKwargs, auto_docstring +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import can_return_tuple from ..auto.modeling_auto import AutoModelForKeypointDetection from .configuration_lightglue import LightGlueConfig @@ -199,6 +200,7 @@ class LightGlueAttention(nn.Module): config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 71753d1708..06fbeba496 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -41,6 +41,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import check_model_inputs from .configuration_llama import LlamaConfig @@ -219,12 +220,13 @@ class LlamaAttention(nn.Module): config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, torch.Tensor]: @@ -238,10 +240,10 @@ class LlamaAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -274,12 +276,13 @@ class LlamaDecoderLayer(GradientCheckpointingLayer): self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -292,7 +295,7 @@ class LlamaDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -393,7 +396,7 @@ class LlamaModel(LlamaPreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, diff --git a/src/transformers/models/llama4/modeling_llama4.py b/src/transformers/models/llama4/modeling_llama4.py index 62a942ef3f..2aeb414d19 100644 --- a/src/transformers/models/llama4/modeling_llama4.py +++ b/src/transformers/models/llama4/modeling_llama4.py @@ -35,6 +35,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import check_model_inputs from .configuration_llama4 import Llama4Config, Llama4TextConfig @@ -308,12 +309,13 @@ class Llama4TextAttention(nn.Module): if self.config.use_qk_norm and self.use_rope: self.qk_norm = Llama4TextL2Norm(config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: @@ -344,10 +346,10 @@ class Llama4TextAttention(nn.Module): query_states = query_states.transpose(1, 2) key_states = key_states.transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -384,12 +386,13 @@ class Llama4TextDecoderLayer(GradientCheckpointingLayer): self.input_layernorm = Llama4TextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Llama4TextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -404,7 +407,7 @@ class Llama4TextDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, **kwargs, @@ -545,7 +548,7 @@ class Llama4TextModel(Llama4PreTrainedModel): hidden_states, attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=freq_cis, @@ -783,7 +786,7 @@ class Llama4VisionAttention(nn.Module): hidden_states: torch.Tensor, freqs_ci: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] diff --git a/src/transformers/models/longt5/modeling_longt5.py b/src/transformers/models/longt5/modeling_longt5.py index 87badf1ad9..1d89feab05 100644 --- a/src/transformers/models/longt5/modeling_longt5.py +++ b/src/transformers/models/longt5/modeling_longt5.py @@ -45,6 +45,7 @@ from ...utils import ( is_torchdynamo_compiling, logging, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_longt5 import LongT5Config @@ -441,13 +442,14 @@ class LongT5Attention(nn.Module): values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) return values + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, mask=None, key_value_states=None, position_bias=None, - past_key_value=None, + past_key_values=None, layer_head_mask=None, query_length=None, use_cache=False, @@ -468,18 +470,18 @@ class LongT5Attention(nn.Module): query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache` - if past_key_value is not None and isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None and isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values @@ -489,7 +491,7 @@ class LongT5Attention(nn.Module): key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( @@ -497,7 +499,7 @@ class LongT5Attention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 scores = torch.matmul(query_states, key_states.transpose(3, 2)) @@ -1018,13 +1020,14 @@ class LongT5LayerSelfAttention(nn.Module): self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, attention_mask=None, position_bias=None, layer_head_mask=None, - past_key_value=None, + past_key_values=None, use_cache=False, output_attentions=False, cache_position=None, @@ -1035,7 +1038,7 @@ class LongT5LayerSelfAttention(nn.Module): mask=attention_mask, position_bias=position_bias, layer_head_mask=layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, @@ -1061,7 +1064,7 @@ class LongT5LayerLocalSelfAttention(nn.Module): position_bias=None, layer_head_mask=None, output_attentions=False, - **kwargs: Any, # to accept past_key_value and use_cache kwargs + **kwargs: Any, # to accept past_key_values and use_cache kwargs ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.LocalSelfAttention( @@ -1094,7 +1097,7 @@ class LongT5LayerTransientGlobalSelfAttention(nn.Module): position_bias=None, layer_head_mask=None, output_attentions=False, - **kwargs: Any, # to accept past_key_value and use_cache kwargs + **kwargs: Any, # to accept past_key_values and use_cache kwargs ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.TransientGlobalSelfAttention( @@ -1117,6 +1120,7 @@ class LongT5LayerCrossAttention(nn.Module): self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, @@ -1124,7 +1128,7 @@ class LongT5LayerCrossAttention(nn.Module): attention_mask=None, position_bias=None, layer_head_mask=None, - past_key_value=None, + past_key_values=None, use_cache=False, query_length=None, output_attentions=False, @@ -1137,7 +1141,7 @@ class LongT5LayerCrossAttention(nn.Module): key_value_states=key_value_states, position_bias=position_bias, layer_head_mask=layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, query_length=query_length, output_attentions=output_attentions, @@ -1172,6 +1176,7 @@ class LongT5Block(GradientCheckpointingLayer): self.layer.append(LongT5LayerFF(config)) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, @@ -1182,7 +1187,7 @@ class LongT5Block(GradientCheckpointingLayer): encoder_decoder_position_bias=None, layer_head_mask=None, cross_attn_layer_head_mask=None, - past_key_value=None, + past_key_values=None, use_cache=False, output_attentions=False, return_dict=True, @@ -1193,7 +1198,7 @@ class LongT5Block(GradientCheckpointingLayer): attention_mask=attention_mask, position_bias=position_bias, layer_head_mask=layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, @@ -1214,7 +1219,7 @@ class LongT5Block(GradientCheckpointingLayer): attention_mask=encoder_attention_mask, position_bias=encoder_decoder_position_bias, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, query_length=cache_position[-1] + 1, use_cache=use_cache, output_attentions=output_attentions, @@ -1497,7 +1502,7 @@ class LongT5Stack(LongT5PreTrainedModel): encoder_decoder_position_bias, # as a positional argument for gradient checkpointing layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, return_dict=return_dict, diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 89344e6ac1..fadba94f9e 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -44,6 +44,7 @@ from ...modeling_outputs import ( from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, is_torch_flex_attn_available, is_torchdynamo_compiling, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_m2m_100 import M2M100Config @@ -251,11 +252,12 @@ class M2M100Attention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, @@ -280,19 +282,19 @@ class M2M100Attention(nn.Module): # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values @@ -302,7 +304,7 @@ class M2M100Attention(nn.Module): key_states = key_states.view(*kv_input_shape).transpose(1, 2) value_states = value_states.view(*kv_input_shape).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( @@ -310,7 +312,7 @@ class M2M100Attention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -432,6 +434,7 @@ class M2M100DecoderLayer(GradientCheckpointingLayer): self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -440,7 +443,7 @@ class M2M100DecoderLayer(GradientCheckpointingLayer): encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, cache_position: Optional[torch.Tensor] = None, @@ -458,7 +461,7 @@ class M2M100DecoderLayer(GradientCheckpointingLayer): `(encoder_attention_heads,)`. cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of size `(decoder_attention_heads,)`. - past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -472,7 +475,7 @@ class M2M100DecoderLayer(GradientCheckpointingLayer): # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, @@ -492,7 +495,7 @@ class M2M100DecoderLayer(GradientCheckpointingLayer): key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -511,6 +514,7 @@ class M2M100DecoderLayer(GradientCheckpointingLayer): if output_attentions: outputs += (self_attn_weights, cross_attn_weights) + return outputs @@ -1133,7 +1137,7 @@ class M2M100Decoder(M2M100PreTrainedModel): cross_attn_layer_head_mask=( cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None ), - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -1275,7 +1279,7 @@ class M2M100Model(M2M100PreTrainedModel): attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 80ad444075..20dc02213d 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -49,6 +49,7 @@ from ...utils import ( is_torchdynamo_compiling, logging, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_marian import MarianConfig @@ -186,11 +187,12 @@ class MarianAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, @@ -215,19 +217,19 @@ class MarianAttention(nn.Module): # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values @@ -237,7 +239,7 @@ class MarianAttention(nn.Module): key_states = key_states.view(*kv_input_shape).transpose(1, 2) value_states = value_states.view(*kv_input_shape).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( @@ -245,7 +247,7 @@ class MarianAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -375,6 +377,7 @@ class MarianDecoderLayer(GradientCheckpointingLayer): self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -383,7 +386,7 @@ class MarianDecoderLayer(GradientCheckpointingLayer): encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, cache_position: Optional[torch.Tensor] = None, @@ -401,7 +404,7 @@ class MarianDecoderLayer(GradientCheckpointingLayer): `(encoder_attention_heads,)`. cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of size `(decoder_attention_heads,)`. - past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -414,7 +417,7 @@ class MarianDecoderLayer(GradientCheckpointingLayer): # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, @@ -434,7 +437,7 @@ class MarianDecoderLayer(GradientCheckpointingLayer): key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) @@ -455,6 +458,7 @@ class MarianDecoderLayer(GradientCheckpointingLayer): if output_attentions: outputs += (self_attn_weights, cross_attn_weights) + return outputs @@ -1069,7 +1073,7 @@ class MarianDecoder(MarianPreTrainedModel): encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -1276,7 +1280,7 @@ class MarianModel(MarianPreTrainedModel): attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 2f6b5c20ef..11d8ca2d26 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -51,6 +51,7 @@ from ...utils import ( is_torchdynamo_compiling, logging, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_mbart import MBartConfig @@ -195,11 +196,12 @@ class MBartAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, @@ -224,19 +226,19 @@ class MBartAttention(nn.Module): # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values @@ -246,7 +248,7 @@ class MBartAttention(nn.Module): key_states = key_states.view(*kv_input_shape).transpose(1, 2) value_states = value_states.view(*kv_input_shape).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( @@ -254,7 +256,7 @@ class MBartAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -374,6 +376,7 @@ class MBartDecoderLayer(GradientCheckpointingLayer): self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -382,7 +385,7 @@ class MBartDecoderLayer(GradientCheckpointingLayer): encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, cache_position: Optional[torch.Tensor] = None, @@ -400,7 +403,7 @@ class MBartDecoderLayer(GradientCheckpointingLayer): `(encoder_attention_heads,)`. cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of size `(decoder_attention_heads,)`. - past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -414,7 +417,7 @@ class MBartDecoderLayer(GradientCheckpointingLayer): # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, @@ -434,7 +437,7 @@ class MBartDecoderLayer(GradientCheckpointingLayer): key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -1113,7 +1116,7 @@ class MBartDecoder(MBartPreTrainedModel): encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -1262,7 +1265,7 @@ class MBartModel(MBartPreTrainedModel): attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, diff --git a/src/transformers/models/megatron_bert/modeling_megatron_bert.py b/src/transformers/models/megatron_bert/modeling_megatron_bert.py index e358499c60..b12e97e68c 100755 --- a/src/transformers/models/megatron_bert/modeling_megatron_bert.py +++ b/src/transformers/models/megatron_bert/modeling_megatron_bert.py @@ -44,6 +44,7 @@ from ...modeling_outputs import ( from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_megatron_bert import MegatronBertConfig @@ -206,13 +207,14 @@ class MegatronBertSelfAttention(nn.Module): self.is_decoder = config.is_decoder self.layer_idx = layer_idx + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -223,19 +225,19 @@ class MegatronBertSelfAttention(nn.Module): ) is_cross_attention = encoder_hidden_states is not None - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_layer from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = encoder_hidden_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_layer = curr_past_key_value.layers[self.layer_idx].keys value_layer = curr_past_key_value.layers[self.layer_idx].values @@ -249,7 +251,7 @@ class MegatronBertSelfAttention(nn.Module): batch_size, -1, self.num_attention_heads, self.attention_head_size ).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_layer, value_layer = curr_past_key_value.update( @@ -257,14 +259,14 @@ class MegatronBertSelfAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if past_key_value is not None: + if past_key_values is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -349,13 +351,14 @@ class MegatronBertAttention(nn.Module): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -365,7 +368,7 @@ class MegatronBertAttention(nn.Module): attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) @@ -420,6 +423,7 @@ class MegatronBertLayer(GradientCheckpointingLayer): self.intermediate = MegatronBertIntermediate(config) self.output = MegatronBertOutput(config) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -427,7 +431,7 @@ class MegatronBertLayer(GradientCheckpointingLayer): head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -437,7 +441,7 @@ class MegatronBertLayer(GradientCheckpointingLayer): attention_mask=attention_mask, head_mask=head_mask, output_attentions=output_attentions, - past_key_value=past_key_value, + past_key_values=past_key_values, cache_position=cache_position, ) attention_output = self_attention_outputs[0] @@ -455,7 +459,7 @@ class MegatronBertLayer(GradientCheckpointingLayer): attention_mask=encoder_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) diff --git a/src/transformers/models/mimi/modeling_mimi.py b/src/transformers/models/mimi/modeling_mimi.py index b5e52b13f0..0cabf4849c 100644 --- a/src/transformers/models/mimi/modeling_mimi.py +++ b/src/transformers/models/mimi/modeling_mimi.py @@ -31,6 +31,7 @@ from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_mimi import MimiConfig @@ -643,12 +644,13 @@ class MimiAttention(nn.Module): self.rotary_emb = MimiRotaryEmbedding(config) self.sliding_window = config.sliding_window # Ignore copy + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -666,10 +668,10 @@ class MimiAttention(nn.Module): cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -719,17 +721,18 @@ class MimiFlashAttention2(MimiAttention): # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: - if isinstance(past_key_value, StaticCache): + if isinstance(past_key_values, StaticCache): raise ValueError( "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" @@ -753,10 +756,10 @@ class MimiFlashAttention2(MimiAttention): cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. @@ -829,12 +832,13 @@ class MimiSdpaAttention(MimiAttention): """ # Adapted from MimiAttention.forward + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -850,7 +854,7 @@ class MimiSdpaAttention(MimiAttention): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -869,10 +873,10 @@ class MimiSdpaAttention(MimiAttention): cos, sin = self.rotary_emb(value_states, position_ids) query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -929,12 +933,13 @@ class MimiTransformerLayer(GradientCheckpointingLayer): self.self_attn_layer_scale = MimiLayerScale(config) self.mlp_layer_scale = MimiLayerScale(config) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -952,7 +957,7 @@ class MimiTransformerLayer(GradientCheckpointingLayer): use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence kwargs (`dict`, *optional*): @@ -968,7 +973,7 @@ class MimiTransformerLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -1130,7 +1135,7 @@ class MimiTransformerModel(nn.Module): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, diff --git a/src/transformers/models/minimax/modeling_minimax.py b/src/transformers/models/minimax/modeling_minimax.py index 8a4ae7f356..cfd1af9cec 100644 --- a/src/transformers/models/minimax/modeling_minimax.py +++ b/src/transformers/models/minimax/modeling_minimax.py @@ -45,6 +45,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import OutputRecorder from .configuration_minimax import MiniMaxConfig @@ -164,12 +165,13 @@ class MiniMaxLightningAttention(nn.Module): return query_decay, key_decay, diagonal_decay + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: @@ -187,8 +189,8 @@ class MiniMaxLightningAttention(nn.Module): # calculated (K.T @ V) and saved as cache attn_weights_inter = None - if past_key_value is not None: - attn_weights_inter = past_key_value.get_linear_cache(self.layer_idx) + if past_key_values is not None: + attn_weights_inter = past_key_values.get_linear_cache(self.layer_idx) if attn_weights_inter is None: attn_weights_inter = torch.zeros(batch_size, self.num_attention_heads, self.head_dim, self.head_dim).to( @@ -257,8 +259,8 @@ class MiniMaxLightningAttention(nn.Module): attn_output = self.out_proj(attn_output) # update cache - if past_key_value is not None: - past_key_value.set_linear_cache(self.layer_idx, attn_weights_inter) + if past_key_values is not None: + past_key_values.set_linear_cache(self.layer_idx, attn_weights_inter) return attn_output, attn_weights_inter @@ -352,12 +354,13 @@ class MiniMaxAttention(nn.Module): self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -371,10 +374,10 @@ class MiniMaxAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -507,13 +510,14 @@ class MiniMaxDecoderLayer(GradientCheckpointingLayer): self.attn_alpha_factor = config.full_attn_alpha_factor self.attn_beta_factor = config.full_attn_beta_factor + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -528,7 +532,7 @@ class MiniMaxDecoderLayer(GradientCheckpointingLayer): with `head_dim` being the embedding dimension of each attention head. attention_mask (`torch.Tensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -554,7 +558,7 @@ class MiniMaxDecoderLayer(GradientCheckpointingLayer): position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -706,7 +710,7 @@ class MiniMaxModel(MiniMaxPreTrainedModel): position_embeddings=position_embeddings, attention_mask=input_attention_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, **kwargs, diff --git a/src/transformers/models/minimax/modular_minimax.py b/src/transformers/models/minimax/modular_minimax.py index 99be8f7fb5..9582e5f392 100644 --- a/src/transformers/models/minimax/modular_minimax.py +++ b/src/transformers/models/minimax/modular_minimax.py @@ -30,6 +30,7 @@ from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import MoeModelOutputWithPast from ...processing_utils import Unpack from ...utils import TransformersKwargs, logging +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import OutputRecorder from ..mixtral.configuration_mixtral import MixtralConfig from ..mixtral.modeling_mixtral import ( @@ -278,12 +279,13 @@ class MiniMaxLightningAttention(nn.Module): return query_decay, key_decay, diagonal_decay + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: @@ -301,8 +303,8 @@ class MiniMaxLightningAttention(nn.Module): # calculated (K.T @ V) and saved as cache attn_weights_inter = None - if past_key_value is not None: - attn_weights_inter = past_key_value.get_linear_cache(self.layer_idx) + if past_key_values is not None: + attn_weights_inter = past_key_values.get_linear_cache(self.layer_idx) if attn_weights_inter is None: attn_weights_inter = torch.zeros(batch_size, self.num_attention_heads, self.head_dim, self.head_dim).to( @@ -371,8 +373,8 @@ class MiniMaxLightningAttention(nn.Module): attn_output = self.out_proj(attn_output) # update cache - if past_key_value is not None: - past_key_value.set_linear_cache(self.layer_idx, attn_weights_inter) + if past_key_values is not None: + past_key_values.set_linear_cache(self.layer_idx, attn_weights_inter) return attn_output, attn_weights_inter @@ -403,13 +405,14 @@ class MiniMaxDecoderLayer(MixtralDecoderLayer, GradientCheckpointingLayer): self.attn_alpha_factor = config.full_attn_alpha_factor self.attn_beta_factor = config.full_attn_beta_factor + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -424,7 +427,7 @@ class MiniMaxDecoderLayer(MixtralDecoderLayer, GradientCheckpointingLayer): with `head_dim` being the embedding dimension of each attention head. attention_mask (`torch.Tensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -450,7 +453,7 @@ class MiniMaxDecoderLayer(MixtralDecoderLayer, GradientCheckpointingLayer): position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -537,7 +540,7 @@ class MiniMaxModel(MixtralModel): position_embeddings=position_embeddings, attention_mask=input_attention_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, **kwargs, diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 212151206a..7d712eeaa2 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -28,6 +28,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg from .configuration_mistral import MistralConfig @@ -136,12 +137,13 @@ class MistralAttention(nn.Module): self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -155,10 +157,10 @@ class MistralAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -211,12 +213,13 @@ class MistralDecoderLayer(GradientCheckpointingLayer): self.input_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -229,7 +232,7 @@ class MistralDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -367,7 +370,7 @@ class MistralModel(MistralPreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, diff --git a/src/transformers/models/mistral/modular_mistral.py b/src/transformers/models/mistral/modular_mistral.py index ece58e6e78..94aacded35 100644 --- a/src/transformers/models/mistral/modular_mistral.py +++ b/src/transformers/models/mistral/modular_mistral.py @@ -15,6 +15,7 @@ from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from ..llama.modeling_llama import ( LlamaAttention, LlamaDecoderLayer, @@ -50,12 +51,13 @@ class MistralAttention(LlamaAttention): self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -69,10 +71,10 @@ class MistralAttention(LlamaAttention): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -159,7 +161,7 @@ class MistralModel(LlamaModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 1ee3004593..41c3000c89 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -49,6 +49,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import OutputRecorder from .configuration_mixtral import MixtralConfig @@ -248,12 +249,13 @@ class MixtralAttention(nn.Module): self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -267,10 +269,10 @@ class MixtralAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -304,13 +306,14 @@ class MixtralDecoderLayer(GradientCheckpointingLayer): self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> torch.FloatTensor: @@ -324,7 +327,7 @@ class MixtralDecoderLayer(GradientCheckpointingLayer): position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, cache_position=cache_position, **kwargs, ) @@ -463,7 +466,7 @@ class MixtralModel(MixtralPreTrainedModel): position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, **kwargs, diff --git a/src/transformers/models/mixtral/modular_mixtral.py b/src/transformers/models/mixtral/modular_mixtral.py index 8dc8c772e7..a2090d8d5a 100644 --- a/src/transformers/models/mixtral/modular_mixtral.py +++ b/src/transformers/models/mixtral/modular_mixtral.py @@ -33,6 +33,7 @@ from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...processing_utils import Unpack from ...utils import TransformersKwargs, logging +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import OutputRecorder from ..mistral.modeling_mistral import ( MistralAttention, @@ -237,13 +238,14 @@ class MixtralDecoderLayer(GradientCheckpointingLayer): self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> torch.FloatTensor: @@ -257,7 +259,7 @@ class MixtralDecoderLayer(GradientCheckpointingLayer): position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, cache_position=cache_position, **kwargs, ) @@ -335,7 +337,7 @@ class MixtralModel(MistralModel): position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, **kwargs, diff --git a/src/transformers/models/mllama/modeling_mllama.py b/src/transformers/models/mllama/modeling_mllama.py index 615eb34f6d..c47ad38414 100644 --- a/src/transformers/models/mllama/modeling_mllama.py +++ b/src/transformers/models/mllama/modeling_mllama.py @@ -33,6 +33,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import OutputRecorder, check_model_inputs from .configuration_mllama import MllamaConfig, MllamaTextConfig, MllamaVisionConfig @@ -407,11 +408,12 @@ class MllamaTextCrossAttention(nn.Module): self.q_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) self.k_norm = MllamaTextRMSNorm(self.head_dim, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, cross_attention_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, @@ -430,16 +432,16 @@ class MllamaTextCrossAttention(nn.Module): value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) key_states = self.k_norm(key_states) - if past_key_value is not None: + if past_key_values is not None: # if we have a new image + new tokens, we only computed key_states on that new image # we still update the cross key states, past_image, new_image. And use it! - key_states, value_states = past_key_value.update( + key_states, value_states = past_key_values.update( key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) elif cache_position[0] != 0: key_states, value_states = ( - past_key_value.layers[self.layer_idx].keys, - past_key_value.layers[self.layer_idx].values, + past_key_values.layers[self.layer_idx].keys, + past_key_values.layers[self.layer_idx].values, ) else: raise ValueError( @@ -523,13 +525,14 @@ class MllamaTextSelfAttention(nn.Module): self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: torch.Tensor, position_embeddings: torch.Tensor, use_cache: bool = False, - past_key_value=None, + past_key_values=None, cache_position=None, **kwargs, ): @@ -546,10 +549,10 @@ class MllamaTextSelfAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -605,6 +608,7 @@ class MllamaSelfAttentionDecoderLayer(GradientCheckpointingLayer): self.layer_idx = layer_idx + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -613,7 +617,7 @@ class MllamaSelfAttentionDecoderLayer(GradientCheckpointingLayer): attention_mask: Optional[torch.Tensor] = None, full_text_row_masked_out_mask: Optional[tuple[torch.Tensor, torch.Tensor]] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -629,7 +633,7 @@ class MllamaSelfAttentionDecoderLayer(GradientCheckpointingLayer): use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): @@ -648,7 +652,7 @@ class MllamaSelfAttentionDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -680,6 +684,7 @@ class MllamaCrossAttentionDecoderLayer(GradientCheckpointingLayer): self.post_attention_layernorm = MllamaTextRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.cross_attn_mlp_gate = torch.nn.Parameter(torch.zeros(1)) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -688,7 +693,7 @@ class MllamaCrossAttentionDecoderLayer(GradientCheckpointingLayer): attention_mask: torch.Tensor, full_text_row_masked_out_mask: tuple[torch.Tensor, torch.Tensor], position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[torch.Tensor] = None, @@ -701,7 +706,7 @@ class MllamaCrossAttentionDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=cross_attention_mask, cross_attention_states=cross_attention_states, - past_key_value=past_key_value, + past_key_values=past_key_values, cache_position=cache_position, **kwargs, ) @@ -1257,7 +1262,7 @@ class MllamaTextModel(MllamaPreTrainedModel): attention_mask=causal_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, diff --git a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py index 0a7d4a9084..ece5455e84 100644 --- a/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +++ b/src/transformers/models/modernbert_decoder/modeling_modernbert_decoder.py @@ -43,6 +43,7 @@ from ...models.modernbert.modeling_modernbert import ( ) from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import check_model_inputs from .configuration_modernbert_decoder import ModernBertDecoderConfig @@ -113,12 +114,13 @@ class ModernBertDecoderAttention(nn.Module): self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -132,10 +134,10 @@ class ModernBertDecoderAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -173,13 +175,14 @@ class ModernBertDecoderLayer(GradientCheckpointingLayer): self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) self.mlp = ModernBertMLP(config) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings_global: torch.Tensor, position_embeddings_local: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -198,7 +201,7 @@ class ModernBertDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, cache_position=cache_position, **kwargs, ) @@ -368,7 +371,7 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel): position_embeddings_global=position_embeddings_global, position_embeddings_local=position_embeddings_local, attention_mask=causal_mask_mapping[decoder_layer.attention_type], - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, **kwargs, diff --git a/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py b/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py index fe0f27d0d6..bb916f45fe 100644 --- a/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py +++ b/src/transformers/models/modernbert_decoder/modular_modernbert_decoder.py @@ -38,6 +38,7 @@ from ...models.modernbert.modeling_modernbert import ( ) from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import check_model_inputs @@ -290,12 +291,13 @@ class ModernBertDecoderAttention(nn.Module): self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: torch.Tensor, attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -309,10 +311,10 @@ class ModernBertDecoderAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -350,13 +352,14 @@ class ModernBertDecoderLayer(GradientCheckpointingLayer): self.mlp_norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias) self.mlp = ModernBertMLP(config) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings_global: torch.Tensor, position_embeddings_local: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -375,7 +378,7 @@ class ModernBertDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, position_embeddings=position_embeddings, attention_mask=attention_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, cache_position=cache_position, **kwargs, ) @@ -545,7 +548,7 @@ class ModernBertDecoderModel(ModernBertDecoderPreTrainedModel): position_embeddings_global=position_embeddings_global, position_embeddings_local=position_embeddings_local, attention_mask=causal_mask_mapping[decoder_layer.attention_type], - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, **kwargs, diff --git a/src/transformers/models/moonshine/modeling_moonshine.py b/src/transformers/models/moonshine/modeling_moonshine.py index aa7dc909a6..dff9678612 100644 --- a/src/transformers/models/moonshine/modeling_moonshine.py +++ b/src/transformers/models/moonshine/modeling_moonshine.py @@ -44,6 +44,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg from .configuration_moonshine import MoonshineConfig @@ -205,12 +206,13 @@ class MoonshineAttention(nn.Module): else: self.head_dim_padding = 0 + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, key_value_states: Optional[torch.Tensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -222,20 +224,20 @@ class MoonshineAttention(nn.Module): ) is_cross_attention = key_value_states is not None - if past_key_value is not None: - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - past_key_value.is_updated[self.layer_idx] = True - past_key_value = past_key_value.cross_attention_cache + past_key_values.is_updated[self.layer_idx] = True + past_key_values = past_key_values.cross_attention_cache else: - past_key_value = past_key_value.self_attention_cache + past_key_values = past_key_values.self_attention_cache # use key_value_states if cross attention current_states = key_value_states if key_value_states is not None else hidden_states - if is_cross_attention and past_key_value and is_updated: - key_states = past_key_value.layers[self.layer_idx].keys - value_states = past_key_value.layers[self.layer_idx].values + if is_cross_attention and past_key_values and is_updated: + key_states = past_key_values.layers[self.layer_idx].keys + value_states = past_key_values.layers[self.layer_idx].values else: key_states = ( self.k_proj(current_states) @@ -247,8 +249,8 @@ class MoonshineAttention(nn.Module): .view(bsz, -1, self.config.num_key_value_heads, self.head_dim) .transpose(1, 2) ) - if is_cross_attention and past_key_value is not None: - key_states, value_states = past_key_value.update( + if is_cross_attention and past_key_values is not None: + key_states, value_states = past_key_values.update( key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) @@ -256,9 +258,9 @@ class MoonshineAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update( + key_states, value_states = past_key_values.update( key_states, value_states, self.layer_idx, cache_kwargs ) @@ -346,12 +348,13 @@ class MoonshineEncoderLayer(GradientCheckpointingLayer): self.input_layernorm = nn.LayerNorm(config.hidden_size, bias=False) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, bias=False) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -364,7 +367,7 @@ class MoonshineEncoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -405,6 +408,7 @@ class MoonshineDecoderLayer(GradientCheckpointingLayer): self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, bias=False) self.final_layernorm = nn.LayerNorm(config.hidden_size, bias=False) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -413,7 +417,7 @@ class MoonshineDecoderLayer(GradientCheckpointingLayer): encoder_attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, encoder_position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, @@ -427,7 +431,7 @@ class MoonshineDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -442,7 +446,7 @@ class MoonshineDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, ) hidden_states = residual + hidden_states @@ -679,7 +683,7 @@ class MoonshineDecoder(MoonshinePreTrainedModel): encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, diff --git a/src/transformers/models/moonshine/modular_moonshine.py b/src/transformers/models/moonshine/modular_moonshine.py index 8452283e9c..cd81408283 100644 --- a/src/transformers/models/moonshine/modular_moonshine.py +++ b/src/transformers/models/moonshine/modular_moonshine.py @@ -38,6 +38,7 @@ from ...modeling_rope_utils import rope_config_validation from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.deprecation import deprecate_kwarg from ..glm.modeling_glm import GlmAttention, GlmRotaryEmbedding, apply_rotary_pos_emb from ..llama.modeling_llama import LlamaDecoderLayer, LlamaModel, eager_attention_forward from ..whisper.modeling_whisper import WhisperModel, shift_tokens_right @@ -304,12 +305,13 @@ class MoonshineAttention(GlmAttention): else: self.head_dim_padding = 0 + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, key_value_states: Optional[torch.Tensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -321,20 +323,20 @@ class MoonshineAttention(GlmAttention): ) is_cross_attention = key_value_states is not None - if past_key_value is not None: - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - past_key_value.is_updated[self.layer_idx] = True - past_key_value = past_key_value.cross_attention_cache + past_key_values.is_updated[self.layer_idx] = True + past_key_values = past_key_values.cross_attention_cache else: - past_key_value = past_key_value.self_attention_cache + past_key_values = past_key_values.self_attention_cache # use key_value_states if cross attention current_states = key_value_states if key_value_states is not None else hidden_states - if is_cross_attention and past_key_value and is_updated: - key_states = past_key_value.layers[self.layer_idx].keys - value_states = past_key_value.layers[self.layer_idx].values + if is_cross_attention and past_key_values and is_updated: + key_states = past_key_values.layers[self.layer_idx].keys + value_states = past_key_values.layers[self.layer_idx].values else: key_states = ( self.k_proj(current_states) @@ -346,8 +348,8 @@ class MoonshineAttention(GlmAttention): .view(bsz, -1, self.config.num_key_value_heads, self.head_dim) .transpose(1, 2) ) - if is_cross_attention and past_key_value is not None: - key_states, value_states = past_key_value.update( + if is_cross_attention and past_key_values is not None: + key_states, value_states = past_key_values.update( key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) @@ -355,9 +357,9 @@ class MoonshineAttention(GlmAttention): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update( + key_states, value_states = past_key_values.update( key_states, value_states, self.layer_idx, cache_kwargs ) @@ -438,6 +440,7 @@ class MoonshineDecoderLayer(GradientCheckpointingLayer): self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, bias=False) self.final_layernorm = nn.LayerNorm(config.hidden_size, bias=False) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -446,7 +449,7 @@ class MoonshineDecoderLayer(GradientCheckpointingLayer): encoder_attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, encoder_position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, @@ -460,7 +463,7 @@ class MoonshineDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -475,7 +478,7 @@ class MoonshineDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, ) hidden_states = residual + hidden_states @@ -702,7 +705,7 @@ class MoonshineDecoder(LlamaModel): encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, diff --git a/src/transformers/models/moshi/modeling_moshi.py b/src/transformers/models/moshi/modeling_moshi.py index 2e1581fbdb..de8f8c1087 100644 --- a/src/transformers/models/moshi/modeling_moshi.py +++ b/src/transformers/models/moshi/modeling_moshi.py @@ -33,6 +33,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, is_torch_flex_attn_available, logging +from ...utils.deprecation import deprecate_kwarg from ..auto.modeling_auto import AutoModel from .configuration_moshi import MoshiConfig, MoshiDepthConfig @@ -431,14 +432,13 @@ class MoshiAttention(nn.Module): self.rope_theta = config.rope_theta self.rotary_emb = MoshiRotaryEmbedding(config) - # copied from transformers.models.gemma.modeling_gemma.GemmaAttention.forward - # no longer copied after attention refactors + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -457,14 +457,14 @@ class MoshiAttention(nn.Module): cos, sin = self.rotary_emb(value_states, position_ids) # Ignore copy query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # Ignore copy - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = ( {"sin": sin, "cos": cos, "cache_position": cache_position} if self.rotary_emb is not None else {"cache_position": cache_position} ) # Ignore copy - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -514,17 +514,18 @@ class MoshiFlashAttention2(MoshiAttention): # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: - if isinstance(past_key_value, StaticCache): + if isinstance(past_key_values, StaticCache): raise ValueError( "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" @@ -549,14 +550,14 @@ class MoshiFlashAttention2(MoshiAttention): cos, sin = self.rotary_emb(value_states, position_ids) # Ignore copy query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # Ignore copy - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = ( {"sin": sin, "cos": cos, "cache_position": cache_position} if self.rotary_emb is not None else {"cache_position": cache_position} ) # Ignore copy - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. @@ -629,12 +630,13 @@ class MoshiSdpaAttention(MoshiAttention): """ # Adapted from MoshiAttention.forward + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -650,7 +652,7 @@ class MoshiSdpaAttention(MoshiAttention): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -670,14 +672,14 @@ class MoshiSdpaAttention(MoshiAttention): cos, sin = self.rotary_emb(value_states, position_ids) # Ignore copy query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) # Ignore copy - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = ( {"sin": sin, "cos": cos, "cache_position": cache_position} if self.rotary_emb is not None else {"cache_position": cache_position} ) # Ignore copy - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -738,12 +740,13 @@ class MoshiDecoderLayer(GradientCheckpointingLayer): self._attn_implementation = config._attn_implementation + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -761,7 +764,7 @@ class MoshiDecoderLayer(GradientCheckpointingLayer): use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence kwargs (`dict`, *optional*): @@ -777,7 +780,7 @@ class MoshiDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -1009,7 +1012,7 @@ class MoshiDepthDecoder(MoshiPreTrainedModel, GenerationMixin): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -1295,7 +1298,7 @@ class MoshiModel(MoshiPreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, diff --git a/src/transformers/models/mpt/modeling_mpt.py b/src/transformers/models/mpt/modeling_mpt.py index 849b3c4851..891602c338 100644 --- a/src/transformers/models/mpt/modeling_mpt.py +++ b/src/transformers/models/mpt/modeling_mpt.py @@ -36,6 +36,7 @@ from ...modeling_outputs import ( ) from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_mpt import MptConfig @@ -86,11 +87,12 @@ class MptAttention(nn.Module): self.out_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) self.layer_idx = layer_idx + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_bias: torch.Tensor, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, cache_position: Optional[torch.Tensor] = None, ): @@ -105,12 +107,12 @@ class MptAttention(nn.Module): key_states = key_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) value_states = value_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: cache_kwargs = {"cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.softmax_scale - query_length = seq_length if past_key_value is None else seq_length + past_key_value.get_seq_length() + query_length = seq_length if past_key_values is None else seq_length + past_key_values.get_seq_length() if position_bias is not None: if len(position_bias.shape) != 3: @@ -201,7 +203,7 @@ class MptBlock(GradientCheckpointingLayer): layernorm_output, position_bias=position_bias, attention_mask=attention_mask, - past_key_value=layer_past, + past_key_values=layer_past, cache_position=cache_position, ) @@ -246,13 +248,14 @@ class MptPreTrainedModel(PreTrainedModel): module.weight.data.fill_(1.0) @staticmethod + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def _convert_to_mpt_cache( - past_key_value: tuple[tuple[torch.Tensor, torch.Tensor]], + past_key_values: tuple[tuple[torch.Tensor, torch.Tensor]], ) -> tuple[tuple[torch.Tensor, torch.Tensor]]: """ Converts the cache to the format expected by Mpt, i.e. to tuple(tuple([batch_size * num_heads, ...])) """ - batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape + batch_size, num_heads, head_dim, seq_length = past_key_values[0][0].shape batch_size_times_num_heads = batch_size * num_heads # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length] # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim] @@ -261,7 +264,7 @@ class MptPreTrainedModel(PreTrainedModel): layer_past[0].reshape(batch_size_times_num_heads, head_dim, seq_length), layer_past[1].reshape(batch_size_times_num_heads, seq_length, head_dim), ) - for layer_past in past_key_value + for layer_past in past_key_values ) diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 0a414cdf11..8d8e052b32 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -50,6 +50,7 @@ from ...utils import ( is_torchdynamo_compiling, logging, ) +from ...utils.deprecation import deprecate_kwarg from ...utils.model_parallel_utils import assert_device_map, get_device_map from .configuration_mt5 import MT5Config @@ -339,13 +340,14 @@ class MT5Attention(nn.Module): values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) return values + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, mask=None, key_value_states=None, position_bias=None, - past_key_value=None, + past_key_values=None, layer_head_mask=None, query_length=None, use_cache=False, @@ -366,18 +368,18 @@ class MT5Attention(nn.Module): query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache` - if past_key_value is not None and isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None and isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values @@ -387,7 +389,7 @@ class MT5Attention(nn.Module): key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( @@ -395,7 +397,7 @@ class MT5Attention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 scores = torch.matmul(query_states, key_states.transpose(3, 2)) @@ -460,13 +462,14 @@ class MT5LayerSelfAttention(nn.Module): self.layer_norm = MT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, attention_mask=None, position_bias=None, layer_head_mask=None, - past_key_value=None, + past_key_values=None, use_cache=False, output_attentions=False, cache_position=None, @@ -477,7 +480,7 @@ class MT5LayerSelfAttention(nn.Module): mask=attention_mask, position_bias=position_bias, layer_head_mask=layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, @@ -495,6 +498,7 @@ class MT5LayerCrossAttention(nn.Module): self.layer_norm = MT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, @@ -502,7 +506,7 @@ class MT5LayerCrossAttention(nn.Module): attention_mask=None, position_bias=None, layer_head_mask=None, - past_key_value=None, + past_key_values=None, use_cache=False, query_length=None, output_attentions=False, @@ -515,7 +519,7 @@ class MT5LayerCrossAttention(nn.Module): key_value_states=key_value_states, position_bias=position_bias, layer_head_mask=layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, query_length=query_length, output_attentions=output_attentions, @@ -540,6 +544,7 @@ class MT5Block(GradientCheckpointingLayer): self.layer.append(MT5LayerFF(config)) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, @@ -550,7 +555,7 @@ class MT5Block(GradientCheckpointingLayer): encoder_decoder_position_bias=None, layer_head_mask=None, cross_attn_layer_head_mask=None, - past_key_value=None, + past_key_values=None, use_cache=False, output_attentions=False, return_dict=True, @@ -561,7 +566,7 @@ class MT5Block(GradientCheckpointingLayer): attention_mask=attention_mask, position_bias=position_bias, layer_head_mask=layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, @@ -586,7 +591,7 @@ class MT5Block(GradientCheckpointingLayer): attention_mask=encoder_attention_mask, position_bias=encoder_decoder_position_bias, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, query_length=cache_position[-1] + 1, use_cache=use_cache, output_attentions=output_attentions, @@ -1085,7 +1090,7 @@ class MT5Stack(MT5PreTrainedModel): encoder_decoder_position_bias, # as a positional argument for gradient checkpointing layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, return_dict=return_dict, diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 97c16b7e25..a4beb1ddf9 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -55,6 +55,7 @@ from ...modeling_outputs import ( from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, is_torch_flex_attn_available, logging +from ...utils.deprecation import deprecate_kwarg from ..auto.configuration_auto import AutoConfig from ..auto.modeling_auto import AutoModel from .configuration_musicgen import MusicgenConfig, MusicgenDecoderConfig @@ -219,11 +220,12 @@ class MusicgenAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, @@ -248,19 +250,19 @@ class MusicgenAttention(nn.Module): # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_layer from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values @@ -268,7 +270,7 @@ class MusicgenAttention(nn.Module): key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2) value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( @@ -276,7 +278,7 @@ class MusicgenAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -335,8 +337,7 @@ class MusicgenDecoderLayer(GradientCheckpointingLayer): self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=False) self.final_layer_norm = nn.LayerNorm(self.embed_dim) - # copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer.forward - # TODO: change to new cache class + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -345,7 +346,7 @@ class MusicgenDecoderLayer(GradientCheckpointingLayer): encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, cache_position: Optional[torch.Tensor] = None, @@ -363,7 +364,7 @@ class MusicgenDecoderLayer(GradientCheckpointingLayer): `(encoder_attention_heads,)`. cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of size `(decoder_attention_heads,)`. - past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -374,7 +375,7 @@ class MusicgenDecoderLayer(GradientCheckpointingLayer): # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, @@ -394,7 +395,7 @@ class MusicgenDecoderLayer(GradientCheckpointingLayer): key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) @@ -619,7 +620,7 @@ class MusicgenDecoder(MusicgenPreTrainedModel): encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -807,7 +808,7 @@ class MusicgenModel(MusicgenPreTrainedModel): use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) decoder_outputs = self.decoder( input_ids=input_ids, attention_mask=attention_mask, diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 92cad3cacc..f2c3d6af4b 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -47,6 +47,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, is_torch_flex_attn_available, logging +from ...utils.deprecation import deprecate_kwarg from ..auto.configuration_auto import AutoConfig from ..auto.modeling_auto import AutoModel, AutoModelForTextEncoding from .configuration_musicgen_melody import MusicgenMelodyConfig, MusicgenMelodyDecoderConfig @@ -227,11 +228,12 @@ class MusicgenMelodyAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, @@ -256,19 +258,19 @@ class MusicgenMelodyAttention(nn.Module): # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_layer from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values @@ -276,7 +278,7 @@ class MusicgenMelodyAttention(nn.Module): key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2) value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( @@ -284,7 +286,7 @@ class MusicgenMelodyAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -334,12 +336,13 @@ class MusicgenMelodyDecoderLayer(GradientCheckpointingLayer): self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=False) self.final_layer_norm = nn.LayerNorm(self.embed_dim) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, cache_position: Optional[torch.Tensor] = None, @@ -350,7 +353,7 @@ class MusicgenMelodyDecoderLayer(GradientCheckpointingLayer): attention_mask (`torch.FloatTensor`): attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size `(attention_heads,)`. - past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -361,7 +364,7 @@ class MusicgenMelodyDecoderLayer(GradientCheckpointingLayer): # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, @@ -583,7 +586,7 @@ class MusicgenMelodyDecoder(MusicgenMelodyPreTrainedModel): hidden_states, attention_mask=attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -737,7 +740,7 @@ class MusicgenMelodyModel(MusicgenMelodyPreTrainedModel): use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) decoder_outputs = self.decoder( input_ids=input_ids, attention_mask=attention_mask, diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index 1102b3c2f7..c15975fd6f 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -41,6 +41,7 @@ from ...modeling_outputs import ( ) from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_mvp import MvpConfig @@ -122,11 +123,12 @@ class MvpAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, attn_prompt: Optional[torch.Tensor] = None, @@ -144,19 +146,19 @@ class MvpAttention(nn.Module): # get query proj query_states = self.q_proj(hidden_states) * self.scaling - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values @@ -166,7 +168,7 @@ class MvpAttention(nn.Module): key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( @@ -174,7 +176,7 @@ class MvpAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True if attn_prompt is not None: key_states = torch.cat([attn_prompt[0].expand(bsz, -1, -1, -1), key_states], dim=2) @@ -345,6 +347,7 @@ class MvpDecoderLayer(GradientCheckpointingLayer): self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -355,7 +358,7 @@ class MvpDecoderLayer(GradientCheckpointingLayer): cross_attn_layer_head_mask: Optional[torch.Tensor] = None, self_attn_prompt: Optional[torch.Tensor] = None, cross_attn_prompt: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, cache_position: Optional[torch.Tensor] = None, @@ -377,7 +380,7 @@ class MvpDecoderLayer(GradientCheckpointingLayer): `(2, decoder_attention_heads, pro_len, head_dim)`. cross_attn_prompt (`torch.FloatTensor`): prompt of cross attention of shape `(2, decoder_attention_heads, pro_len, head_dim)`. - past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -387,7 +390,7 @@ class MvpDecoderLayer(GradientCheckpointingLayer): # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, attention_mask=attention_mask, layer_head_mask=layer_head_mask, attn_prompt=self_attn_prompt, @@ -409,7 +412,7 @@ class MvpDecoderLayer(GradientCheckpointingLayer): attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, attn_prompt=cross_attn_prompt, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -920,7 +923,7 @@ class MvpDecoder(MvpPreTrainedModel): cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), self_attn_prompt=(self_attn_prompt[idx] if self.use_prompt else None), cross_attn_prompt=(cross_attn_prompt[idx] if self.use_prompt else None), - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -1082,7 +1085,7 @@ class MvpModel(MvpPreTrainedModel): attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, diff --git a/src/transformers/models/nemotron/modeling_nemotron.py b/src/transformers/models/nemotron/modeling_nemotron.py index 71a2e2c876..f8e2afb889 100644 --- a/src/transformers/models/nemotron/modeling_nemotron.py +++ b/src/transformers/models/nemotron/modeling_nemotron.py @@ -41,6 +41,7 @@ from ...modeling_outputs import ( from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_nemotron import NemotronConfig @@ -224,13 +225,14 @@ class NemotronAttention(nn.Module): self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) self.o_proj = nn.Linear(self.head_dim * self.num_heads, self.hidden_size, bias=config.attention_bias) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -249,10 +251,10 @@ class NemotronAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -297,19 +299,19 @@ class NemotronFlashAttention2(NemotronAttention): # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - # Ignore copy + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: - if isinstance(past_key_value, StaticCache): + if isinstance(past_key_values, StaticCache): raise ValueError( "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" @@ -334,10 +336,10 @@ class NemotronFlashAttention2(NemotronAttention): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. @@ -410,14 +412,14 @@ class NemotronSdpaAttention(NemotronAttention): SDPA API. """ - # Ignore copy + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -433,7 +435,7 @@ class NemotronSdpaAttention(NemotronAttention): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -454,10 +456,10 @@ class NemotronSdpaAttention(NemotronAttention): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -515,12 +517,13 @@ class NemotronDecoderLayer(GradientCheckpointingLayer): self.input_layernorm = NemotronLayerNorm1P(config.hidden_size, eps=config.norm_eps) self.post_attention_layernorm = NemotronLayerNorm1P(config.hidden_size, eps=config.norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -539,7 +542,7 @@ class NemotronDecoderLayer(GradientCheckpointingLayer): use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): @@ -558,7 +561,7 @@ class NemotronDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -701,7 +704,7 @@ class NemotronModel(NemotronPreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 5e6dc2d52d..819ebef200 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -43,6 +43,7 @@ from ...modeling_outputs import ( from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, is_torch_flex_attn_available, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_nllb_moe import NllbMoeConfig @@ -541,11 +542,12 @@ class NllbMoeAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, @@ -570,19 +572,19 @@ class NllbMoeAttention(nn.Module): # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_layer from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = encoder_hidden_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values @@ -590,7 +592,7 @@ class NllbMoeAttention(nn.Module): key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2) value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( @@ -598,7 +600,7 @@ class NllbMoeAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -739,6 +741,7 @@ class NllbMoeDecoderLayer(GradientCheckpointingLayer): self.ff_layer_norm = nn.LayerNorm(config.d_model) self.ff_dropout = nn.Dropout(config.activation_dropout) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -747,7 +750,7 @@ class NllbMoeDecoderLayer(GradientCheckpointingLayer): encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = True, @@ -769,7 +772,7 @@ class NllbMoeDecoderLayer(GradientCheckpointingLayer): mask for attention heads in a given layer of size `(encoder_attention_heads,)`. cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of size `(decoder_attention_heads,)`. - past_key_value (`Tuple(torch.FloatTensor)`): + past_key_values (`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under @@ -781,7 +784,7 @@ class NllbMoeDecoderLayer(GradientCheckpointingLayer): # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, @@ -799,7 +802,7 @@ class NllbMoeDecoderLayer(GradientCheckpointingLayer): hidden_states, cross_attn_weights = self.cross_attention( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, output_attentions=output_attentions, @@ -1298,7 +1301,7 @@ class NllbMoeDecoder(NllbMoePreTrainedModel): encoder_attention_mask=encoder_attention_mask, layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_router_logits=output_router_logits, @@ -1545,7 +1548,7 @@ class NllbMoeModel(NllbMoePreTrainedModel): router_probs=encoder_outputs[3] if len(encoder_outputs) > 3 else None, ) - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 80a1c47b98..d32fb0d820 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -20,6 +20,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import check_model_inputs from .configuration_olmo import OlmoConfig @@ -153,12 +154,13 @@ class OlmoAttention(nn.Module): config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -181,10 +183,10 @@ class OlmoAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -216,12 +218,13 @@ class OlmoDecoderLayer(GradientCheckpointingLayer): self.input_layernorm = OlmoLayerNorm(config.hidden_size) self.post_attention_layernorm = OlmoLayerNorm(config.hidden_size) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -234,7 +237,7 @@ class OlmoDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -370,7 +373,7 @@ class OlmoModel(OlmoPreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, diff --git a/src/transformers/models/olmo/modular_olmo.py b/src/transformers/models/olmo/modular_olmo.py index 56690a2c1a..f54b910634 100644 --- a/src/transformers/models/olmo/modular_olmo.py +++ b/src/transformers/models/olmo/modular_olmo.py @@ -8,6 +8,7 @@ import torch.utils.checkpoint from ...cache_utils import Cache from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...utils import logging +from ...utils.deprecation import deprecate_kwarg from ..llama.modeling_llama import ( LlamaAttention, LlamaDecoderLayer, @@ -75,12 +76,13 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): class OlmoAttention(LlamaAttention): + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -103,10 +105,10 @@ class OlmoAttention(LlamaAttention): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 5fa1aaeeac..a98943a88e 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -22,6 +22,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import check_model_inputs from .configuration_olmo2 import Olmo2Config @@ -148,12 +149,13 @@ class Olmo2Attention(nn.Module): self.q_norm = Olmo2RMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps) self.k_norm = Olmo2RMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -171,10 +173,10 @@ class Olmo2Attention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -222,12 +224,13 @@ class Olmo2DecoderLayer(GradientCheckpointingLayer): self.post_attention_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_feedforward_layernorm = Olmo2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -238,7 +241,7 @@ class Olmo2DecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -375,7 +378,7 @@ class Olmo2Model(Olmo2PreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, cache_position=cache_position, position_embeddings=position_embeddings, **kwargs, diff --git a/src/transformers/models/olmo2/modular_olmo2.py b/src/transformers/models/olmo2/modular_olmo2.py index 56635628aa..c7e4706976 100644 --- a/src/transformers/models/olmo2/modular_olmo2.py +++ b/src/transformers/models/olmo2/modular_olmo2.py @@ -9,6 +9,7 @@ from ...cache_utils import Cache from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...utils import logging +from ...utils.deprecation import deprecate_kwarg from ..llama.modeling_llama import LlamaPreTrainedModel, LlamaRMSNorm, eager_attention_forward from ..olmo.configuration_olmo import OlmoConfig from ..olmo.modeling_olmo import ( @@ -194,12 +195,13 @@ class Olmo2Attention(OlmoAttention): self.q_norm = Olmo2RMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps) self.k_norm = Olmo2RMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -217,10 +219,10 @@ class Olmo2Attention(OlmoAttention): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -253,12 +255,13 @@ class Olmo2DecoderLayer(OlmoDecoderLayer): self.self_attn = Olmo2Attention(config=config, layer_idx=layer_idx) del self.input_layernorm + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -269,7 +272,7 @@ class Olmo2DecoderLayer(OlmoDecoderLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index 7c9cdc983f..6a4b970815 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -29,6 +29,7 @@ from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPas from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_olmoe import OlmoeConfig @@ -288,12 +289,13 @@ class OlmoeAttention(nn.Module): (self.hidden_size // self.num_heads) * self.num_key_value_heads, eps=config.rms_norm_eps ) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -318,10 +320,10 @@ class OlmoeAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -370,12 +372,13 @@ class OlmoeFlashAttention2(OlmoeAttention): # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -404,10 +407,10 @@ class OlmoeFlashAttention2(OlmoeAttention): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. @@ -476,12 +479,13 @@ class OlmoeSdpaAttention(OlmoeAttention): """ # Adapted from OlmoeAttention.forward + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -497,7 +501,7 @@ class OlmoeSdpaAttention(OlmoeAttention): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -522,10 +526,10 @@ class OlmoeSdpaAttention(OlmoeAttention): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -629,12 +633,13 @@ class OlmoeDecoderLayer(GradientCheckpointingLayer): self.input_layernorm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = OlmoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -657,7 +662,7 @@ class OlmoeDecoderLayer(GradientCheckpointingLayer): use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): @@ -676,7 +681,7 @@ class OlmoeDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -821,7 +826,7 @@ class OlmoeModel(OlmoePreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, output_router_logits=output_router_logits, use_cache=use_cache, diff --git a/src/transformers/models/opt/modeling_opt.py b/src/transformers/models/opt/modeling_opt.py index ff5e8dfa01..738c556ace 100644 --- a/src/transformers/models/opt/modeling_opt.py +++ b/src/transformers/models/opt/modeling_opt.py @@ -36,6 +36,7 @@ from ...modeling_outputs import ( from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_opt import OPTConfig @@ -138,10 +139,11 @@ class OPTAttention(nn.Module): self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=self.enable_bias) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, @@ -164,9 +166,9 @@ class OPTAttention(nn.Module): key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation - key_states, value_states = past_key_value.update( + key_states, value_states = past_key_values.update( key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) @@ -219,12 +221,13 @@ class OPTDecoderLayer(GradientCheckpointingLayer): self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=config.enable_bias) self.final_layer_norm = nn.LayerNorm(self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, position_ids: Optional[torch.LongTensor] = None, @@ -244,7 +247,7 @@ class OPTDecoderLayer(GradientCheckpointingLayer): use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence.. """ @@ -258,7 +261,7 @@ class OPTDecoderLayer(GradientCheckpointingLayer): # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, position_ids=position_ids, attention_mask=attention_mask, layer_head_mask=layer_head_mask, @@ -655,7 +658,7 @@ class OPTDecoder(OPTPreTrainedModel): attention_mask=causal_mask, position_ids=position_ids, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -726,7 +729,7 @@ class OPTModel(OPTPreTrainedModel): use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) decoder_outputs = self.decoder( input_ids=input_ids, attention_mask=attention_mask, diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 33fc066d89..174e9308cb 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -49,6 +49,7 @@ from ...utils import ( is_torchdynamo_compiling, logging, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_pegasus import PegasusConfig @@ -185,11 +186,12 @@ class PegasusAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, @@ -214,19 +216,19 @@ class PegasusAttention(nn.Module): # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values @@ -236,7 +238,7 @@ class PegasusAttention(nn.Module): key_states = key_states.view(*kv_input_shape).transpose(1, 2) value_states = value_states.view(*kv_input_shape).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( @@ -244,7 +246,7 @@ class PegasusAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -366,6 +368,7 @@ class PegasusDecoderLayer(GradientCheckpointingLayer): self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -374,7 +377,7 @@ class PegasusDecoderLayer(GradientCheckpointingLayer): encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, cache_position: Optional[torch.Tensor] = None, @@ -392,7 +395,7 @@ class PegasusDecoderLayer(GradientCheckpointingLayer): `(encoder_attention_heads,)`. cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of size `(decoder_attention_heads,)`. - past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -406,7 +409,7 @@ class PegasusDecoderLayer(GradientCheckpointingLayer): # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, @@ -426,7 +429,7 @@ class PegasusDecoderLayer(GradientCheckpointingLayer): key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -1119,7 +1122,7 @@ class PegasusDecoder(PegasusPreTrainedModel): encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -1293,7 +1296,7 @@ class PegasusModel(PegasusPreTrainedModel): attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 029ad0a2e4..faf71a29f8 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -43,6 +43,7 @@ from ...modeling_outputs import ( from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, is_torch_flex_attn_available, is_torchdynamo_compiling, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_pegasus_x import PegasusXConfig @@ -206,11 +207,12 @@ class PegasusXAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, @@ -235,19 +237,19 @@ class PegasusXAttention(nn.Module): # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values @@ -257,7 +259,7 @@ class PegasusXAttention(nn.Module): key_states = key_states.view(*kv_input_shape).transpose(1, 2) value_states = value_states.view(*kv_input_shape).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( @@ -265,7 +267,7 @@ class PegasusXAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -672,13 +674,14 @@ class PegasusXDecoderLayer(GradientCheckpointingLayer): self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, cache_position: Optional[torch.Tensor] = None, @@ -692,7 +695,7 @@ class PegasusXDecoderLayer(GradientCheckpointingLayer): cross attention input to the layer of shape *(seq_len, batch, embed_dim)* encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size *(batch, 1, tgt_len, src_len)* where padding elements are indicated by very large negative values. - past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -707,7 +710,7 @@ class PegasusXDecoderLayer(GradientCheckpointingLayer): # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, attention_mask=attention_mask, output_attentions=output_attentions, cache_position=cache_position, @@ -725,7 +728,7 @@ class PegasusXDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -1369,7 +1372,7 @@ class PegasusXDecoder(PegasusXPreTrainedModel): causal_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -1537,7 +1540,7 @@ class PegasusXModel(PegasusXPreTrainedModel): attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index 3143c50bdb..45c66707b0 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -43,6 +43,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_persimmon import PersimmonConfig @@ -223,12 +224,13 @@ class PersimmonAttention(nn.Module): fused_qkv = fused_qkv.view(batch_size, seq_length, self.num_heads, 3, self.head_dim) return fused_qkv[..., 0, :], fused_qkv[..., 1, :], fused_qkv[..., 2, :] + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -270,7 +272,7 @@ class PersimmonAttention(nn.Module): query_states = torch.cat((query_rot, query_pass), dim=-1) key_states = torch.cat((key_rot, key_pass), dim=-1) - if past_key_value is not None: + if past_key_values is not None: # Specific to RoPE models with partial rotation cache_kwargs = { "sin": sin, @@ -278,7 +280,7 @@ class PersimmonAttention(nn.Module): "partial_rotation_size": self.rotary_ndims, "cache_position": cache_position, } - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -314,12 +316,13 @@ class PersimmonDecoderLayer(GradientCheckpointingLayer): self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.dropout = nn.Dropout(config.hidden_dropout) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -335,7 +338,7 @@ class PersimmonDecoderLayer(GradientCheckpointingLayer): Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under @@ -359,7 +362,7 @@ class PersimmonDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -508,7 +511,7 @@ class PersimmonModel(PersimmonPreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 265520eecf..dde7681402 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -23,6 +23,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import check_model_inputs from .configuration_phi import PhiConfig @@ -128,12 +129,13 @@ class PhiAttention(nn.Module): config.hidden_size // config.num_attention_heads, eps=config.layer_norm_eps, elementwise_affine=True ) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -165,10 +167,10 @@ class PhiAttention(nn.Module): query_states = torch.cat((query_rot, query_pass), dim=-1) key_states = torch.cat((key_rot, key_pass), dim=-1) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -213,12 +215,13 @@ class PhiDecoderLayer(GradientCheckpointingLayer): self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.resid_dropout = nn.Dropout(config.resid_pdrop) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -234,7 +237,7 @@ class PhiDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -399,7 +402,7 @@ class PhiModel(PhiPreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, diff --git a/src/transformers/models/phi/modular_phi.py b/src/transformers/models/phi/modular_phi.py index c028ee32a0..56eb541cb0 100644 --- a/src/transformers/models/phi/modular_phi.py +++ b/src/transformers/models/phi/modular_phi.py @@ -12,6 +12,7 @@ from ...modeling_outputs import ( from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...utils import TransformersKwargs, logging +from ...utils.deprecation import deprecate_kwarg from ..clip.modeling_clip import CLIPMLP from ..llama.modeling_llama import ( LlamaAttention, @@ -50,12 +51,13 @@ class PhiAttention(LlamaAttention): config.hidden_size // config.num_attention_heads, eps=config.layer_norm_eps, elementwise_affine=True ) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -87,10 +89,10 @@ class PhiAttention(LlamaAttention): query_states = torch.cat((query_rot, query_pass), dim=-1) key_states = torch.cat((key_rot, key_pass), dim=-1) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -124,12 +126,13 @@ class PhiDecoderLayer(GradientCheckpointingLayer): self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.resid_dropout = nn.Dropout(config.resid_pdrop) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -145,7 +148,7 @@ class PhiDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -248,7 +251,7 @@ class PhiModel(LlamaModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 743b9eae58..71aef43c2c 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -43,6 +43,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg from .configuration_phi3 import Phi3Config @@ -159,12 +160,13 @@ class Phi3Attention(nn.Module): self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) self.qkv_proj = nn.Linear(config.hidden_size, op_size, bias=False) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: @@ -184,10 +186,10 @@ class Phi3Attention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -243,12 +245,13 @@ class Phi3DecoderLayer(GradientCheckpointingLayer): self.resid_attn_dropout = nn.Dropout(config.resid_pdrop) self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -261,7 +264,7 @@ class Phi3DecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -399,7 +402,7 @@ class Phi3Model(Phi3PreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, diff --git a/src/transformers/models/phi3/modular_phi3.py b/src/transformers/models/phi3/modular_phi3.py index 5227f98dde..03cd5ad4ae 100644 --- a/src/transformers/models/phi3/modular_phi3.py +++ b/src/transformers/models/phi3/modular_phi3.py @@ -27,6 +27,7 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...utils import logging +from ...utils.deprecation import deprecate_kwarg from ..mistral.modeling_mistral import ( MistralDecoderLayer, MistralForCausalLM, @@ -113,12 +114,13 @@ class Phi3Attention(nn.Module): self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) self.qkv_proj = nn.Linear(config.hidden_size, op_size, bias=False) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: @@ -138,10 +140,10 @@ class Phi3Attention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -173,12 +175,13 @@ class Phi3DecoderLayer(MistralDecoderLayer): self.resid_attn_dropout = nn.Dropout(config.resid_pdrop) self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -191,7 +194,7 @@ class Phi3DecoderLayer(MistralDecoderLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, diff --git a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py index e52fd88e3c..449e1d8a3d 100644 --- a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py @@ -46,6 +46,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, can_return_tuple, torch_int +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import TransformersKwargs, check_model_inputs from .configuration_phi4_multimodal import Phi4MultimodalAudioConfig, Phi4MultimodalConfig, Phi4MultimodalVisionConfig @@ -1383,12 +1384,13 @@ class Phi4MultimodalAttention(nn.Module): self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) self.qkv_proj = nn.Linear(config.hidden_size, op_size, bias=False) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: @@ -1408,10 +1410,10 @@ class Phi4MultimodalAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -1446,12 +1448,13 @@ class Phi4MultimodalDecoderLayer(GradientCheckpointingLayer): self.resid_attn_dropout = nn.Dropout(config.resid_pdrop) self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -1464,7 +1467,7 @@ class Phi4MultimodalDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -1703,7 +1706,7 @@ class Phi4MultimodalModel(Phi4MultimodalPreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, diff --git a/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py index 8772796b03..e77c7a0fa3 100644 --- a/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modular_phi4_multimodal.py @@ -1556,7 +1556,7 @@ class Phi4MultimodalModel(Phi3Model, nn.Module): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index d46df04795..394d9c4282 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -35,6 +35,7 @@ from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPas from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_phimoe import PhimoeConfig @@ -268,12 +269,13 @@ class PhimoeAttention(nn.Module): def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -292,9 +294,9 @@ class PhimoeAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - if past_key_value is not None: + if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -336,12 +338,13 @@ class PhimoeFlashAttention2(PhimoeAttention): flash attention and deal with padding tokens in case the input contains any of them. """ + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -360,9 +363,9 @@ class PhimoeFlashAttention2(PhimoeAttention): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - if past_key_value is not None: + if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -431,12 +434,13 @@ class PhimoeSdpaAttention(PhimoeAttention): """ # Adapted from PhimoeAttention.forward + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -452,7 +456,7 @@ class PhimoeSdpaAttention(PhimoeAttention): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, position_embeddings=position_embeddings, @@ -471,9 +475,9 @@ class PhimoeSdpaAttention(PhimoeAttention): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - if past_key_value is not None: + if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -812,12 +816,13 @@ class PhimoeDecoderLayer(GradientCheckpointingLayer): config.hidden_size, eps=config.rms_norm_eps, elementwise_affine=True ) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -830,7 +835,7 @@ class PhimoeDecoderLayer(GradientCheckpointingLayer): hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -856,7 +861,7 @@ class PhimoeDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -1008,7 +1013,7 @@ class PhimoeModel(PhimoePreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, output_router_logits=output_router_logits, use_cache=use_cache, diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 8ef1ea8090..07623a0285 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -43,6 +43,7 @@ from ...utils import ( is_torchdynamo_compiling, logging, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_pix2struct import Pix2StructConfig, Pix2StructTextConfig, Pix2StructVisionConfig @@ -163,7 +164,7 @@ class Pix2StructVisionAttention(nn.Module): """ # Input is (batch_size, seq_length, dim) # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) - # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + # past_key_values[0] is (batch_size, n_heads, q_len - 1, dim_per_head) batch_size, seq_length = hidden_states.shape[:2] def to_projection_shape(states): @@ -733,13 +734,14 @@ class Pix2StructTextAttention(nn.Module): return values # Adapted from transformers.models.t5.modeling_t5.T5Attention.forward + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, mask=None, key_value_states=None, position_bias=None, - past_key_value=None, + past_key_values=None, layer_head_mask=None, query_length=None, use_cache=False, @@ -760,18 +762,18 @@ class Pix2StructTextAttention(nn.Module): query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache` - if past_key_value is not None and isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None and isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value and is_updated: + if is_cross_attention and past_key_values and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values @@ -781,7 +783,7 @@ class Pix2StructTextAttention(nn.Module): key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( @@ -789,7 +791,7 @@ class Pix2StructTextAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 scores = torch.matmul(query_states, key_states.transpose(3, 2)) @@ -854,13 +856,14 @@ class Pix2StructTextLayerSelfAttention(nn.Module): self.layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, attention_mask=None, position_bias=None, layer_head_mask=None, - past_key_value=None, + past_key_values=None, use_cache=False, output_attentions=False, cache_position=None, @@ -871,7 +874,7 @@ class Pix2StructTextLayerSelfAttention(nn.Module): mask=attention_mask, position_bias=position_bias, layer_head_mask=layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, @@ -889,6 +892,7 @@ class Pix2StructTextLayerCrossAttention(nn.Module): self.layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, @@ -896,7 +900,7 @@ class Pix2StructTextLayerCrossAttention(nn.Module): attention_mask=None, position_bias=None, layer_head_mask=None, - past_key_value=None, + past_key_values=None, use_cache=False, query_length=None, output_attentions=False, @@ -909,7 +913,7 @@ class Pix2StructTextLayerCrossAttention(nn.Module): key_value_states=key_value_states, position_bias=position_bias, layer_head_mask=layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, query_length=query_length, output_attentions=output_attentions, @@ -937,6 +941,7 @@ class Pix2StructTextBlock(GradientCheckpointingLayer): self.mlp = Pix2StructTextLayerFF(config) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, @@ -947,7 +952,7 @@ class Pix2StructTextBlock(GradientCheckpointingLayer): encoder_decoder_position_bias=None, layer_head_mask=None, cross_attn_layer_head_mask=None, - past_key_value=None, + past_key_values=None, use_cache=False, output_attentions=False, return_dict=True, @@ -958,7 +963,7 @@ class Pix2StructTextBlock(GradientCheckpointingLayer): attention_mask=attention_mask, position_bias=position_bias, layer_head_mask=layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, @@ -979,7 +984,7 @@ class Pix2StructTextBlock(GradientCheckpointingLayer): attention_mask=encoder_attention_mask, position_bias=encoder_decoder_position_bias, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, query_length=cache_position[-1] + 1, use_cache=use_cache, output_attentions=output_attentions, @@ -1196,7 +1201,7 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel): encoder_decoder_position_bias, # as a positional argument for gradient checkpointing layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 23184dd4c0..de92fb89aa 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -47,6 +47,7 @@ from ...modeling_outputs import ( from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, is_torch_flex_attn_available, is_torchdynamo_compiling, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_plbart import PLBartConfig @@ -370,11 +371,12 @@ class PLBartAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, @@ -399,19 +401,19 @@ class PLBartAttention(nn.Module): # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values @@ -421,7 +423,7 @@ class PLBartAttention(nn.Module): key_states = key_states.view(*kv_input_shape).transpose(1, 2) value_states = value_states.view(*kv_input_shape).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( @@ -429,7 +431,7 @@ class PLBartAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -720,6 +722,7 @@ class PLBartDecoderLayer(GradientCheckpointingLayer): self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -728,7 +731,7 @@ class PLBartDecoderLayer(GradientCheckpointingLayer): encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, cache_position: Optional[torch.Tensor] = None, @@ -746,7 +749,7 @@ class PLBartDecoderLayer(GradientCheckpointingLayer): `(encoder_attention_heads,)`. cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of size `(decoder_attention_heads,)`. - past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -759,7 +762,7 @@ class PLBartDecoderLayer(GradientCheckpointingLayer): # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, @@ -779,7 +782,7 @@ class PLBartDecoderLayer(GradientCheckpointingLayer): key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) @@ -1035,7 +1038,7 @@ class PLBartDecoder(PLBartPreTrainedModel): encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -1203,7 +1206,7 @@ class PLBartModel(PLBartPreTrainedModel): attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, diff --git a/src/transformers/models/plbart/modular_plbart.py b/src/transformers/models/plbart/modular_plbart.py index c0fcdea7af..2737039eef 100644 --- a/src/transformers/models/plbart/modular_plbart.py +++ b/src/transformers/models/plbart/modular_plbart.py @@ -381,7 +381,7 @@ class PLBartModel(PLBartPreTrainedModel): attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, diff --git a/src/transformers/models/pop2piano/modeling_pop2piano.py b/src/transformers/models/pop2piano/modeling_pop2piano.py index 9676c39455..ed5a182ca1 100644 --- a/src/transformers/models/pop2piano/modeling_pop2piano.py +++ b/src/transformers/models/pop2piano/modeling_pop2piano.py @@ -33,6 +33,7 @@ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPastAndCross from ...modeling_utils import PreTrainedModel from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, is_torch_flex_attn_available, is_torch_fx_proxy, is_torchdynamo_compiling, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_pop2piano import Pop2PianoConfig @@ -283,13 +284,14 @@ class Pop2PianoAttention(nn.Module): values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) return values + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, mask=None, key_value_states=None, position_bias=None, - past_key_value=None, + past_key_values=None, layer_head_mask=None, query_length=None, use_cache=False, @@ -310,18 +312,18 @@ class Pop2PianoAttention(nn.Module): query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache` - if past_key_value is not None and isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None and isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values @@ -331,7 +333,7 @@ class Pop2PianoAttention(nn.Module): key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( @@ -339,7 +341,7 @@ class Pop2PianoAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 scores = torch.matmul(query_states, key_states.transpose(3, 2)) @@ -404,13 +406,14 @@ class Pop2PianoLayerSelfAttention(nn.Module): self.layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, attention_mask=None, position_bias=None, layer_head_mask=None, - past_key_value=None, + past_key_values=None, use_cache=False, output_attentions=False, cache_position=None, @@ -421,7 +424,7 @@ class Pop2PianoLayerSelfAttention(nn.Module): mask=attention_mask, position_bias=position_bias, layer_head_mask=layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, @@ -439,6 +442,7 @@ class Pop2PianoLayerCrossAttention(nn.Module): self.layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, @@ -446,7 +450,7 @@ class Pop2PianoLayerCrossAttention(nn.Module): attention_mask=None, position_bias=None, layer_head_mask=None, - past_key_value=None, + past_key_values=None, use_cache=False, query_length=None, output_attentions=False, @@ -459,7 +463,7 @@ class Pop2PianoLayerCrossAttention(nn.Module): key_value_states=key_value_states, position_bias=position_bias, layer_head_mask=layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, query_length=query_length, output_attentions=output_attentions, @@ -486,6 +490,7 @@ class Pop2PianoBlock(GradientCheckpointingLayer): self.layer.append(Pop2PianoLayerFF(config)) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, @@ -496,7 +501,7 @@ class Pop2PianoBlock(GradientCheckpointingLayer): encoder_decoder_position_bias=None, layer_head_mask=None, cross_attn_layer_head_mask=None, - past_key_value=None, + past_key_values=None, use_cache=False, output_attentions=False, return_dict=True, @@ -507,7 +512,7 @@ class Pop2PianoBlock(GradientCheckpointingLayer): attention_mask=attention_mask, position_bias=position_bias, layer_head_mask=layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, @@ -532,7 +537,7 @@ class Pop2PianoBlock(GradientCheckpointingLayer): attention_mask=encoder_attention_mask, position_bias=encoder_decoder_position_bias, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, query_length=cache_position[-1] + 1, use_cache=use_cache, output_attentions=output_attentions, @@ -811,7 +816,7 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel): encoder_decoder_position_bias, # as a positional argument for gradient checkpointing layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index 57a41b33c5..46fe08536e 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -32,6 +32,7 @@ from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_prophetnet import ProphetNetConfig @@ -437,13 +438,14 @@ class ProphetNetAttention(nn.Module): self.out_proj = nn.Linear(hidden_size, hidden_size) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, key_value_states: Optional[Tensor] = None, attention_mask: Optional[Tensor] = None, layer_head_mask: Optional[Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[Tensor, Optional[Tensor]]: @@ -461,19 +463,19 @@ class ProphetNetAttention(nn.Module): # previous time steps are cached - no need to recompute key and value if they are static query_states = self.query_proj(hidden_states) / (self.head_dim**0.5) - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values @@ -483,7 +485,7 @@ class ProphetNetAttention(nn.Module): key_states = key_states.view(batch_size, -1, self.num_attn_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(batch_size, -1, self.num_attn_heads, self.head_dim).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( @@ -491,7 +493,7 @@ class ProphetNetAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True query_states = query_states.view(batch_size, tgt_len, self.num_attn_heads, self.head_dim).transpose(1, 2) src_len = key_states.size(2) @@ -606,10 +608,11 @@ class ProphetNetNgramSelfAttention(nn.Module): def prepare_for_onnx_export_(self): self.onnx_trace = True + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, - past_key_value: Optional[tuple[Tensor]] = None, + past_key_values: Optional[tuple[Tensor]] = None, attention_mask=None, layer_head_mask=None, extended_predict_attention_mask=None, @@ -655,11 +658,11 @@ class ProphetNetNgramSelfAttention(nn.Module): # ProphetNet has two separate attention layers, one for self and one for cross attention # We need to obtain the self attention only for this module, if `EncoderDecoderCache` - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - curr_past_key_value = past_key_value.self_attention_cache + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values main_key_states, main_value_states = curr_past_key_value.update( main_key_states, main_value_states, self.layer_idx, {"cache_position": cache_position} ) @@ -954,6 +957,7 @@ class ProphetNetDecoderLayer(GradientCheckpointingLayer): self.feed_forward = ProphetNetFeedForward(config, config.decoder_ffn_dim) self.feed_forward_layer_norm = LayerNorm(config.hidden_size) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, @@ -966,7 +970,7 @@ class ProphetNetDecoderLayer(GradientCheckpointingLayer): main_relative_position_buckets=None, predict_relative_position_buckets=None, position_ids=None, - past_key_value=None, + past_key_values=None, use_cache: Optional[bool] = True, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, @@ -974,7 +978,7 @@ class ProphetNetDecoderLayer(GradientCheckpointingLayer): # 1st residual block ngram_attention_output, self_attn_weights, self_attn_weights_ngram = self.self_attn( hidden_states=hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, attention_mask=attention_mask, layer_head_mask=layer_head_mask, extended_predict_attention_mask=extended_predict_attention_mask, @@ -992,7 +996,7 @@ class ProphetNetDecoderLayer(GradientCheckpointingLayer): key_value_states=encoder_hidden_states, attention_mask=encoder_attn_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, ) hidden_states = self.cross_attn_layer_norm(attention_output + hidden_states) @@ -1334,7 +1338,7 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel): main_relative_position_buckets=main_relative_position_buckets, predict_relative_position_buckets=predict_relative_position_buckets, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 6f47d3ae9a..5d39c58a24 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -26,6 +26,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import check_model_inputs from .configuration_qwen2 import Qwen2Config @@ -136,12 +137,13 @@ class Qwen2Attention(nn.Module): self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -155,10 +157,10 @@ class Qwen2Attention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -214,12 +216,13 @@ class Qwen2DecoderLayer(GradientCheckpointingLayer): self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.attention_type = config.layer_types[layer_idx] + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -232,7 +235,7 @@ class Qwen2DecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -382,7 +385,7 @@ class Qwen2Model(Qwen2PreTrainedModel): hidden_states, attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, diff --git a/src/transformers/models/qwen2/modular_qwen2.py b/src/transformers/models/qwen2/modular_qwen2.py index 0421e5bace..030cf82c2c 100644 --- a/src/transformers/models/qwen2/modular_qwen2.py +++ b/src/transformers/models/qwen2/modular_qwen2.py @@ -13,6 +13,7 @@ from ...modeling_outputs import ( from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import check_model_inputs from ..llama.modeling_llama import ( LlamaAttention, @@ -50,12 +51,13 @@ class Qwen2Attention(LlamaAttention): self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -69,10 +71,10 @@ class Qwen2Attention(LlamaAttention): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -170,7 +172,7 @@ class Qwen2Model(MistralModel): hidden_states, attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index b12aec70b2..774b0cdabc 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -41,6 +41,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, check_torch_load_is_safe, logging +from ...utils.deprecation import deprecate_kwarg from ...utils.hub import cached_file from .configuration_qwen2_5_omni import ( Qwen2_5OmniAudioEncoderConfig, @@ -1363,12 +1364,13 @@ class Qwen2_5OmniAttention(nn.Module): self.rotary_emb = Qwen2_5OmniRotaryEmbedding(config=config) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -1390,9 +1392,9 @@ class Qwen2_5OmniAttention(nn.Module): query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] ) - if past_key_value is not None: + if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -1447,12 +1449,13 @@ class Qwen2_5OmniDecoderLayer(GradientCheckpointingLayer): self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.attention_type = config.layer_types[layer_idx] + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -1470,7 +1473,7 @@ class Qwen2_5OmniDecoderLayer(GradientCheckpointingLayer): use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): @@ -1490,7 +1493,7 @@ class Qwen2_5OmniDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -1640,7 +1643,7 @@ class Qwen2_5OmniThinkerTextModel(Qwen2_5OmniPreTrainedModel): hidden_states, attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=text_position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -2220,7 +2223,7 @@ class Qwen2_5OmniTalkerModel(Qwen2_5OmniPreTrainedModel): hidden_states, attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=text_position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 25e6be0213..50e9a5c7fa 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -42,6 +42,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_qwen2_5_vl import Qwen2_5_VLConfig, Qwen2_5_VLTextConfig, Qwen2_5_VLVisionConfig @@ -657,12 +658,13 @@ class Qwen2_5_VLAttention(nn.Module): self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -684,9 +686,9 @@ class Qwen2_5_VLAttention(nn.Module): query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] ) - if past_key_value is not None: + if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -727,12 +729,13 @@ class Qwen2_5_VLDecoderLayer(GradientCheckpointingLayer): self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.attention_type = config.layer_types[layer_idx] + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -750,7 +753,7 @@ class Qwen2_5_VLDecoderLayer(GradientCheckpointingLayer): use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): @@ -770,7 +773,7 @@ class Qwen2_5_VLDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -919,7 +922,7 @@ class Qwen2_5_VLTextModel(Qwen2_5_VLPreTrainedModel): hidden_states, attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=text_position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index d51a4951e1..0ad984798d 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -45,6 +45,7 @@ from ...modeling_outputs import ( from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_qwen2_moe import Qwen2MoeConfig @@ -309,13 +310,13 @@ class Qwen2MoeAttention(nn.Module): self.rotary_emb = Qwen2MoeRotaryEmbedding(config=self.config) - # Ignore copy + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -334,9 +335,9 @@ class Qwen2MoeAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) - if past_key_value is not None: + if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -389,12 +390,13 @@ class Qwen2MoeFlashAttention2(Qwen2MoeAttention): # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -413,9 +415,9 @@ class Qwen2MoeFlashAttention2(Qwen2MoeAttention): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) # repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -496,12 +498,13 @@ class Qwen2MoeSdpaAttention(Qwen2MoeAttention): """ # Adapted from Qwen2MoeAttention.forward + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -517,7 +520,7 @@ class Qwen2MoeSdpaAttention(Qwen2MoeAttention): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -537,9 +540,9 @@ class Qwen2MoeSdpaAttention(Qwen2MoeAttention): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) @@ -664,12 +667,13 @@ class Qwen2MoeDecoderLayer(GradientCheckpointingLayer): self.input_layernorm = Qwen2MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen2MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -691,7 +695,7 @@ class Qwen2MoeDecoderLayer(GradientCheckpointingLayer): use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): @@ -711,7 +715,7 @@ class Qwen2MoeDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -859,7 +863,7 @@ class Qwen2MoeModel(Qwen2MoePreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, output_router_logits=output_router_logits, use_cache=use_cache, diff --git a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py index 472a8b525a..5e422451ae 100644 --- a/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py +++ b/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py @@ -45,6 +45,7 @@ from ...utils import ( is_torchdynamo_compiling, logging, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_qwen2_vl import Qwen2VLConfig, Qwen2VLTextConfig, Qwen2VLVisionConfig @@ -518,12 +519,13 @@ class Qwen2VLAttention(nn.Module): self.rotary_emb = Qwen2VLRotaryEmbedding(config=config) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -545,9 +547,9 @@ class Qwen2VLAttention(nn.Module): query_states, key_states, cos, sin, self.rope_scaling["mrope_section"] ) - if past_key_value is not None: + if past_key_values is not None: cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -588,12 +590,13 @@ class Qwen2VLDecoderLayer(GradientCheckpointingLayer): self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.attention_type = config.layer_types[layer_idx] + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -611,7 +614,7 @@ class Qwen2VLDecoderLayer(GradientCheckpointingLayer): use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): @@ -631,7 +634,7 @@ class Qwen2VLDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -891,7 +894,7 @@ class Qwen2VLTextModel(Qwen2VLPreTrainedModel): hidden_states, attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=text_position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index ce9e0fbbea..785fca6d15 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -41,6 +41,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import check_model_inputs from .configuration_qwen3 import Qwen3Config @@ -183,12 +184,13 @@ class Qwen3Attention(nn.Module): self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -202,10 +204,10 @@ class Qwen3Attention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -240,12 +242,13 @@ class Qwen3DecoderLayer(GradientCheckpointingLayer): self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.attention_type = config.layer_types[layer_idx] + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -258,7 +261,7 @@ class Qwen3DecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -408,7 +411,7 @@ class Qwen3Model(Qwen3PreTrainedModel): hidden_states, attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, diff --git a/src/transformers/models/qwen3/modular_qwen3.py b/src/transformers/models/qwen3/modular_qwen3.py index 1a729ab0b4..f1e38841fa 100644 --- a/src/transformers/models/qwen3/modular_qwen3.py +++ b/src/transformers/models/qwen3/modular_qwen3.py @@ -24,6 +24,7 @@ from ...modeling_outputs import CausalLMOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...utils import TransformersKwargs, logging +from ...utils.deprecation import deprecate_kwarg from ..gemma.modeling_gemma import GemmaMLP from ..llama.modeling_llama import ( LlamaAttention, @@ -63,12 +64,13 @@ class Qwen3Attention(LlamaAttention): self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -82,10 +84,10 @@ class Qwen3Attention(LlamaAttention): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index e9cffdb523..340281b7d9 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -42,6 +42,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import OutputRecorder, check_model_inputs from .configuration_qwen3_moe import Qwen3MoeConfig @@ -147,12 +148,13 @@ class Qwen3MoeAttention(nn.Module): self.k_norm = Qwen3MoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # thus post q_norm does not need reshape self.sliding_window = getattr(config, "sliding_window", None) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -166,10 +168,10 @@ class Qwen3MoeAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -300,13 +302,14 @@ class Qwen3MoeDecoderLayer(GradientCheckpointingLayer): self.input_layernorm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> torch.FloatTensor: @@ -324,7 +327,7 @@ class Qwen3MoeDecoderLayer(GradientCheckpointingLayer): use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): Indices depicting the position of the input sequence tokens in the sequence. position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): @@ -344,7 +347,7 @@ class Qwen3MoeDecoderLayer(GradientCheckpointingLayer): position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, cache_position=cache_position, **kwargs, ) @@ -486,7 +489,7 @@ class Qwen3MoeModel(Qwen3MoePreTrainedModel): position_embeddings=position_embeddings, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, **kwargs, diff --git a/src/transformers/models/qwen3_moe/modular_qwen3_moe.py b/src/transformers/models/qwen3_moe/modular_qwen3_moe.py index 8cf50cdce2..64f9c45725 100644 --- a/src/transformers/models/qwen3_moe/modular_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modular_qwen3_moe.py @@ -27,6 +27,7 @@ from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast from ...processing_utils import Unpack from ...utils import TransformersKwargs, logging +from ...utils.deprecation import deprecate_kwarg from ..llama.modeling_llama import ( LlamaForQuestionAnswering, LlamaForSequenceClassification, @@ -139,13 +140,14 @@ class Qwen3MoeDecoderLayer(Qwen2MoeDecoderLayer, nn.Module): self.input_layernorm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = Qwen3MoeRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> torch.FloatTensor: @@ -159,7 +161,7 @@ class Qwen3MoeDecoderLayer(Qwen2MoeDecoderLayer, nn.Module): position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, cache_position=cache_position, **kwargs, ) diff --git a/src/transformers/models/rag/modeling_tf_rag.py b/src/transformers/models/rag/modeling_tf_rag.py index f430239e0d..1553837728 100644 --- a/src/transformers/models/rag/modeling_tf_rag.py +++ b/src/transformers/models/rag/modeling_tf_rag.py @@ -567,27 +567,31 @@ class TFRagModel(TFRagPreTrainedModel): **kwargs, ) -> TFRetrievAugLMOutput: r""" - Returns: + Returns: - Example: + Example: - ```python - >>> from transformers import AutoTokenizer, RagRetriever, TFRagModel - >>> import torch + ```python + >>> from transformers import AutoTokenizer, RagRetriever, TFRagModel + >>> import torch + from ...utils.deprecation import deprecate_kwarg + from ...utils.deprecation import deprecate_kwarg + from ...utils.deprecation import deprecate_kwarg + from ...utils.deprecation import deprecate_kwarg - >>> tokenizer = AutoTokenizer.from_pretrained("facebook/rag-token-base") - >>> retriever = RagRetriever.from_pretrained( - ... "facebook/rag-token-base", index_name="exact", use_dummy_dataset=True - ... ) - >>> # initialize with RagRetriever to do everything in one forward call - >>> model = TFRagModel.from_pretrained("facebook/rag-token-base", retriever=retriever, from_pt=True) + >>> tokenizer = AutoTokenizer.from_pretrained("facebook/rag-token-base") + >>> retriever = RagRetriever.from_pretrained( + ... "facebook/rag-token-base", index_name="exact", use_dummy_dataset=True + ... ) + >>> # initialize with RagRetriever to do everything in one forward call + >>> model = TFRagModel.from_pretrained("facebook/rag-token-base", retriever=retriever, from_pt=True) - >>> input_dict = tokenizer.prepare_seq2seq_batch( - ... "How many people live in Paris?", "In Paris, there are 10 million people.", return_tensors="tf" - ... ) - >>> input_ids = input_dict["input_ids"] - >>> outputs = model(input_ids) - ```""" + >>> input_dict = tokenizer.prepare_seq2seq_batch( + ... "How many people live in Paris?", "In Paris, there are 10 million people.", return_tensors="tf" + ... ) + >>> input_ids = input_dict["input_ids"] + >>> outputs = model(input_ids) + ```""" assert "decoder_cached_states" not in kwargs, ( "Please use past_key_values to cache intermediate outputs" ) # from modeling_tf_bart.py diff --git a/src/transformers/models/reformer/modeling_reformer.py b/src/transformers/models/reformer/modeling_reformer.py index 97cb08deb9..65badc59b7 100755 --- a/src/transformers/models/reformer/modeling_reformer.py +++ b/src/transformers/models/reformer/modeling_reformer.py @@ -79,7 +79,7 @@ class ReformerDynamicCache(DynamicCache): def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]: """ - Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the + Support for backwards-compatible `past_key_values` indexing, e.g. `past_key_values[0][0].shape[2]` to get the sequence length. """ if layer_idx < len(self): @@ -89,7 +89,7 @@ class ReformerDynamicCache(DynamicCache): def __iter__(self): """ - Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over + Support for backwards-compatible `past_key_values` iteration, e.g. `for x in past_key_values:` to iterate over keys and values """ for layer_idx in range(len(self)): @@ -97,7 +97,7 @@ class ReformerDynamicCache(DynamicCache): def __len__(self): """ - Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds + Support for backwards-compatible `past_key_values` length, e.g. `len(past_key_values)`. This value corresponds to the number of layers in the model. """ return len(self.states_cache) diff --git a/src/transformers/models/rembert/modeling_rembert.py b/src/transformers/models/rembert/modeling_rembert.py index dc9c5f86ea..fc3f1862e6 100755 --- a/src/transformers/models/rembert/modeling_rembert.py +++ b/src/transformers/models/rembert/modeling_rembert.py @@ -40,6 +40,7 @@ from ...modeling_outputs import ( from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_rembert import RemBertConfig @@ -221,13 +222,14 @@ class RemBertSelfAttention(nn.Module): self.is_decoder = config.is_decoder self.layer_idx = layer_idx + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: bool = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple: @@ -239,19 +241,19 @@ class RemBertSelfAttention(nn.Module): ) is_cross_attention = encoder_hidden_states is not None - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_layer from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = encoder_hidden_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_layer = curr_past_key_value.layers[self.layer_idx].keys value_layer = curr_past_key_value.layers[self.layer_idx].values @@ -267,7 +269,7 @@ class RemBertSelfAttention(nn.Module): .transpose(1, 2) ) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_layer, value_layer = curr_past_key_value.update( @@ -275,7 +277,7 @@ class RemBertSelfAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -353,7 +355,7 @@ class RemBertAttention(nn.Module): attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -362,7 +364,7 @@ class RemBertAttention(nn.Module): attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) @@ -425,7 +427,7 @@ class RemBertLayer(GradientCheckpointingLayer): head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -434,7 +436,7 @@ class RemBertLayer(GradientCheckpointingLayer): attention_mask=attention_mask, head_mask=head_mask, output_attentions=output_attentions, - past_key_value=past_key_value, + past_key_values=past_key_values, cache_position=cache_position, ) attention_output = self_attention_outputs[0] @@ -452,7 +454,7 @@ class RemBertLayer(GradientCheckpointingLayer): attention_mask=encoder_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index 26a10273a5..e1770bb4db 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -42,6 +42,7 @@ from ...modeling_outputs import ( from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, get_torch_version, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_roberta import RobertaConfig @@ -166,13 +167,14 @@ class RobertaSelfAttention(nn.Module): self.is_decoder = config.is_decoder self.layer_idx = layer_idx + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -183,19 +185,19 @@ class RobertaSelfAttention(nn.Module): ) is_cross_attention = encoder_hidden_states is not None - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_layer from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = encoder_hidden_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_layer = curr_past_key_value.layers[self.layer_idx].keys value_layer = curr_past_key_value.layers[self.layer_idx].values @@ -209,7 +211,7 @@ class RobertaSelfAttention(nn.Module): batch_size, -1, self.num_attention_heads, self.attention_head_size ).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_layer, value_layer = curr_past_key_value.update( @@ -217,14 +219,14 @@ class RobertaSelfAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if past_key_value is not None: + if past_key_values is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -277,13 +279,14 @@ class RobertaSdpaSelfAttention(RobertaSelfAttention): self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") # Adapted from RobertaSelfAttention + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -301,7 +304,7 @@ class RobertaSdpaSelfAttention(RobertaSelfAttention): attention_mask, head_mask, encoder_hidden_states, - past_key_value, + past_key_values, output_attentions, cache_position, ) @@ -314,19 +317,19 @@ class RobertaSdpaSelfAttention(RobertaSelfAttention): is_cross_attention = encoder_hidden_states is not None current_states = encoder_hidden_states if is_cross_attention else hidden_states - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = encoder_hidden_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_layer = curr_past_key_value.layers[self.layer_idx].keys value_layer = curr_past_key_value.layers[self.layer_idx].values @@ -342,7 +345,7 @@ class RobertaSdpaSelfAttention(RobertaSelfAttention): .transpose(1, 2) ) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_layer, value_layer = curr_past_key_value.update( @@ -350,7 +353,7 @@ class RobertaSdpaSelfAttention(RobertaSelfAttention): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. @@ -432,13 +435,14 @@ class RobertaAttention(nn.Module): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -447,7 +451,7 @@ class RobertaAttention(nn.Module): attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) @@ -503,6 +507,7 @@ class RobertaLayer(GradientCheckpointingLayer): self.intermediate = RobertaIntermediate(config) self.output = RobertaOutput(config) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -510,7 +515,7 @@ class RobertaLayer(GradientCheckpointingLayer): head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -519,7 +524,7 @@ class RobertaLayer(GradientCheckpointingLayer): attention_mask=attention_mask, head_mask=head_mask, output_attentions=output_attentions, - past_key_value=past_key_value, + past_key_values=past_key_values, cache_position=cache_position, ) attention_output = self_attention_outputs[0] @@ -537,7 +542,7 @@ class RobertaLayer(GradientCheckpointingLayer): attention_mask=encoder_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) @@ -612,7 +617,7 @@ class RobertaEncoder(nn.Module): layer_head_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) diff --git a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py index c36f029cf3..30cc18801d 100644 --- a/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +++ b/src/transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py @@ -40,6 +40,7 @@ from ...modeling_outputs import ( from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_roberta_prelayernorm import RobertaPreLayerNormConfig @@ -165,13 +166,14 @@ class RobertaPreLayerNormSelfAttention(nn.Module): self.is_decoder = config.is_decoder self.layer_idx = layer_idx + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -182,19 +184,19 @@ class RobertaPreLayerNormSelfAttention(nn.Module): ) is_cross_attention = encoder_hidden_states is not None - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_layer from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = encoder_hidden_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_layer = curr_past_key_value.layers[self.layer_idx].keys value_layer = curr_past_key_value.layers[self.layer_idx].values @@ -208,7 +210,7 @@ class RobertaPreLayerNormSelfAttention(nn.Module): batch_size, -1, self.num_attention_heads, self.attention_head_size ).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_layer, value_layer = curr_past_key_value.update( @@ -216,14 +218,14 @@ class RobertaPreLayerNormSelfAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if past_key_value is not None: + if past_key_values is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -310,13 +312,14 @@ class RobertaPreLayerNormAttention(nn.Module): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -326,7 +329,7 @@ class RobertaPreLayerNormAttention(nn.Module): attention_mask, head_mask, encoder_hidden_states, - past_key_value, + past_key_values, output_attentions, cache_position, ) @@ -383,6 +386,7 @@ class RobertaPreLayerNormLayer(GradientCheckpointingLayer): self.intermediate = RobertaPreLayerNormIntermediate(config) self.output = RobertaPreLayerNormOutput(config) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -390,7 +394,7 @@ class RobertaPreLayerNormLayer(GradientCheckpointingLayer): head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -399,7 +403,7 @@ class RobertaPreLayerNormLayer(GradientCheckpointingLayer): attention_mask=attention_mask, head_mask=head_mask, output_attentions=output_attentions, - past_key_value=past_key_value, + past_key_values=past_key_values, cache_position=cache_position, ) attention_output = self_attention_outputs[0] @@ -417,7 +421,7 @@ class RobertaPreLayerNormLayer(GradientCheckpointingLayer): attention_mask=encoder_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) @@ -494,7 +498,7 @@ class RobertaPreLayerNormEncoder(nn.Module): layer_head_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) diff --git a/src/transformers/models/roc_bert/modeling_roc_bert.py b/src/transformers/models/roc_bert/modeling_roc_bert.py index a4e736c9f4..69f9787d22 100644 --- a/src/transformers/models/roc_bert/modeling_roc_bert.py +++ b/src/transformers/models/roc_bert/modeling_roc_bert.py @@ -40,6 +40,7 @@ from ...modeling_outputs import ( from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_roc_bert import RoCBertConfig @@ -280,13 +281,14 @@ class RoCBertSelfAttention(nn.Module): self.is_decoder = config.is_decoder self.layer_idx = layer_idx + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -297,19 +299,19 @@ class RoCBertSelfAttention(nn.Module): ) is_cross_attention = encoder_hidden_states is not None - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_layer from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = encoder_hidden_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_layer = curr_past_key_value.layers[self.layer_idx].keys value_layer = curr_past_key_value.layers[self.layer_idx].values @@ -323,7 +325,7 @@ class RoCBertSelfAttention(nn.Module): batch_size, -1, self.num_attention_heads, self.attention_head_size ).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_layer, value_layer = curr_past_key_value.update( @@ -331,14 +333,14 @@ class RoCBertSelfAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if past_key_value is not None: + if past_key_values is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -433,13 +435,14 @@ class RoCBertAttention(nn.Module): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -448,7 +451,7 @@ class RoCBertAttention(nn.Module): attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) @@ -504,6 +507,7 @@ class RoCBertLayer(GradientCheckpointingLayer): self.intermediate = RoCBertIntermediate(config) self.output = RoCBertOutput(config) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -511,7 +515,7 @@ class RoCBertLayer(GradientCheckpointingLayer): head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -520,7 +524,7 @@ class RoCBertLayer(GradientCheckpointingLayer): attention_mask=attention_mask, head_mask=head_mask, output_attentions=output_attentions, - past_key_value=past_key_value, + past_key_values=past_key_values, cache_position=cache_position, ) attention_output = self_attention_outputs[0] @@ -538,7 +542,7 @@ class RoCBertLayer(GradientCheckpointingLayer): attention_mask=encoder_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) @@ -613,7 +617,7 @@ class RoCBertEncoder(nn.Module): layer_head_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index 5a816efb61..2a4ddc6941 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -40,6 +40,7 @@ from ...modeling_outputs import ( from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_roformer import RoFormerConfig @@ -211,6 +212,7 @@ class RoFormerSelfAttention(nn.Module): self.rotary_value = config.rotary_value self.layer_idx = layer_idx + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, @@ -218,7 +220,7 @@ class RoFormerSelfAttention(nn.Module): sinusoidal_pos=None, head_mask=None, encoder_hidden_states=None, - past_key_value=None, + past_key_values=None, output_attentions=False, cache_position=None, ): @@ -233,19 +235,19 @@ class RoFormerSelfAttention(nn.Module): # such that the encoder's padding tokens are not attended to. is_cross_attention = encoder_hidden_states is not None - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_layer from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = encoder_hidden_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_layer = curr_past_key_value.layers[self.layer_idx].keys value_layer = curr_past_key_value.layers[self.layer_idx].values @@ -272,7 +274,7 @@ class RoFormerSelfAttention(nn.Module): sinusoidal_pos, query_layer, key_layer ) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_layer, value_layer = curr_past_key_value.update( @@ -280,7 +282,7 @@ class RoFormerSelfAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -378,7 +380,7 @@ class RoFormerAttention(nn.Module): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) - # End Copy + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, @@ -386,7 +388,7 @@ class RoFormerAttention(nn.Module): sinusoidal_pos=None, head_mask=None, encoder_hidden_states=None, - past_key_value=None, + past_key_values=None, output_attentions=False, cache_position=None, ): @@ -396,7 +398,7 @@ class RoFormerAttention(nn.Module): sinusoidal_pos=sinusoidal_pos, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) @@ -451,6 +453,7 @@ class RoFormerLayer(GradientCheckpointingLayer): self.intermediate = RoFormerIntermediate(config) self.output = RoFormerOutput(config) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, @@ -459,7 +462,7 @@ class RoFormerLayer(GradientCheckpointingLayer): head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, - past_key_value=None, + past_key_values=None, output_attentions=False, cache_position=None, ): @@ -469,7 +472,7 @@ class RoFormerLayer(GradientCheckpointingLayer): sinusoidal_pos=sinusoidal_pos, head_mask=head_mask, output_attentions=output_attentions, - past_key_value=past_key_value, + past_key_values=past_key_values, cache_position=cache_position, ) attention_output = self_attention_outputs[0] @@ -488,7 +491,7 @@ class RoFormerLayer(GradientCheckpointingLayer): sinusoidal_pos=sinusoidal_pos, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) diff --git a/src/transformers/models/rt_detr/modeling_rt_detr_resnet.py b/src/transformers/models/rt_detr/modeling_rt_detr_resnet.py index 770a9b0326..9098ec776b 100644 --- a/src/transformers/models/rt_detr/modeling_rt_detr_resnet.py +++ b/src/transformers/models/rt_detr/modeling_rt_detr_resnet.py @@ -339,24 +339,29 @@ class RTDetrResNetBackbone(RTDetrResNetPreTrainedModel, BackboneMixin): self, pixel_values: Tensor, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None ) -> BackboneOutput: r""" - Examples: + Examples: - ```python - >>> from transformers import RTDetrResNetConfig, RTDetrResNetBackbone - >>> import torch + ```python + >>> from transformers import RTDetrResNetConfig, RTDetrResNetBackbone + >>> import torch + from ...utils.deprecation import deprecate_kwarg + from ...utils.deprecation import deprecate_kwarg + from ...utils.deprecation import deprecate_kwarg + from ...utils.deprecation import deprecate_kwarg + from ...utils.deprecation import deprecate_kwarg - >>> config = RTDetrResNetConfig() - >>> model = RTDetrResNetBackbone(config) + >>> config = RTDetrResNetConfig() + >>> model = RTDetrResNetBackbone(config) - >>> pixel_values = torch.randn(1, 3, 224, 224) + >>> pixel_values = torch.randn(1, 3, 224, 224) - >>> with torch.no_grad(): - ... outputs = model(pixel_values) + >>> with torch.no_grad(): + ... outputs = model(pixel_values) - >>> feature_maps = outputs.feature_maps - >>> list(feature_maps[-1].shape) - [1, 2048, 7, 7] - ```""" + >>> feature_maps = outputs.feature_maps + >>> list(feature_maps[-1].shape) + [1, 2048, 7, 7] + ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index 6fb9c2ab3b..2a89449894 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -43,6 +43,7 @@ from ...modeling_outputs import ( ) from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_seamless_m4t import SeamlessM4TConfig @@ -1028,11 +1029,12 @@ class SeamlessM4TAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, cache_position: Optional[torch.Tensor] = None, @@ -1048,19 +1050,19 @@ class SeamlessM4TAttention(nn.Module): # get query proj query_states = self.q_proj(hidden_states) * self.scaling - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = encoder_hidden_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values @@ -1070,7 +1072,7 @@ class SeamlessM4TAttention(nn.Module): key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( @@ -1078,7 +1080,7 @@ class SeamlessM4TAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) @@ -1258,13 +1260,14 @@ class SeamlessM4TDecoderLayer(GradientCheckpointingLayer): self.ffn_layer_norm = nn.LayerNorm(config.hidden_size) self.ffn_dropout = nn.Dropout(config.activation_dropout) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, cache_position: Optional[torch.Tensor] = None, @@ -1281,7 +1284,7 @@ class SeamlessM4TDecoderLayer(GradientCheckpointingLayer): encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - past_key_value (`Tuple(torch.FloatTensor)`): + past_key_values (`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under @@ -1293,7 +1296,7 @@ class SeamlessM4TDecoderLayer(GradientCheckpointingLayer): # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, attention_mask=attention_mask, output_attentions=output_attentions, cache_position=cache_position, @@ -1310,7 +1313,7 @@ class SeamlessM4TDecoderLayer(GradientCheckpointingLayer): hidden_states, cross_attn_weights = self.cross_attention( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, attention_mask=encoder_attention_mask, output_attentions=output_attentions, cache_position=cache_position, @@ -1845,7 +1848,7 @@ class SeamlessM4TDecoder(SeamlessM4TPreTrainedModel): attention_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -1944,7 +1947,7 @@ class SeamlessM4TTextToUnitModel(SeamlessM4TPreTrainedModel): attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, @@ -2557,7 +2560,7 @@ class SeamlessM4TForTextToText(SeamlessM4TPreTrainedModel, GenerationMixin): encoder_attention_mask = attention_mask - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) decoder_outputs = self.text_decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, @@ -2816,7 +2819,7 @@ class SeamlessM4TForSpeechToText(SeamlessM4TPreTrainedModel, GenerationMixin): hidden_states=encoder_outputs[0], seq_lens=sub_sampled_lengths ) - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) decoder_outputs = self.text_decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, @@ -3084,7 +3087,7 @@ class SeamlessM4TForTextToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin): encoder_attention_mask = attention_mask - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) decoder_outputs = self.text_decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, @@ -3409,7 +3412,7 @@ class SeamlessM4TForSpeechToSpeech(SeamlessM4TPreTrainedModel, GenerationMixin): hidden_states=encoder_outputs[0], seq_lens=sub_sampled_lengths ) - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) decoder_outputs = self.text_decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, @@ -3799,7 +3802,7 @@ class SeamlessM4TModel(SeamlessM4TPreTrainedModel, GenerationMixin): hidden_states=encoder_outputs[0], seq_lens=sub_sampled_lengths ) - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) decoder_outputs = self.text_decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, diff --git a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py index c7ab1db016..821586e1fc 100644 --- a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +++ b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py @@ -40,6 +40,7 @@ from ...modeling_outputs import ( ) from ...modeling_utils import PreTrainedModel from ...utils import ModelOutput, auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_seamless_m4t_v2 import SeamlessM4Tv2Config @@ -900,11 +901,12 @@ class SeamlessM4Tv2Attention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, cache_position: Optional[torch.Tensor] = None, @@ -914,19 +916,19 @@ class SeamlessM4Tv2Attention(nn.Module): is_cross_attention = encoder_hidden_states is not None batch_size, seq_length = hidden_states.shape[:2] - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = encoder_hidden_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values @@ -936,7 +938,7 @@ class SeamlessM4Tv2Attention(nn.Module): key_states = key_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( @@ -944,7 +946,7 @@ class SeamlessM4Tv2Attention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True query_states = self.q_proj(hidden_states) query_states = query_states.reshape(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2) @@ -1092,13 +1094,14 @@ class SeamlessM4Tv2DecoderLayer(GradientCheckpointingLayer): self.ffn_layer_norm = nn.LayerNorm(config.hidden_size) self.ffn_dropout = nn.Dropout(config.activation_dropout) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, cache_position: Optional[torch.Tensor] = None, @@ -1115,7 +1118,7 @@ class SeamlessM4Tv2DecoderLayer(GradientCheckpointingLayer): encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - past_key_value (`Tuple(torch.FloatTensor)`): + past_key_values (`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under @@ -1127,7 +1130,7 @@ class SeamlessM4Tv2DecoderLayer(GradientCheckpointingLayer): # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, attention_mask=attention_mask, output_attentions=output_attentions, cache_position=cache_position, @@ -1144,7 +1147,7 @@ class SeamlessM4Tv2DecoderLayer(GradientCheckpointingLayer): hidden_states, cross_attn_weights = self.cross_attention( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, attention_mask=encoder_attention_mask, output_attentions=output_attentions, cache_position=cache_position, @@ -1888,7 +1891,7 @@ class SeamlessM4Tv2Decoder(SeamlessM4Tv2PreTrainedModel): attention_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -2765,7 +2768,7 @@ class SeamlessM4Tv2ForTextToText(SeamlessM4Tv2PreTrainedModel, GenerationMixin): encoder_attention_mask = attention_mask - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) decoder_outputs = self.text_decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, @@ -3031,7 +3034,7 @@ class SeamlessM4Tv2ForSpeechToText(SeamlessM4Tv2PreTrainedModel, GenerationMixin hidden_states=encoder_outputs[0], seq_lens=sub_sampled_lengths ) - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) decoder_outputs = self.text_decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, @@ -3307,7 +3310,7 @@ class SeamlessM4Tv2ForTextToSpeech(SeamlessM4Tv2PreTrainedModel, GenerationMixin encoder_attention_mask = attention_mask - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) decoder_outputs = self.text_decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, @@ -3670,7 +3673,7 @@ class SeamlessM4Tv2ForSpeechToSpeech(SeamlessM4Tv2PreTrainedModel, GenerationMix hidden_states=encoder_outputs[0], seq_lens=sub_sampled_lengths ) - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) decoder_outputs = self.text_decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, @@ -4097,7 +4100,7 @@ class SeamlessM4Tv2Model(SeamlessM4Tv2PreTrainedModel, GenerationMixin): hidden_states=encoder_outputs[0], seq_lens=sub_sampled_lengths ) - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) decoder_outputs = self.text_decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, diff --git a/src/transformers/models/smollm3/modeling_smollm3.py b/src/transformers/models/smollm3/modeling_smollm3.py index 3a891cff6c..7d34302888 100644 --- a/src/transformers/models/smollm3/modeling_smollm3.py +++ b/src/transformers/models/smollm3/modeling_smollm3.py @@ -41,6 +41,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg from ...utils.generic import check_model_inputs from .configuration_smollm3 import SmolLM3Config @@ -150,12 +151,13 @@ class SmolLM3Attention(nn.Module): else None ) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -170,9 +172,9 @@ class SmolLM3Attention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: cache_kwargs = {"cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -244,12 +246,13 @@ class SmolLM3DecoderLayer(GradientCheckpointingLayer): self.post_attention_layernorm = SmolLM3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.attention_type = config.layer_types[layer_idx] + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -262,7 +265,7 @@ class SmolLM3DecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -412,7 +415,7 @@ class SmolLM3Model(SmolLM3PreTrainedModel): hidden_states, attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, diff --git a/src/transformers/models/smollm3/modular_smollm3.py b/src/transformers/models/smollm3/modular_smollm3.py index b19083f650..e06fd1e1b1 100644 --- a/src/transformers/models/smollm3/modular_smollm3.py +++ b/src/transformers/models/smollm3/modular_smollm3.py @@ -24,6 +24,7 @@ from ...modeling_rope_utils import rope_config_validation from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...utils import logging +from ...utils.deprecation import deprecate_kwarg from ..llama.modeling_llama import ( LlamaAttention, LlamaDecoderLayer, @@ -273,12 +274,13 @@ class SmolLM3Attention(LlamaAttention): else None ) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: @@ -293,9 +295,9 @@ class SmolLM3Attention(LlamaAttention): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: cache_kwargs = {"cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 81b36c8bf9..56b9a582a0 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -45,6 +45,7 @@ from ...utils import ( is_torch_flex_attn_available, logging, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_speech_to_text import Speech2TextConfig @@ -243,11 +244,12 @@ class Speech2TextAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, @@ -272,19 +274,19 @@ class Speech2TextAttention(nn.Module): # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_layer from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values @@ -292,7 +294,7 @@ class Speech2TextAttention(nn.Module): key_states = self.k_proj(current_states).view(*kv_input_shape).transpose(1, 2) value_states = self.v_proj(current_states).view(*kv_input_shape).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( @@ -300,7 +302,7 @@ class Speech2TextAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -423,6 +425,7 @@ class Speech2TextDecoderLayer(GradientCheckpointingLayer): self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoderLayer.forward def forward( self, @@ -432,7 +435,7 @@ class Speech2TextDecoderLayer(GradientCheckpointingLayer): encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, cache_position: Optional[torch.Tensor] = None, @@ -450,7 +453,7 @@ class Speech2TextDecoderLayer(GradientCheckpointingLayer): `(encoder_attention_heads,)`. cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of size `(decoder_attention_heads,)`. - past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -461,7 +464,7 @@ class Speech2TextDecoderLayer(GradientCheckpointingLayer): # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, @@ -481,7 +484,7 @@ class Speech2TextDecoderLayer(GradientCheckpointingLayer): key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) @@ -928,7 +931,7 @@ class Speech2TextDecoder(Speech2TextPreTrainedModel): encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -1159,7 +1162,7 @@ class Speech2TextModel(Speech2TextPreTrainedModel): else: encoder_attention_mask = None - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, diff --git a/src/transformers/models/speecht5/modeling_speecht5.py b/src/transformers/models/speecht5/modeling_speecht5.py index afb5dba86c..60ddec72c3 100644 --- a/src/transformers/models/speecht5/modeling_speecht5.py +++ b/src/transformers/models/speecht5/modeling_speecht5.py @@ -39,6 +39,7 @@ from ...modeling_outputs import ( ) from ...modeling_utils import EmbeddingAccessMixin, PreTrainedModel from ...utils import auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_speecht5 import SpeechT5Config, SpeechT5HifiGanConfig @@ -876,11 +877,12 @@ class SpeechT5Attention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, position_bias: Optional[torch.Tensor] = None, @@ -898,19 +900,19 @@ class SpeechT5Attention(nn.Module): # get query proj query_states = self.q_proj(hidden_states) * self.scaling - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values @@ -920,7 +922,7 @@ class SpeechT5Attention(nn.Module): key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( @@ -928,7 +930,7 @@ class SpeechT5Attention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) @@ -1115,6 +1117,7 @@ class SpeechT5DecoderLayer(GradientCheckpointingLayer): self.feed_forward = SpeechT5FeedForward(config, config.decoder_ffn_dim) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -1123,7 +1126,7 @@ class SpeechT5DecoderLayer(GradientCheckpointingLayer): encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, cache_position: Optional[torch.Tensor] = None, @@ -1141,7 +1144,7 @@ class SpeechT5DecoderLayer(GradientCheckpointingLayer): `(encoder_attention_heads,)`. cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of size `(decoder_attention_heads,)`. - past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -1151,7 +1154,7 @@ class SpeechT5DecoderLayer(GradientCheckpointingLayer): # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, @@ -1171,7 +1174,7 @@ class SpeechT5DecoderLayer(GradientCheckpointingLayer): key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) @@ -1637,7 +1640,7 @@ class SpeechT5Decoder(SpeechT5PreTrainedModel): encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 978d9e3ecd..8e711c5891 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -43,6 +43,7 @@ from ...modeling_outputs import ( from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_stablelm import StableLmConfig @@ -220,12 +221,13 @@ class StableLmAttention(nn.Module): self.attention_dropout = nn.Dropout(config.attention_dropout) self.rotary_emb = StableLmRotaryEmbedding(config=self.config) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -263,7 +265,7 @@ class StableLmAttention(nn.Module): query_states = torch.cat((query_rot, query_pass), dim=-1) key_states = torch.cat((key_rot, key_pass), dim=-1) - if past_key_value is not None: + if past_key_values is not None: # Specific to RoPE models with partial rotation cache_kwargs = { "sin": sin, @@ -271,7 +273,7 @@ class StableLmAttention(nn.Module): "partial_rotation_size": self.rotary_ndims, "cache_position": cache_position, } - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) # Repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -307,12 +309,13 @@ class StableLmAttention(nn.Module): class StableLmSdpaAttention(StableLmAttention): + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -328,7 +331,7 @@ class StableLmSdpaAttention(StableLmAttention): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -367,7 +370,7 @@ class StableLmSdpaAttention(StableLmAttention): query_states = torch.cat((query_rot, query_pass), dim=-1) key_states = torch.cat((key_rot, key_pass), dim=-1) - if past_key_value is not None: + if past_key_values is not None: # Specific to RoPE models with partial rotation cache_kwargs = { "sin": sin, @@ -375,7 +378,7 @@ class StableLmSdpaAttention(StableLmAttention): "partial_rotation_size": self.rotary_ndims, "cache_position": cache_position, } - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) # Repeat k/v heads if n_kv_heads < n_heads key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -429,12 +432,13 @@ class StableLmFlashAttention2(StableLmAttention): # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -479,14 +483,14 @@ class StableLmFlashAttention2(StableLmAttention): query_states = torch.cat((query_rot, query_pass), dim=-1) key_states = torch.cat((key_rot, key_pass), dim=-1) - if past_key_value is not None: + if past_key_values is not None: cache_kwargs = { "sin": sin, "cos": cos, "partial_rotation_size": self.rotary_ndims, "cache_position": cache_position, } - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache # to be able to avoid many of these transpose/reshape/view. @@ -542,7 +546,7 @@ class StableLmDecoderLayer(GradientCheckpointingLayer): hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -558,7 +562,7 @@ class StableLmDecoderLayer(GradientCheckpointingLayer): `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) - past_key_value (`Tuple(torch.FloatTensor)`, *optional*): + past_key_values (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under @@ -582,7 +586,7 @@ class StableLmDecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -735,7 +739,7 @@ class StableLmModel(StableLmPreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index 281bacf4d5..98d9bf415f 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -46,6 +46,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple +from ...utils.deprecation import deprecate_kwarg from .configuration_starcoder2 import Starcoder2Config @@ -156,12 +157,13 @@ class Starcoder2Attention(nn.Module): self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias) self.residual_dropout = config.residual_dropout + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: @@ -175,10 +177,10 @@ class Starcoder2Attention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -214,12 +216,13 @@ class Starcoder2DecoderLayer(GradientCheckpointingLayer): self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC @@ -232,7 +235,7 @@ class Starcoder2DecoderLayer(GradientCheckpointingLayer): hidden_states=hidden_states, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -374,7 +377,7 @@ class Starcoder2Model(Starcoder2PreTrainedModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, diff --git a/src/transformers/models/starcoder2/modular_starcoder2.py b/src/transformers/models/starcoder2/modular_starcoder2.py index 349ffb8acb..81cad212b0 100644 --- a/src/transformers/models/starcoder2/modular_starcoder2.py +++ b/src/transformers/models/starcoder2/modular_starcoder2.py @@ -35,6 +35,7 @@ from ...modeling_outputs import BaseModelOutputWithPast from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...utils import TransformersKwargs, logging +from ...utils.deprecation import deprecate_kwarg from ..mistral.modeling_mistral import ( MistralAttention, MistralDecoderLayer, @@ -78,12 +79,13 @@ class Starcoder2Attention(MistralAttention): self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.use_bias) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.use_bias) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: @@ -97,10 +99,10 @@ class Starcoder2Attention(MistralAttention): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -201,7 +203,7 @@ class Starcoder2Model(MistralModel): hidden_states, attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, diff --git a/src/transformers/models/switch_transformers/modeling_switch_transformers.py b/src/transformers/models/switch_transformers/modeling_switch_transformers.py index d6241ff134..0027a48c0d 100644 --- a/src/transformers/models/switch_transformers/modeling_switch_transformers.py +++ b/src/transformers/models/switch_transformers/modeling_switch_transformers.py @@ -45,6 +45,7 @@ from ...utils import ( is_torchdynamo_compiling, logging, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_switch_transformers import SwitchTransformersConfig @@ -476,13 +477,14 @@ class SwitchTransformersAttention(nn.Module): values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) return values + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, mask=None, key_value_states=None, position_bias=None, - past_key_value=None, + past_key_values=None, layer_head_mask=None, query_length=None, use_cache=False, @@ -503,18 +505,18 @@ class SwitchTransformersAttention(nn.Module): query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache` - if past_key_value is not None and isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None and isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values @@ -524,7 +526,7 @@ class SwitchTransformersAttention(nn.Module): key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( @@ -532,7 +534,7 @@ class SwitchTransformersAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 scores = torch.matmul(query_states, key_states.transpose(3, 2)) @@ -597,13 +599,14 @@ class SwitchTransformersLayerSelfAttention(nn.Module): self.layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, attention_mask=None, position_bias=None, layer_head_mask=None, - past_key_value=None, + past_key_values=None, use_cache=False, output_attentions=False, cache_position=None, @@ -614,7 +617,7 @@ class SwitchTransformersLayerSelfAttention(nn.Module): mask=attention_mask, position_bias=position_bias, layer_head_mask=layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, @@ -634,6 +637,7 @@ class SwitchTransformersLayerCrossAttention(nn.Module): self.layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, @@ -641,7 +645,7 @@ class SwitchTransformersLayerCrossAttention(nn.Module): attention_mask=None, position_bias=None, layer_head_mask=None, - past_key_value=None, + past_key_values=None, use_cache=False, query_length=None, output_attentions=False, @@ -654,7 +658,7 @@ class SwitchTransformersLayerCrossAttention(nn.Module): key_value_states=key_value_states, position_bias=position_bias, layer_head_mask=layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, query_length=query_length, output_attentions=output_attentions, @@ -691,7 +695,7 @@ class SwitchTransformersBlock(GradientCheckpointingLayer): encoder_decoder_position_bias=None, layer_head_mask=None, cross_attn_layer_head_mask=None, - past_key_value=None, + past_key_values=None, use_cache=False, output_attentions=False, output_router_logits=True, @@ -703,7 +707,7 @@ class SwitchTransformersBlock(GradientCheckpointingLayer): attention_mask=attention_mask, position_bias=position_bias, layer_head_mask=layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, @@ -724,7 +728,7 @@ class SwitchTransformersBlock(GradientCheckpointingLayer): attention_mask=encoder_attention_mask, position_bias=encoder_decoder_position_bias, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, query_length=cache_position[-1] + 1, use_cache=use_cache, output_attentions=output_attentions, @@ -1022,7 +1026,7 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel): encoder_decoder_position_bias, layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, output_router_logits=output_router_logits, diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index aeb8de0424..f8f4059c20 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -50,6 +50,7 @@ from ...utils import ( is_torchdynamo_compiling, logging, ) +from ...utils.deprecation import deprecate_kwarg from ...utils.model_parallel_utils import assert_device_map, get_device_map from .configuration_t5 import T5Config @@ -464,13 +465,14 @@ class T5Attention(nn.Module): values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) return values + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, mask=None, key_value_states=None, position_bias=None, - past_key_value=None, + past_key_values=None, layer_head_mask=None, query_length=None, use_cache=False, @@ -491,18 +493,18 @@ class T5Attention(nn.Module): query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache` - if past_key_value is not None and isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None and isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values @@ -512,7 +514,7 @@ class T5Attention(nn.Module): key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( @@ -520,7 +522,7 @@ class T5Attention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 scores = torch.matmul(query_states, key_states.transpose(3, 2)) @@ -584,13 +586,14 @@ class T5LayerSelfAttention(nn.Module): self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, attention_mask=None, position_bias=None, layer_head_mask=None, - past_key_value=None, + past_key_values=None, use_cache=False, output_attentions=False, cache_position=None, @@ -601,7 +604,7 @@ class T5LayerSelfAttention(nn.Module): mask=attention_mask, position_bias=position_bias, layer_head_mask=layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, @@ -618,6 +621,7 @@ class T5LayerCrossAttention(nn.Module): self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, @@ -625,7 +629,7 @@ class T5LayerCrossAttention(nn.Module): attention_mask=None, position_bias=None, layer_head_mask=None, - past_key_value=None, + past_key_values=None, use_cache=False, query_length=None, output_attentions=False, @@ -638,7 +642,7 @@ class T5LayerCrossAttention(nn.Module): key_value_states=key_value_states, position_bias=position_bias, layer_head_mask=layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, query_length=query_length, output_attentions=output_attentions, @@ -662,6 +666,7 @@ class T5Block(GradientCheckpointingLayer): self.layer.append(T5LayerFF(config)) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, @@ -672,7 +677,7 @@ class T5Block(GradientCheckpointingLayer): encoder_decoder_position_bias=None, layer_head_mask=None, cross_attn_layer_head_mask=None, - past_key_value=None, + past_key_values=None, use_cache=False, output_attentions=False, return_dict=True, @@ -683,7 +688,7 @@ class T5Block(GradientCheckpointingLayer): attention_mask=attention_mask, position_bias=position_bias, layer_head_mask=layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, @@ -708,7 +713,7 @@ class T5Block(GradientCheckpointingLayer): attention_mask=encoder_attention_mask, position_bias=encoder_decoder_position_bias, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, query_length=cache_position[-1] + 1, use_cache=use_cache, output_attentions=output_attentions, @@ -1098,7 +1103,7 @@ class T5Stack(T5PreTrainedModel): encoder_decoder_position_bias, # as a positional argument for gradient checkpointing layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, return_dict=return_dict, diff --git a/src/transformers/models/t5gemma/modeling_t5gemma.py b/src/transformers/models/t5gemma/modeling_t5gemma.py index 0a1b62396e..8fcb072ca8 100644 --- a/src/transformers/models/t5gemma/modeling_t5gemma.py +++ b/src/transformers/models/t5gemma/modeling_t5gemma.py @@ -44,6 +44,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_t5gemma import T5GemmaConfig, T5GemmaModuleConfig @@ -235,12 +236,13 @@ class T5GemmaSelfAttention(nn.Module): self.attn_logit_softcapping = self.config.attn_logit_softcapping self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: @@ -254,10 +256,10 @@ class T5GemmaSelfAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -312,12 +314,13 @@ class T5GemmaCrossAttention(nn.Module): if config.cross_attention_hidden_size is None: raise ValueError("Cross-attention needs cross_attention_hidden_size to be specified.") + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], encoder_hidden_states: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: if encoder_hidden_states is None: @@ -327,19 +330,19 @@ class T5GemmaCrossAttention(nn.Module): hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) - if past_key_value is not None: - is_updated = past_key_value.is_updated.get(self.layer_idx) - curr_past_key_value = past_key_value.cross_attention_cache + if past_key_values is not None: + is_updated = past_key_values.is_updated.get(self.layer_idx) + curr_past_key_value = past_key_values.cross_attention_cache - if past_key_value is None or not is_updated: + if past_key_values is None or not is_updated: encoder_input_shape = encoder_hidden_states.shape[:-1] encoder_hidden_shape = (*encoder_input_shape, -1, self.head_dim) key_states = self.k_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2) value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: key_states, value_states = curr_past_key_value.update(key_states, value_states, self.layer_idx) - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True else: key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values @@ -404,7 +407,7 @@ class T5GemmaEncoderLayer(GradientCheckpointingLayer): position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=None, + past_key_values=None, **kwargs, ) hidden_states = self.post_self_attn_layernorm(hidden_states) @@ -427,13 +430,14 @@ class T5GemmaDecoderLayer(T5GemmaEncoderLayer): self.pre_cross_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_cross_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[EncoderDecoderCache] = None, + past_key_values: Optional[EncoderDecoderCache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, @@ -447,7 +451,7 @@ class T5GemmaDecoderLayer(T5GemmaEncoderLayer): position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value.self_attention_cache if past_key_value is not None else None, + past_key_values=past_key_values.self_attention_cache if past_key_values is not None else None, use_cache=use_cache, cache_position=cache_position, **kwargs, @@ -461,7 +465,7 @@ class T5GemmaDecoderLayer(T5GemmaEncoderLayer): hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, **kwargs, ) @@ -530,12 +534,13 @@ class T5GemmaAttention(nn.Module): self.attn_logit_softcapping = self.config.attn_logit_softcapping self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: @@ -549,10 +554,10 @@ class T5GemmaAttention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: + if past_key_values is not None: # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -712,6 +717,9 @@ class T5GemmaEncoder(T5GemmaPreTrainedModel): if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + # As we want to pass `past_key_values=None` explicitly everwhere, we need to pop them from kwargs if present + kwargs.pop("past_key_values", None) + if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) diff --git a/src/transformers/models/t5gemma/modular_t5gemma.py b/src/transformers/models/t5gemma/modular_t5gemma.py index 74b66ad530..92f1b59ea8 100644 --- a/src/transformers/models/t5gemma/modular_t5gemma.py +++ b/src/transformers/models/t5gemma/modular_t5gemma.py @@ -43,6 +43,7 @@ from ...utils import ( is_torchdynamo_compiling, logging, ) +from ...utils.deprecation import deprecate_kwarg from ..gemma2.configuration_gemma2 import Gemma2Config from ..gemma2.modeling_gemma2 import ( Gemma2Attention, @@ -262,12 +263,13 @@ class T5GemmaCrossAttention(Gemma2Attention): config.cross_attention_hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias ) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], encoder_hidden_states: Optional[torch.Tensor], - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: if encoder_hidden_states is None: @@ -277,19 +279,19 @@ class T5GemmaCrossAttention(Gemma2Attention): hidden_shape = (*input_shape, -1, self.head_dim) query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) - if past_key_value is not None: - is_updated = past_key_value.is_updated.get(self.layer_idx) - curr_past_key_value = past_key_value.cross_attention_cache + if past_key_values is not None: + is_updated = past_key_values.is_updated.get(self.layer_idx) + curr_past_key_value = past_key_values.cross_attention_cache - if past_key_value is None or not is_updated: + if past_key_values is None or not is_updated: encoder_input_shape = encoder_hidden_states.shape[:-1] encoder_hidden_shape = (*encoder_input_shape, -1, self.head_dim) key_states = self.k_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2) value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: key_states, value_states = curr_past_key_value.update(key_states, value_states, self.layer_idx) - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True else: key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values @@ -378,7 +380,7 @@ class T5GemmaEncoderLayer(GradientCheckpointingLayer): position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=None, + past_key_values=None, **kwargs, ) hidden_states = self.post_self_attn_layernorm(hidden_states) @@ -401,13 +403,14 @@ class T5GemmaDecoderLayer(T5GemmaEncoderLayer): self.pre_cross_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_cross_attn_layernorm = T5GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[EncoderDecoderCache] = None, + past_key_values: Optional[EncoderDecoderCache] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, @@ -421,7 +424,7 @@ class T5GemmaDecoderLayer(T5GemmaEncoderLayer): position_embeddings=position_embeddings, attention_mask=attention_mask, position_ids=position_ids, - past_key_value=past_key_value.self_attention_cache if past_key_value is not None else None, + past_key_values=past_key_values.self_attention_cache if past_key_values is not None else None, use_cache=use_cache, cache_position=cache_position, **kwargs, @@ -435,7 +438,7 @@ class T5GemmaDecoderLayer(T5GemmaEncoderLayer): hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, **kwargs, ) @@ -577,6 +580,9 @@ class T5GemmaEncoder(T5GemmaPreTrainedModel): if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + # As we want to pass `past_key_values=None` explicitly everwhere, we need to pop them from kwargs if present + kwargs.pop("past_key_values", None) + if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) diff --git a/src/transformers/models/tapas/modeling_tapas.py b/src/transformers/models/tapas/modeling_tapas.py index b8a681a0aa..6704c77b3b 100644 --- a/src/transformers/models/tapas/modeling_tapas.py +++ b/src/transformers/models/tapas/modeling_tapas.py @@ -32,6 +32,7 @@ from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, Mas from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import ModelOutput, auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_tapas import TapasConfig @@ -300,13 +301,14 @@ class TapasSelfAttention(nn.Module): self.is_decoder = config.is_decoder self.layer_idx = layer_idx + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, - past_key_value=None, + past_key_values=None, output_attentions=False, cache_position=None, ): @@ -318,19 +320,19 @@ class TapasSelfAttention(nn.Module): ) is_cross_attention = encoder_hidden_states is not None - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_layer from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = encoder_hidden_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_layer = curr_past_key_value.layers[self.layer_idx].keys value_layer = curr_past_key_value.layers[self.layer_idx].values @@ -346,7 +348,7 @@ class TapasSelfAttention(nn.Module): .transpose(1, 2) ) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_layer, value_layer = curr_past_key_value.update( @@ -354,7 +356,7 @@ class TapasSelfAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) @@ -382,7 +384,7 @@ class TapasSelfAttention(nn.Module): outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) if self.is_decoder: - outputs = outputs + (past_key_value,) + outputs = outputs + (past_key_values,) return outputs @@ -434,7 +436,7 @@ class TapasAttention(nn.Module): attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -443,7 +445,7 @@ class TapasAttention(nn.Module): attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) @@ -506,7 +508,7 @@ class TapasLayer(GradientCheckpointingLayer): head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -515,7 +517,7 @@ class TapasLayer(GradientCheckpointingLayer): attention_mask=attention_mask, head_mask=head_mask, output_attentions=output_attentions, - past_key_value=past_key_value, + past_key_values=past_key_values, cache_position=cache_position, ) attention_output = self_attention_outputs[0] @@ -533,7 +535,7 @@ class TapasLayer(GradientCheckpointingLayer): attention_mask=encoder_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) @@ -597,7 +599,7 @@ class TapasEncoder(nn.Module): layer_head_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 84d5de004c..20e6b11b68 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -42,6 +42,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput from ...utils import auto_docstring, is_torch_flex_attn_available, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_time_series_transformer import TimeSeriesTransformerConfig @@ -351,11 +352,12 @@ class TimeSeriesTransformerAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, @@ -380,19 +382,19 @@ class TimeSeriesTransformerAttention(nn.Module): # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values @@ -402,7 +404,7 @@ class TimeSeriesTransformerAttention(nn.Module): key_states = key_states.view(*kv_input_shape).transpose(1, 2) value_states = value_states.view(*kv_input_shape).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( @@ -410,7 +412,7 @@ class TimeSeriesTransformerAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -540,6 +542,7 @@ class TimeSeriesTransformerDecoderLayer(GradientCheckpointingLayer): self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -548,7 +551,7 @@ class TimeSeriesTransformerDecoderLayer(GradientCheckpointingLayer): encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, cache_position: Optional[torch.Tensor] = None, @@ -566,7 +569,7 @@ class TimeSeriesTransformerDecoderLayer(GradientCheckpointingLayer): `(encoder_attention_heads,)`. cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of size `(decoder_attention_heads,)`. - past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -579,7 +582,7 @@ class TimeSeriesTransformerDecoderLayer(GradientCheckpointingLayer): # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, @@ -599,7 +602,7 @@ class TimeSeriesTransformerDecoderLayer(GradientCheckpointingLayer): key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) @@ -1061,7 +1064,7 @@ class TimeSeriesTransformerDecoder(TimeSeriesTransformerPreTrainedModel): encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index 0ea823b4ca..e9a12069a0 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -32,6 +32,7 @@ from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_trocr import TrOCRConfig @@ -178,11 +179,12 @@ class TrOCRAttention(nn.Module): self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = False, @@ -198,19 +200,19 @@ class TrOCRAttention(nn.Module): # get query proj query_states = self.q_proj(hidden_states) * self.scaling - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values @@ -220,7 +222,7 @@ class TrOCRAttention(nn.Module): key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( @@ -228,7 +230,7 @@ class TrOCRAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) @@ -330,6 +332,7 @@ class TrOCRDecoderLayer(GradientCheckpointingLayer): self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -338,7 +341,7 @@ class TrOCRDecoderLayer(GradientCheckpointingLayer): encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, cache_position: Optional[torch.Tensor] = None, @@ -356,7 +359,7 @@ class TrOCRDecoderLayer(GradientCheckpointingLayer): `(encoder_attention_heads,)`. cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of size *(decoder_attention_heads,)*. - past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -366,7 +369,7 @@ class TrOCRDecoderLayer(GradientCheckpointingLayer): # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, @@ -387,7 +390,7 @@ class TrOCRDecoderLayer(GradientCheckpointingLayer): key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) @@ -648,7 +651,7 @@ class TrOCRDecoder(TrOCRPreTrainedModel): encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, diff --git a/src/transformers/models/udop/modeling_udop.py b/src/transformers/models/udop/modeling_udop.py index a84d15051b..5a807e283e 100644 --- a/src/transformers/models/udop/modeling_udop.py +++ b/src/transformers/models/udop/modeling_udop.py @@ -47,6 +47,7 @@ from ...utils import ( is_torch_flex_attn_available, is_torchdynamo_compiling, ) +from ...utils.deprecation import deprecate_kwarg if is_torch_flex_attn_available(): @@ -562,13 +563,14 @@ class UdopAttention(nn.Module): values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) return values + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, mask=None, key_value_states=None, position_bias=None, - past_key_value=None, + past_key_values=None, layer_head_mask=None, query_length=None, use_cache=False, @@ -589,18 +591,18 @@ class UdopAttention(nn.Module): query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache` - if past_key_value is not None and isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None and isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values @@ -610,7 +612,7 @@ class UdopAttention(nn.Module): key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( @@ -618,7 +620,7 @@ class UdopAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 scores = torch.matmul(query_states, key_states.transpose(3, 2)) @@ -683,13 +685,14 @@ class UdopLayerSelfAttention(nn.Module): self.layer_norm = UdopLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, attention_mask=None, position_bias=None, layer_head_mask=None, - past_key_value=None, + past_key_values=None, use_cache=False, output_attentions=False, cache_position=None, @@ -700,7 +703,7 @@ class UdopLayerSelfAttention(nn.Module): mask=attention_mask, position_bias=position_bias, layer_head_mask=layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, @@ -718,6 +721,7 @@ class UdopLayerCrossAttention(nn.Module): self.layer_norm = UdopLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, @@ -725,7 +729,7 @@ class UdopLayerCrossAttention(nn.Module): attention_mask=None, position_bias=None, layer_head_mask=None, - past_key_value=None, + past_key_values=None, use_cache=False, query_length=None, output_attentions=False, @@ -738,7 +742,7 @@ class UdopLayerCrossAttention(nn.Module): key_value_states=key_value_states, position_bias=position_bias, layer_head_mask=layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, query_length=query_length, output_attentions=output_attentions, @@ -765,6 +769,7 @@ class UdopBlock(GradientCheckpointingLayer): self.layer.append(UdopLayerFF(config)) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, @@ -775,7 +780,7 @@ class UdopBlock(GradientCheckpointingLayer): encoder_decoder_position_bias=None, layer_head_mask=None, cross_attn_layer_head_mask=None, - past_key_value=None, + past_key_values=None, use_cache=False, output_attentions=False, return_dict=True, @@ -786,7 +791,7 @@ class UdopBlock(GradientCheckpointingLayer): attention_mask=attention_mask, position_bias=position_bias, layer_head_mask=layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, @@ -811,7 +816,7 @@ class UdopBlock(GradientCheckpointingLayer): attention_mask=encoder_attention_mask, position_bias=encoder_decoder_position_bias, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, query_length=cache_position[-1] + 1, use_cache=use_cache, output_attentions=output_attentions, @@ -1293,7 +1298,7 @@ class UdopStack(UdopPreTrainedModel): encoder_extended_attention_mask, encoder_decoder_position_bias, # as a positional argument for gradient checkpointing layer_head_mask=head_mask[i], - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, diff --git a/src/transformers/models/umt5/modeling_umt5.py b/src/transformers/models/umt5/modeling_umt5.py index e171be26fe..aa1470e006 100644 --- a/src/transformers/models/umt5/modeling_umt5.py +++ b/src/transformers/models/umt5/modeling_umt5.py @@ -46,6 +46,7 @@ from ...utils import ( is_torchdynamo_compiling, logging, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_umt5 import UMT5Config @@ -257,11 +258,12 @@ class UMT5Attention(nn.Module): values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) return values + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, - past_key_value: Optional[tuple[torch.Tensor]] = None, + past_key_values: Optional[tuple[torch.Tensor]] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cache_position: Optional[torch.Tensor] = None, @@ -275,18 +277,18 @@ class UMT5Attention(nn.Module): query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache` - if past_key_value is not None and isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None and isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = encoder_hidden_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values @@ -296,7 +298,7 @@ class UMT5Attention(nn.Module): key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( @@ -304,13 +306,13 @@ class UMT5Attention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 scores = torch.matmul(query_states, key_states.transpose(3, 2)) # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) - real_seq_length = seq_length + past_key_value.get_seq_length() if past_key_value is not None else seq_length + real_seq_length = seq_length + past_key_values.get_seq_length() if past_key_values is not None else seq_length key_length = key_states.shape[-2] if not self.has_relative_attention_bias: position_bias = torch.zeros( @@ -359,12 +361,13 @@ class UMT5LayerSelfAttention(nn.Module): self.layer_norm = UMT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, attention_mask=None, layer_head_mask=None, - past_key_value=None, + past_key_values=None, cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) @@ -372,7 +375,7 @@ class UMT5LayerSelfAttention(nn.Module): normed_hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, cache_position=cache_position, ) hidden_states = hidden_states + self.dropout(attention_output[0]) @@ -387,13 +390,14 @@ class UMT5LayerCrossAttention(nn.Module): self.layer_norm = UMT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, encoder_hidden_states=None, attention_mask=None, layer_head_mask=None, - past_key_value=None, + past_key_values=None, cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) @@ -402,7 +406,7 @@ class UMT5LayerCrossAttention(nn.Module): encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, cache_position=cache_position, ) layer_output = hidden_states + self.dropout(attention_output[0]) @@ -421,6 +425,7 @@ class UMT5Block(GradientCheckpointingLayer): self.layer.append(UMT5LayerFF(config)) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, @@ -429,7 +434,7 @@ class UMT5Block(GradientCheckpointingLayer): encoder_attention_mask=None, layer_head_mask=None, cross_attn_layer_head_mask=None, - past_key_value=None, + past_key_values=None, use_cache=False, output_attentions=False, cache_position=None, @@ -438,7 +443,7 @@ class UMT5Block(GradientCheckpointingLayer): hidden_states, attention_mask=attention_mask, layer_head_mask=layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, cache_position=cache_position, ) @@ -457,7 +462,7 @@ class UMT5Block(GradientCheckpointingLayer): encoder_hidden_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, cache_position=cache_position, ) # clamp inf values to enable fp16 training @@ -762,7 +767,7 @@ class UMT5Stack(UMT5PreTrainedModel): encoder_attention_mask=encoder_extended_attention_mask, layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_values, + past_key_values=past_key_values, use_cache=use_cache, output_attentions=output_attentions, cache_position=cache_position, diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index ad8ac7cee3..bfeb0eb7b9 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -42,6 +42,7 @@ from ...modeling_outputs import ( from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_whisper import WhisperConfig from .generation_whisper import WhisperGenerationMixin @@ -284,11 +285,12 @@ class WhisperAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, @@ -317,30 +319,30 @@ class WhisperAttention(nn.Module): query_states = query_states.transpose(1, 2).contiguous() # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache` - if past_key_value is not None and isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None and isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - past_key_value.is_updated[self.layer_idx] = True - past_key_value = past_key_value.cross_attention_cache + past_key_values.is_updated[self.layer_idx] = True + past_key_values = past_key_values.cross_attention_cache else: - past_key_value = past_key_value.self_attention_cache + past_key_values = past_key_values.self_attention_cache # use key_value_states if cross attention current_states = key_value_states if key_value_states is not None else hidden_states - if is_cross_attention and past_key_value and is_updated: + if is_cross_attention and past_key_values and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value.layers[self.layer_idx].keys - value_states = past_key_value.layers[self.layer_idx].values + key_states = past_key_values.layers[self.layer_idx].keys + value_states = past_key_values.layers[self.layer_idx].values else: key_states = self.k_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim) value_states = self.v_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim) key_states = key_states.transpose(1, 2).contiguous() value_states = value_states.transpose(1, 2).contiguous() - if past_key_value is not None: + if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None - key_states, value_states = past_key_value.update( + key_states, value_states = past_key_values.update( key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) @@ -463,6 +465,7 @@ class WhisperDecoderLayer(GradientCheckpointingLayer): self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -471,7 +474,7 @@ class WhisperDecoderLayer(GradientCheckpointingLayer): encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[EncoderDecoderCache] = None, + past_key_values: Optional[EncoderDecoderCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, cache_position: Optional[torch.LongTensor] = None, @@ -489,7 +492,7 @@ class WhisperDecoderLayer(GradientCheckpointingLayer): `(encoder_attention_heads,)`. cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of size `(decoder_attention_heads,)`. - past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -500,7 +503,7 @@ class WhisperDecoderLayer(GradientCheckpointingLayer): # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, @@ -519,7 +522,7 @@ class WhisperDecoderLayer(GradientCheckpointingLayer): key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -939,7 +942,7 @@ class WhisperDecoder(WhisperPreTrainedModel): encoder_hidden_states=encoder_hidden_states, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), - past_key_value=past_key_values if use_cache else None, + past_key_values=past_key_values if use_cache else None, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -1136,7 +1139,7 @@ class WhisperModel(WhisperPreTrainedModel): attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) - # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + # decoder outputs consists of (dec_features, past_key_values, dec_hidden, dec_attn) decoder_outputs = self.decoder( input_ids=decoder_input_ids, attention_mask=decoder_attention_mask, diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index 7f54c3fac8..d9df105204 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -29,6 +29,7 @@ from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_xglm import XGLMConfig @@ -133,11 +134,12 @@ class XGLMAttention(nn.Module): self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, @@ -155,19 +157,19 @@ class XGLMAttention(nn.Module): # get query proj query_states = self.q_proj(hidden_states) * self.scaling - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = key_value_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_states = curr_past_key_value.layers[self.layer_idx].keys value_states = curr_past_key_value.layers[self.layer_idx].values @@ -177,7 +179,7 @@ class XGLMAttention(nn.Module): key_states = key_states.view(bsz, src_len, -1, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, src_len, -1, self.head_dim).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_states, value_states = curr_past_key_value.update( @@ -185,7 +187,7 @@ class XGLMAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim).transpose(1, 2) @@ -291,6 +293,7 @@ class XGLMDecoderLayer(GradientCheckpointingLayer): self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoderLayer.forward def forward( self, @@ -300,7 +303,7 @@ class XGLMDecoderLayer(GradientCheckpointingLayer): encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, cache_position: Optional[torch.Tensor] = None, @@ -318,7 +321,7 @@ class XGLMDecoderLayer(GradientCheckpointingLayer): `(encoder_attention_heads,)`. cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of size `(decoder_attention_heads,)`. - past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + past_key_values (`Tuple(torch.FloatTensor)`): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -329,7 +332,7 @@ class XGLMDecoderLayer(GradientCheckpointingLayer): # Self Attention hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, @@ -349,7 +352,7 @@ class XGLMDecoderLayer(GradientCheckpointingLayer): key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) @@ -556,7 +559,7 @@ class XGLMModel(XGLMPreTrainedModel): encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, diff --git a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py index d1c70b6028..fdff73ff77 100644 --- a/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py @@ -42,6 +42,7 @@ from ...modeling_outputs import ( from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, get_torch_version, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_xlm_roberta import XLMRobertaConfig @@ -167,13 +168,14 @@ class XLMRobertaSelfAttention(nn.Module): self.is_decoder = config.is_decoder self.layer_idx = layer_idx + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -184,19 +186,19 @@ class XLMRobertaSelfAttention(nn.Module): ) is_cross_attention = encoder_hidden_states is not None - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_layer from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = encoder_hidden_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_layer = curr_past_key_value.layers[self.layer_idx].keys value_layer = curr_past_key_value.layers[self.layer_idx].values @@ -210,7 +212,7 @@ class XLMRobertaSelfAttention(nn.Module): batch_size, -1, self.num_attention_heads, self.attention_head_size ).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_layer, value_layer = curr_past_key_value.update( @@ -218,14 +220,14 @@ class XLMRobertaSelfAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if past_key_value is not None: + if past_key_values is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -278,13 +280,14 @@ class XLMRobertaSdpaSelfAttention(XLMRobertaSelfAttention): self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") # Adapted from XLMRobertaSelfAttention + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -302,7 +305,7 @@ class XLMRobertaSdpaSelfAttention(XLMRobertaSelfAttention): attention_mask, head_mask, encoder_hidden_states, - past_key_value, + past_key_values, output_attentions, cache_position, ) @@ -315,19 +318,19 @@ class XLMRobertaSdpaSelfAttention(XLMRobertaSelfAttention): is_cross_attention = encoder_hidden_states is not None current_states = encoder_hidden_states if is_cross_attention else hidden_states - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = encoder_hidden_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_layer = curr_past_key_value.layers[self.layer_idx].keys value_layer = curr_past_key_value.layers[self.layer_idx].values @@ -343,7 +346,7 @@ class XLMRobertaSdpaSelfAttention(XLMRobertaSelfAttention): .transpose(1, 2) ) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_layer, value_layer = curr_past_key_value.update( @@ -351,7 +354,7 @@ class XLMRobertaSdpaSelfAttention(XLMRobertaSelfAttention): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. @@ -433,13 +436,14 @@ class XLMRobertaAttention(nn.Module): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -448,7 +452,7 @@ class XLMRobertaAttention(nn.Module): attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) @@ -504,6 +508,7 @@ class XLMRobertaLayer(GradientCheckpointingLayer): self.intermediate = XLMRobertaIntermediate(config) self.output = XLMRobertaOutput(config) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -511,7 +516,7 @@ class XLMRobertaLayer(GradientCheckpointingLayer): head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -520,7 +525,7 @@ class XLMRobertaLayer(GradientCheckpointingLayer): attention_mask=attention_mask, head_mask=head_mask, output_attentions=output_attentions, - past_key_value=past_key_value, + past_key_values=past_key_values, cache_position=cache_position, ) attention_output = self_attention_outputs[0] @@ -538,7 +543,7 @@ class XLMRobertaLayer(GradientCheckpointingLayer): attention_mask=encoder_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) @@ -613,7 +618,7 @@ class XLMRobertaEncoder(nn.Module): layer_head_mask, encoder_hidden_states, # as a positional argument for gradient checkpointing encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) diff --git a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py index 62793c38bc..c666aef841 100644 --- a/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +++ b/src/transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py @@ -41,6 +41,7 @@ from ...modeling_outputs import ( from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, get_torch_version, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_xlm_roberta_xl import XLMRobertaXLConfig @@ -164,13 +165,14 @@ class XLMRobertaXLSelfAttention(nn.Module): self.is_decoder = config.is_decoder self.layer_idx = layer_idx + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -181,19 +183,19 @@ class XLMRobertaXLSelfAttention(nn.Module): ) is_cross_attention = encoder_hidden_states is not None - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_layer from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = encoder_hidden_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_layer = curr_past_key_value.layers[self.layer_idx].keys value_layer = curr_past_key_value.layers[self.layer_idx].values @@ -207,7 +209,7 @@ class XLMRobertaXLSelfAttention(nn.Module): batch_size, -1, self.num_attention_heads, self.attention_head_size ).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_layer, value_layer = curr_past_key_value.update( @@ -215,14 +217,14 @@ class XLMRobertaXLSelfAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if past_key_value is not None: + if past_key_values is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -275,13 +277,14 @@ class XLMRobertaXLSdpaSelfAttention(XLMRobertaXLSelfAttention): self.require_contiguous_qkv = version.parse(get_torch_version()) < version.parse("2.2.0") # Adapted from XLMRobertaXLSelfAttention + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -299,7 +302,7 @@ class XLMRobertaXLSdpaSelfAttention(XLMRobertaXLSelfAttention): attention_mask, head_mask, encoder_hidden_states, - past_key_value, + past_key_values, output_attentions, cache_position, ) @@ -312,19 +315,19 @@ class XLMRobertaXLSdpaSelfAttention(XLMRobertaXLSelfAttention): is_cross_attention = encoder_hidden_states is not None current_states = encoder_hidden_states if is_cross_attention else hidden_states - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = encoder_hidden_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_layer = curr_past_key_value.layers[self.layer_idx].keys value_layer = curr_past_key_value.layers[self.layer_idx].values @@ -340,7 +343,7 @@ class XLMRobertaXLSdpaSelfAttention(XLMRobertaXLSelfAttention): .transpose(1, 2) ) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_layer, value_layer = curr_past_key_value.update( @@ -348,7 +351,7 @@ class XLMRobertaXLSdpaSelfAttention(XLMRobertaXLSelfAttention): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. @@ -428,13 +431,14 @@ class XLMRobertaXLAttention(nn.Module): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, - past_key_value=None, + past_key_values=None, output_attentions=False, cache_position=None, ): @@ -444,7 +448,7 @@ class XLMRobertaXLAttention(nn.Module): attention_mask, head_mask, encoder_hidden_states, - past_key_value, + past_key_values, output_attentions, cache_position, ) @@ -498,6 +502,7 @@ class XLMRobertaXLLayer(GradientCheckpointingLayer): self.output = XLMRobertaXLOutput(config) self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states, @@ -505,7 +510,7 @@ class XLMRobertaXLLayer(GradientCheckpointingLayer): head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None, - past_key_value=None, + past_key_values=None, output_attentions=False, cache_position=None, ): @@ -514,7 +519,7 @@ class XLMRobertaXLLayer(GradientCheckpointingLayer): attention_mask=attention_mask, head_mask=head_mask, output_attentions=output_attentions, - past_key_value=past_key_value, + past_key_values=past_key_values, cache_position=cache_position, ) attention_output = self_attention_outputs[0] @@ -532,7 +537,7 @@ class XLMRobertaXLLayer(GradientCheckpointingLayer): attention_mask=encoder_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) diff --git a/src/transformers/models/xmod/modeling_xmod.py b/src/transformers/models/xmod/modeling_xmod.py index 6b4ac64f4e..80e16dc966 100644 --- a/src/transformers/models/xmod/modeling_xmod.py +++ b/src/transformers/models/xmod/modeling_xmod.py @@ -39,6 +39,7 @@ from ...modeling_outputs import ( from ...modeling_utils import PreTrainedModel from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer from ...utils import auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_xmod import XmodConfig @@ -164,13 +165,14 @@ class XmodSelfAttention(nn.Module): self.is_decoder = config.is_decoder self.layer_idx = layer_idx + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -181,19 +183,19 @@ class XmodSelfAttention(nn.Module): ) is_cross_attention = encoder_hidden_states is not None - if past_key_value is not None: - if isinstance(past_key_value, EncoderDecoderCache): - is_updated = past_key_value.is_updated.get(self.layer_idx) + if past_key_values is not None: + if isinstance(past_key_values, EncoderDecoderCache): + is_updated = past_key_values.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_layer from cache - curr_past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_values.cross_attention_cache else: - curr_past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_values.self_attention_cache else: - curr_past_key_value = past_key_value + curr_past_key_value = past_key_values current_states = encoder_hidden_states if is_cross_attention else hidden_states - if is_cross_attention and past_key_value is not None and is_updated: + if is_cross_attention and past_key_values is not None and is_updated: # reuse k,v, cross_attentions key_layer = curr_past_key_value.layers[self.layer_idx].keys value_layer = curr_past_key_value.layers[self.layer_idx].values @@ -207,7 +209,7 @@ class XmodSelfAttention(nn.Module): batch_size, -1, self.num_attention_heads, self.attention_head_size ).transpose(1, 2) - if past_key_value is not None: + if past_key_values is not None: # save all key/value_layer to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None key_layer, value_layer = curr_past_key_value.update( @@ -215,14 +217,14 @@ class XmodSelfAttention(nn.Module): ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: - past_key_value.is_updated[self.layer_idx] = True + past_key_values.is_updated[self.layer_idx] = True # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": query_length, key_length = query_layer.shape[2], key_layer.shape[2] - if past_key_value is not None: + if past_key_values is not None: position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( -1, 1 ) @@ -309,13 +311,14 @@ class XmodAttention(nn.Module): self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -327,7 +330,7 @@ class XmodAttention(nn.Module): attention_mask, head_mask, encoder_hidden_states, - past_key_value, + past_key_values, output_attentions, cache_position, ) @@ -438,6 +441,7 @@ class XmodLayer(GradientCheckpointingLayer): self.output = XmodOutput(config) self.pre_norm = config.pre_norm + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -446,7 +450,7 @@ class XmodLayer(GradientCheckpointingLayer): head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.FloatTensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - past_key_value: Optional[Cache] = None, + past_key_values: Optional[Cache] = None, output_attentions: Optional[bool] = False, cache_position: Optional[torch.Tensor] = None, ) -> tuple[torch.Tensor]: @@ -455,7 +459,7 @@ class XmodLayer(GradientCheckpointingLayer): attention_mask=attention_mask, head_mask=head_mask, output_attentions=output_attentions, - past_key_value=past_key_value, + past_key_values=past_key_values, cache_position=cache_position, ) attention_output = self_attention_outputs[0] @@ -473,7 +477,7 @@ class XmodLayer(GradientCheckpointingLayer): attention_mask=encoder_attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, cache_position=cache_position, ) diff --git a/src/transformers/models/zamba/modeling_zamba.py b/src/transformers/models/zamba/modeling_zamba.py index 16290ea4e1..e04af25feb 100644 --- a/src/transformers/models/zamba/modeling_zamba.py +++ b/src/transformers/models/zamba/modeling_zamba.py @@ -36,6 +36,7 @@ from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available from .configuration_zamba import ZambaConfig @@ -256,12 +257,13 @@ class ZambaAttention(nn.Module): self.v_proj = nn.Linear(config.attention_hidden_size, config.num_key_value_heads * self.head_dim, bias=False) self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, layer_idx: int, attention_mask: Optional[torch.Tensor], - past_key_value: Optional[ZambaHybridDynamicCache] = None, + past_key_values: Optional[ZambaHybridDynamicCache] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: input_shape = hidden_states.shape[:-1] @@ -271,8 +273,8 @@ class ZambaAttention(nn.Module): key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - if past_key_value is not None: - key_states, value_states = past_key_value.update(key_states, value_states, layer_idx) + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, layer_idx) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -596,13 +598,14 @@ class ZambaAttentionDecoderLayer(nn.Module): self.input_layernorm = ZambaRMSNorm(config.attention_hidden_size, eps=config.rms_norm_eps) self.pre_ff_layernorm = ZambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, original_hidden_states: torch.Tensor, layer_idx: int, attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[ZambaHybridDynamicCache] = None, + past_key_values: Optional[ZambaHybridDynamicCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, **kwargs: Unpack[FlashAttentionKwargs], @@ -617,7 +620,7 @@ class ZambaAttentionDecoderLayer(nn.Module): layer_idx (`int`): layer_idx in the forward pass. Used to distinguish Zamba's tied transformer layers. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_value (`ZambaHybridDynamicCache`, *optional*): cached past key and value projection states + past_key_values (`ZambaHybridDynamicCache`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -633,7 +636,7 @@ class ZambaAttentionDecoderLayer(nn.Module): hidden_states=hidden_states, layer_idx=layer_idx, attention_mask=attention_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, **kwargs, @@ -657,6 +660,7 @@ class ZambaMambaDecoderLayer(nn.Module): self.input_layernorm = ZambaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.layer_idx = layer_idx + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -664,7 +668,7 @@ class ZambaMambaDecoderLayer(nn.Module): layer_idx: Optional[int] = None, attention_mask: Optional[torch.Tensor] = None, causal_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[ZambaHybridDynamicCache] = None, + past_key_values: Optional[ZambaHybridDynamicCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -676,7 +680,7 @@ class ZambaMambaDecoderLayer(nn.Module): hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_value (`ZambaHybridDynamicCache`, *optional*): cached past key and value projection states + past_key_values (`ZambaHybridDynamicCache`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -698,7 +702,7 @@ class ZambaMambaDecoderLayer(nn.Module): hidden_states = self.mamba( hidden_states=hidden_states, - cache_params=past_key_value, + cache_params=past_key_values, attention_mask=attention_mask, ) @@ -713,7 +717,7 @@ class ZambaMambaDecoderLayer(nn.Module): outputs += (self_attn_weights,) if use_cache: - outputs += (past_key_value,) + outputs += (past_key_values,) return outputs @@ -725,6 +729,7 @@ class ZambaHybridLayer(nn.Module): self.linear = linear self.mamba_decoder = mamba + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -732,7 +737,7 @@ class ZambaHybridLayer(nn.Module): layer_idx: Optional[int] = None, attention_mask: Optional[torch.Tensor] = None, causal_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[ZambaHybridDynamicCache] = None, + past_key_values: Optional[ZambaHybridDynamicCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -745,7 +750,7 @@ class ZambaHybridLayer(nn.Module): layer_idx (`int`): layer number. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_value (`ZambaHybridDynamicCache`, *optional*): cached past key and value projection states + past_key_values (`ZambaHybridDynamicCache`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -761,7 +766,7 @@ class ZambaHybridLayer(nn.Module): original_hidden_states=original_hidden_states, layer_idx=layer_idx, attention_mask=causal_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -778,7 +783,7 @@ class ZambaHybridLayer(nn.Module): hidden_states, transformer_hidden_states=transformer_hidden_states, attention_mask=attention_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -972,7 +977,7 @@ class ZambaModel(ZambaPreTrainedModel): layer_idx=layer_idx, attention_mask=attention_mask, causal_mask=causal_mask, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, diff --git a/src/transformers/models/zamba2/modeling_zamba2.py b/src/transformers/models/zamba2/modeling_zamba2.py index 1ce66bf962..2f1e1e0bc6 100644 --- a/src/transformers/models/zamba2/modeling_zamba2.py +++ b/src/transformers/models/zamba2/modeling_zamba2.py @@ -38,6 +38,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from ...utils.import_utils import is_causal_conv1d_available, is_mamba_ssm_available from .configuration_zamba2 import Zamba2Config @@ -395,12 +396,13 @@ class Zamba2Attention(nn.Module): self.layer_dic = {value: index for index, value in enumerate(self.layer_block_map)} + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, layer_idx: int, attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Zamba2HybridDynamicCache] = None, + past_key_values: Optional[Zamba2HybridDynamicCache] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: @@ -424,8 +426,8 @@ class Zamba2Attention(nn.Module): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: - key_states, value_states = past_key_value.update(key_states, value_states, layer_idx) + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, layer_idx) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -986,13 +988,14 @@ class Zamba2AttentionDecoderLayer(nn.Module): self.input_layernorm = Zamba2RMSNorm(config.attention_hidden_size, eps=config.rms_norm_eps) self.pre_ff_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, original_hidden_states: torch.Tensor, layer_idx: int, attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Zamba2HybridDynamicCache] = None, + past_key_values: Optional[Zamba2HybridDynamicCache] = None, output_attentions: Optional[bool] = False, position_embeddings: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -1006,7 +1009,7 @@ class Zamba2AttentionDecoderLayer(nn.Module): (see fig. 2 in https://huggingface.co/papers/2405.16712). attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_value (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states + past_key_values (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -1023,7 +1026,7 @@ class Zamba2AttentionDecoderLayer(nn.Module): hidden_states=hidden_states, layer_idx=layer_idx, attention_mask=attention_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, position_embeddings=position_embeddings, **kwargs, @@ -1047,6 +1050,7 @@ class Zamba2MambaDecoderLayer(nn.Module): self.input_layernorm = Zamba2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.layer_idx = layer_idx + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -1054,7 +1058,7 @@ class Zamba2MambaDecoderLayer(nn.Module): layer_idx: Optional[int] = None, attention_mask: Optional[torch.Tensor] = None, causal_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Zamba2HybridDynamicCache] = None, + past_key_values: Optional[Zamba2HybridDynamicCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -1066,7 +1070,7 @@ class Zamba2MambaDecoderLayer(nn.Module): hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_value (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states + past_key_values (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -1088,7 +1092,7 @@ class Zamba2MambaDecoderLayer(nn.Module): hidden_states = self.mamba( hidden_states=hidden_states, - cache_params=past_key_value, + cache_params=past_key_values, attention_mask=attention_mask, ) @@ -1103,7 +1107,7 @@ class Zamba2MambaDecoderLayer(nn.Module): outputs += (self_attn_weights,) if use_cache: - outputs += (past_key_value,) + outputs += (past_key_values,) return outputs @@ -1117,6 +1121,7 @@ class Zamba2HybridLayer(nn.Module): self.mamba_decoder = mamba self.shared_transformer = shared_transformer + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -1124,7 +1129,7 @@ class Zamba2HybridLayer(nn.Module): layer_idx: Optional[int] = None, attention_mask: Optional[torch.Tensor] = None, causal_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Zamba2HybridDynamicCache] = None, + past_key_values: Optional[Zamba2HybridDynamicCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, position_embeddings: Optional[torch.LongTensor] = None, @@ -1137,7 +1142,7 @@ class Zamba2HybridLayer(nn.Module): layer_idx (`int`): layer number. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_value (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states + past_key_values (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -1154,7 +1159,7 @@ class Zamba2HybridLayer(nn.Module): original_hidden_states=original_hidden_states, layer_idx=layer_idx, attention_mask=causal_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, position_embeddings=position_embeddings, ) @@ -1170,7 +1175,7 @@ class Zamba2HybridLayer(nn.Module): hidden_states, transformer_hidden_states=transformer_hidden_states, attention_mask=attention_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, position_embeddings=position_embeddings, @@ -1348,7 +1353,7 @@ class Zamba2Model(Zamba2PreTrainedModel): layer_idx=layer_idx, attention_mask=attention_mask, causal_mask=causal_mask, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, position_embeddings=position_embeddings, diff --git a/src/transformers/models/zamba2/modular_zamba2.py b/src/transformers/models/zamba2/modular_zamba2.py index 44f28f9dc3..5eb00899c2 100644 --- a/src/transformers/models/zamba2/modular_zamba2.py +++ b/src/transformers/models/zamba2/modular_zamba2.py @@ -30,6 +30,7 @@ from ...processing_utils import Unpack from ...utils import ( logging, ) +from ...utils.deprecation import deprecate_kwarg from ...utils.import_utils import ( is_causal_conv1d_available, is_mamba_ssm_available, @@ -227,12 +228,13 @@ class Zamba2Attention(ZambaAttention): self.layer_dic = {value: index for index, value in enumerate(self.layer_block_map)} + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, layer_idx: int, attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Zamba2HybridDynamicCache] = None, + past_key_values: Optional[Zamba2HybridDynamicCache] = None, position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, **kwargs: Unpack[FlashAttentionKwargs], ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]: @@ -256,8 +258,8 @@ class Zamba2Attention(ZambaAttention): cos, sin = position_embeddings query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - if past_key_value is not None: - key_states, value_states = past_key_value.update(key_states, value_states, layer_idx) + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, layer_idx) attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": @@ -759,13 +761,14 @@ class Zamba2AttentionDecoderLayer(ZambaAttentionDecoderLayer): self.self_attn = Zamba2Attention(config, layer_idx=-1, num_fwd_mem_blocks=num_gs, block_id=block_id) self.feed_forward = Zamba2MLP(config, num_fwd_mem_blocks=num_gs, block_id=block_id) + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, original_hidden_states: torch.Tensor, layer_idx: int, attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Zamba2HybridDynamicCache] = None, + past_key_values: Optional[Zamba2HybridDynamicCache] = None, output_attentions: Optional[bool] = False, position_embeddings: Optional[torch.LongTensor] = None, **kwargs: Unpack[FlashAttentionKwargs], @@ -779,7 +782,7 @@ class Zamba2AttentionDecoderLayer(ZambaAttentionDecoderLayer): (see fig. 2 in https://huggingface.co/papers/2405.16712). attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_value (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states + past_key_values (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -796,7 +799,7 @@ class Zamba2AttentionDecoderLayer(ZambaAttentionDecoderLayer): hidden_states=hidden_states, layer_idx=layer_idx, attention_mask=attention_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, position_embeddings=position_embeddings, **kwargs, @@ -828,6 +831,7 @@ class Zamba2HybridLayer(ZambaHybridLayer): del self.shared_transf self.shared_transformer = shared_transformer + @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") def forward( self, hidden_states: torch.Tensor, @@ -835,7 +839,7 @@ class Zamba2HybridLayer(ZambaHybridLayer): layer_idx: Optional[int] = None, attention_mask: Optional[torch.Tensor] = None, causal_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Zamba2HybridDynamicCache] = None, + past_key_values: Optional[Zamba2HybridDynamicCache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, position_embeddings: Optional[torch.LongTensor] = None, @@ -848,7 +852,7 @@ class Zamba2HybridLayer(ZambaHybridLayer): layer_idx (`int`): layer number. attention_mask (`torch.FloatTensor`, *optional*): attention mask of size `(batch, sequence_length)` where padding elements are indicated by 0. - past_key_value (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states + past_key_values (`Zamba2HybridDynamicCache`, *optional*): cached past key and value projection states output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -865,7 +869,7 @@ class Zamba2HybridLayer(ZambaHybridLayer): original_hidden_states=original_hidden_states, layer_idx=layer_idx, attention_mask=causal_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, position_embeddings=position_embeddings, ) @@ -881,7 +885,7 @@ class Zamba2HybridLayer(ZambaHybridLayer): hidden_states, transformer_hidden_states=transformer_hidden_states, attention_mask=attention_mask, - past_key_value=past_key_value, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, position_embeddings=position_embeddings, @@ -1105,7 +1109,7 @@ class Zamba2Model(ZambaModel, Zamba2PreTrainedModel): layer_idx=layer_idx, attention_mask=attention_mask, causal_mask=causal_mask, - past_key_value=past_key_values, + past_key_values=past_key_values, output_attentions=output_attentions, use_cache=use_cache, position_embeddings=position_embeddings, diff --git a/src/transformers/utils/auto_docstring.py b/src/transformers/utils/auto_docstring.py index 8a13486b1d..ba03cf9cfe 100644 --- a/src/transformers/utils/auto_docstring.py +++ b/src/transformers/utils/auto_docstring.py @@ -365,13 +365,6 @@ class ModelArgs: "shape": None, } - past_key_value = { - "description": """ - deprecated in favor of `past_key_values` - """, - "shape": None, - } - inputs_embeds = { "description": """ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This diff --git a/tests/utils/test_auto_docstring.py b/tests/utils/test_auto_docstring.py index 11ac05f3b1..1874631f5e 100644 --- a/tests/utils/test_auto_docstring.py +++ b/tests/utils/test_auto_docstring.py @@ -20,7 +20,7 @@ LLAMA_CLM_FORWARD = """ The [`LlamaForCausalLM`] forward method, override LLAMA_MODEL_DOCSTRING = """ The [`LlamaModel`] forward method, overrides the `__call__` special method.\n\n \n\n Although the recipe for forward pass needs to be defined within this function, one should call the [`Module`]\n instance afterwards instead of this since the former takes care of running the pre and post processing steps while\n the latter silently ignores them.\n\n \n\n Args:\n input_ids (`Optional[torch.LongTensor]`)of shape `(batch_size, sequence_length)`):\n Indices of input sequence tokens in the vocabulary. Padding will be ignored by default.\n\n Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n [`PreTrainedTokenizer.__call__`] for details.\n\n [What are input IDs?](../glossary#input-ids)\n attention_mask (`Optional[torch.Tensor]`) of shape `(batch_size, sequence_length)`:\n Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n - 1 for tokens that are **not masked**,\n - 0 for tokens that are **masked**.\n\n [What are attention masks?](../glossary#attention-mask)\n position_ids (`Optional[torch.LongTensor]`):\n Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`.\n\n [What are position IDs?](../glossary#position-ids)\n past_key_values (`Optional[~cache_utils.Cache]`):\n Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention\n blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`\n returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.\n\n Two formats are allowed:\n - a `~cache_utils.Cache` instance, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);\n - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of\n shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy\n cache format.\n\n The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the\n legacy cache format will be returned.\n\n If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't\n have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`\n of shape `(batch_size, sequence_length)`.\n inputs_embeds (`Optional[torch.FloatTensor]`):\n Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This\n is useful if you want more control over how to convert `input_ids` indices into associated vectors than the\n model's internal embedding lookup matrix.\n use_cache (`Optional[bool]`):\n If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see\n `past_key_values`).\n output_attentions (`Optional[bool]`):\n Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned\n tensors for more detail.\n output_hidden_states (`Optional[bool]`):\n Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for\n more detail.\n cache_position (`Optional[torch.LongTensor]`):\n Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,\n this tensor is not affected by padding. It is used to update the cache in the correct position and to infer\n the complete sequence length.\n\n Returns:\n [`transformers.modeling_outputs.BaseModelOutputWithPast`] or `tuple(torch.FloatTensor)`: A [`transformers.modeling_outputs.BaseModelOutputWithPast`] or a tuple of\n `torch.FloatTensor` (if `return_dict=False` is passed or when `config.return_dict=False`) comprising various\n elements depending on the configuration ([`LlamaConfig`]) and inputs.\n\n - **last_hidden_state** (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`) -- Sequence of hidden-states at the output of the last layer of the model.\n\n If `past_key_values` is used only the last hidden-state of the sequences of shape `(batch_size, 1,\n hidden_size)` is output.\n - **past_key_values** (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`) -- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape\n `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and optionally if\n `config.is_encoder_decoder=True` 2 additional tensors of shape `(batch_size, num_heads,\n encoder_sequence_length, embed_size_per_head)`.\n\n Contains pre-computed hidden-states (key and values in the self-attention blocks and optionally if\n `config.is_encoder_decoder=True` in the cross-attention blocks) that can be used (see `past_key_values`\n input) to speed up sequential decoding.\n - **hidden_states** (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`) -- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +\n one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.\n\n Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.\n - **attentions** (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`) -- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,\n sequence_length)`.\n\n Attentions weights after the attention softmax, used to compute the weighted average in the self-attention\n heads.\n""" -LLAMA_DECODER = """ The [`LlamaDecoderLayer`] forward method, overrides the `__call__` special method.\n\n \n\n Although the recipe for forward pass needs to be defined within this function, one should call the [`Module`]\n instance afterwards instead of this since the former takes care of running the pre and post processing steps while\n the latter silently ignores them.\n\n \n\n Args:\n hidden_states (`torch.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim) attention_mask (`Optional[torch.Tensor]`) of shape `(batch_size, sequence_length)`:\n Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n - 1 for tokens that are **not masked**,\n - 0 for tokens that are **masked**.\n\n [What are attention masks?](../glossary#attention-mask)\n position_ids (`Optional[torch.LongTensor]`):\n Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`.\n\n [What are position IDs?](../glossary#position-ids)\n past_key_value (`Optional[~cache_utils.Cache]`):deprecated in favor of `past_key_values` output_attentions (`Optional[bool]`, defaults to `False`):\n Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned\n tensors for more detail.\n use_cache (`Optional[bool]`, defaults to `False`):\n If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see\n `past_key_values`).\n cache_position (`Optional[torch.LongTensor]`):\n Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,\n this tensor is not affected by padding. It is used to update the cache in the correct position and to infer\n the complete sequence length.\n position_embeddings (`Optional[Tuple[torch.Tensor, torch.Tensor]]`):\n Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,\n with `head_dim` being the embedding dimension of each attention head.\n\n Returns:\n `Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]`""" +LLAMA_DECODER = """ The [`LlamaDecoderLayer`] forward method, overrides the `__call__` special method.\n\n \n\n Although the recipe for forward pass needs to be defined within this function, one should call the [`Module`]\n instance afterwards instead of this since the former takes care of running the pre and post processing steps while\n the latter silently ignores them.\n\n \n\n Args:\n hidden_states (`torch.Tensor`): input to the layer of shape `(batch, seq_len, embed_dim) attention_mask (`Optional[torch.Tensor]`) of shape `(batch_size, sequence_length)`:\n Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n - 1 for tokens that are **not masked**,\n - 0 for tokens that are **masked**.\n\n [What are attention masks?](../glossary#attention-mask)\n position_ids (`Optional[torch.LongTensor]`):\n Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`.\n\n [What are position IDs?](../glossary#position-ids)\n past_key_values (`Optional[~cache_utils.Cache]`):deprecated in favor of `past_key_values` output_attentions (`Optional[bool]`, defaults to `False`):\n Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned\n tensors for more detail.\n use_cache (`Optional[bool]`, defaults to `False`):\n If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see\n `past_key_values`).\n cache_position (`Optional[torch.LongTensor]`):\n Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,\n this tensor is not affected by padding. It is used to update the cache in the correct position and to infer\n the complete sequence length.\n position_embeddings (`Optional[Tuple[torch.Tensor, torch.Tensor]]`):\n Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,\n with `head_dim` being the embedding dimension of each attention head.\n\n Returns:\n `Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]`""" LLAMA_FOR_SEQUENCE_CLASSIFICATION_DOC = """ The [`LlamaForSequenceClassification`] forward method, overrides the `__call__` special method.\n\n \n\n Although the recipe for forward pass needs to be defined within this function, one should call the [`Module`]\n instance afterwards instead of this since the former takes care of running the pre and post processing steps while\n the latter silently ignores them.\n\n \n\n Args:\n input_ids (`Optional[torch.LongTensor]`)of shape `(batch_size, sequence_length)`):\n Indices of input sequence tokens in the vocabulary. Padding will be ignored by default.\n\n Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and\n [`PreTrainedTokenizer.__call__`] for details.\n\n [What are input IDs?](../glossary#input-ids)\n attention_mask (`Optional[torch.Tensor]`) of shape `(batch_size, sequence_length)`:\n Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:\n\n - 1 for tokens that are **not masked**,\n - 0 for tokens that are **masked**.\n\n [What are attention masks?](../glossary#attention-mask)\n position_ids (`Optional[torch.LongTensor]`):\n Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, config.n_positions - 1]`.\n\n [What are position IDs?](../glossary#position-ids)\n past_key_values (`Optional[~cache_utils.Cache]`):\n Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention\n blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`\n returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.\n\n Two formats are allowed:\n - a `~cache_utils.Cache` instance, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);\n - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of\n shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy\n cache format.\n\n The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the\n legacy cache format will be returned.\n\n If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don\'t\n have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`\n of shape `(batch_size, sequence_length)`.\n inputs_embeds (`Optional[torch.FloatTensor]`):\n Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This\n is useful if you want more control over how to convert `input_ids` indices into associated vectors than the\n model\'s internal embedding lookup matrix.\n labels (`Optional[torch.LongTensor]`):\n Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,\n config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If\n `config.num_labels > 1` a classification loss is computed (Cross-Entropy).\n use_cache (`Optional[bool]`):\n If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see\n `past_key_values`).\n output_attentions (`Optional[bool]`):\n Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned\n tensors for more detail.\n output_hidden_states (`Optional[bool]`):\n Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for\n more detail.\n\n Returns:\n [`transformers.modeling_outputs.SequenceClassifierOutputWithPast`] or `tuple(torch.FloatTensor)`: A [`transformers.modeling_outputs.SequenceClassifierOutputWithPast`] or a tuple of\n `torch.FloatTensor` (if `return_dict=False` is passed or when `config.return_dict=False`) comprising various\n elements depending on the configuration ([`LlamaConfig`]) and inputs.\n\n - **loss** (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided) -- Classification (or regression if config.num_labels==1) loss.\n - **logits** (`torch.FloatTensor` of shape `(batch_size, config.num_labels)`) -- Classification (or regression if config.num_labels==1) scores (before SoftMax).\n - **past_key_values** (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`) -- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape\n `(batch_size, num_heads, sequence_length, embed_size_per_head)`)\n\n Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see\n `past_key_values` input) to speed up sequential decoding.\n - **hidden_states** (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`) -- Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, +\n one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`.\n\n Hidden-states of the model at the output of each layer plus the optional initial embedding outputs.\n - **attentions** (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`) -- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,\n sequence_length)`.\n\n Attentions weights after the attention softmax, used to compute the weighted average in the self-attention\n heads.\n\n Example of single-label classification:\n\n ```python\n >>> import torch\n >>> from transformers import AutoTokenizer, LlamaForSequenceClassification\n\n >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")\n >>> model = LlamaForSequenceClassification.from_pretrained("meta-llama/Llama-2-7b-hf")\n\n >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")\n\n >>> with torch.no_grad():\n ... logits = model(**inputs).logits\n\n >>> predicted_class_id = logits.argmax().item()\n >>> model.config.id2label[predicted_class_id]\n ...\n\n >>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`\n >>> num_labels = len(model.config.id2label)\n >>> model = LlamaForSequenceClassification.from_pretrained("meta-llama/Llama-2-7b-hf", num_labels=num_labels)\n\n >>> labels = torch.tensor([1])\n >>> loss = model(**inputs, labels=labels).loss\n >>> round(loss.item(), 2)\n ...\n ```\n\n Example of multi-label classification:\n\n ```python\n >>> import torch\n >>> from transformers import AutoTokenizer, LlamaForSequenceClassification\n\n >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")\n >>> model = LlamaForSequenceClassification.from_pretrained("meta-llama/Llama-2-7b-hf", problem_type="multi_label_classification")\n\n >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")\n\n >>> with torch.no_grad():\n ... logits = model(**inputs).logits\n\n >>> predicted_class_ids = torch.arange(0, logits.shape[-1])[torch.sigmoid(logits).squeeze(dim=0) > 0.5]\n\n >>> # To train a model on `num_labels` classes, you can pass `num_labels=num_labels` to `.from_pretrained(...)`\n >>> num_labels = len(model.config.id2label)\n >>> model = LlamaForSequenceClassification.from_pretrained(\n ... "meta-llama/Llama-2-7b-hf", num_labels=num_labels, problem_type="multi_label_classification"\n ... )\n\n >>> labels = torch.sum(\n ... torch.nn.functional.one_hot(predicted_class_ids[None, :].clone(), num_classes=num_labels), dim=1\n ... ).to(torch.float)\n >>> loss = model(**inputs, labels=labels).loss\n ```\n"""