Harmonize past_key_value to past_key_valueS everywhere (#39956)

* all modulars and llama

* apply modular

* bert and gpt2 copies

* fix imports

* do it everywhere

* fix import

* finalize it

* fix

* oups set it in modular

* style

* fix

* Add 1 version to deprecation cycle

* Update modeling_layers.py
This commit is contained in:
Cyril Vallez
2025-08-08 11:52:57 +02:00
committed by GitHub
parent 2469cce621
commit 5c3fb7f731
211 changed files with 3159 additions and 2733 deletions

View File

@@ -218,7 +218,7 @@ class Olmo2Attention(OlmoAttention):
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
past_key_values: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
@@ -236,10 +236,10 @@ class Olmo2Attention(OlmoAttention):
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
if past_key_values is not None:
# sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
attention_interface: Callable = eager_attention_forward
if self.config._attn_implementation != "eager":
@@ -290,7 +290,7 @@ class Olmo2DecoderLayer(OlmoDecoderLayer):
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
past_key_values: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
@@ -304,7 +304,7 @@ class Olmo2DecoderLayer(OlmoDecoderLayer):
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,