From 7c62e69326c9e7a78a1339df7d39c9a1c5c43de0 Mon Sep 17 00:00:00 2001 From: Poedator <24738311+poedator@users.noreply.github.com> Date: Thu, 24 Apr 2025 15:46:35 +0300 Subject: [PATCH] `GPT2Model` StaticCache support (#35761) * initial GPT2 changes * causal_mask support * return_legacy_cache * cleanup * fix1 * outputs shape fixes * gpt2 return fix * pkv, attn fixes * fix dual_head * is_causal arg fix * decision transformer updated * style fix * batch_size from inputs_embeds * DecisionTransformerModel fixes * cross-attn support + cache warning * x-attn @decision * EDCache proper init * simplified logic in `if use_cache:` for GPT2Model * @deprecate_kwarg for DecisionTr attn fwd * @deprecate_kwarg in gpt2 * deprecation version updated to 4.51 * kwargs in gradient_checkpointing_fn * rename next_cache to past_key_values * attention_mask prep * +cache_position in GPT2DoubleHeadsModel * undo kwargs in gradient checkpointing * moved up `if self.gradient_checkpointing` * consistency in decision_transformer * pastkv, cache_pos in grad_checkpt args * rm _reorder_cache * output_attentions streamlined * decision_transformer consistency * return_legacy_cache improved * ClvpForCausalLM used for legacy cache test now * is_causal fixed * attn_output cleanup * consistency @ decision_transformer * Updated deprecation notice version to 4.52 * upd deprecation * consistent legacy cache code in decision transformers\ * next_cache -> past_kv in decision_tr * cache support flags in decision_transf * rm legacy cache warning * consistency in cache init for decision transf * no Static Cache for Decision Transformer --------- Co-authored-by: Cyril Vallez --- src/transformers/models/clvp/modeling_clvp.py | 1 - .../modeling_decision_transformer.py | 129 ++++--- src/transformers/models/gpt2/modeling_gpt2.py | 356 ++++++++++++------ tests/utils/test_cache_utils.py | 4 +- 4 files changed, 324 insertions(+), 166 deletions(-) diff --git a/src/transformers/models/clvp/modeling_clvp.py b/src/transformers/models/clvp/modeling_clvp.py index bafaa36dd6..a63e0df512 100644 --- a/src/transformers/models/clvp/modeling_clvp.py +++ b/src/transformers/models/clvp/modeling_clvp.py @@ -1589,7 +1589,6 @@ class ClvpForCausalLM(ClvpPreTrainedModel, GenerationMixin): ) @staticmethod - # Copied from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache def _reorder_cache( past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor ) -> Tuple[Tuple[torch.Tensor]]: diff --git a/src/transformers/models/decision_transformer/modeling_decision_transformer.py b/src/transformers/models/decision_transformer/modeling_decision_transformer.py index ab8c190775..54000b8f24 100755 --- a/src/transformers/models/decision_transformer/modeling_decision_transformer.py +++ b/src/transformers/models/decision_transformer/modeling_decision_transformer.py @@ -24,6 +24,7 @@ import torch.utils.checkpoint from torch import nn from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer @@ -34,6 +35,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from .configuration_decision_transformer import DecisionTransformerConfig @@ -125,7 +127,8 @@ def eager_attention_forward(module, query, key, value, attention_mask, head_mask if attention_mask is not None: # Apply the attention mask - attn_weights = attn_weights + attention_mask + causal_mask = attention_mask[:, :, :, : key.shape[-2]] + attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) @@ -257,19 +260,21 @@ class DecisionTransformerGPT2Attention(nn.Module): return attn_output, attn_weights + @deprecate_kwarg("layer_past", new_name="past_key_value", version="4.53.0", raise_if_both_names=True) def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], - layer_past: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, **kwargs, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: - if encoder_hidden_states is not None: + is_cross_attention = encoder_hidden_states is not None + if is_cross_attention: if not hasattr(self, "q_attn"): raise ValueError( "If class is used as cross attention, the weights `q_attn` have to be defined. " @@ -289,17 +294,17 @@ class DecisionTransformerGPT2Attention(nn.Module): key_states = key_states.view(shape_kv).transpose(1, 2) value_states = value_states.view(shape_kv).transpose(1, 2) - if layer_past is not None: - past_key, past_value = layer_past - key_states = torch.cat((past_key, key_states), dim=-2) - value_states = torch.cat((past_value, value_states), dim=-2) + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + if is_cross_attention: + past_key_value = past_key_value.cross_attention_cache + else: + past_key_value = past_key_value.self_attention_cache + cache_kwargs = {"cache_position": cache_position} + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs=cache_kwargs + ) - if use_cache is True: - present = (key_states, value_states) - else: - present = None - - is_cross_attention = encoder_hidden_states is not None is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention using_eager = self.config._attn_implementation == "eager" @@ -338,11 +343,7 @@ class DecisionTransformerGPT2Attention(nn.Module): attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) - outputs = (attn_output, present) - if output_attentions: - outputs += (attn_weights,) - - return outputs # a, present, (attentions) + return attn_output, attn_weights # Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP with GPT2->DecisionTransformerGPT2 @@ -383,10 +384,12 @@ class DecisionTransformerGPT2Block(nn.Module): self.mlp = DecisionTransformerGPT2MLP(inner_dim, config) + @deprecate_kwarg("layer_past", new_name="past_key_value", version="4.53.0", raise_if_both_names=True) def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], - layer_past: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, @@ -396,16 +399,15 @@ class DecisionTransformerGPT2Block(nn.Module): ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: residual = hidden_states hidden_states = self.ln_1(hidden_states) - attn_outputs = self.attn( + attn_output, self_attn_weights = self.attn( hidden_states, - layer_past=layer_past, + past_key_value=past_key_value, + cache_position=cache_position, attention_mask=attention_mask, head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, ) - attn_output = attn_outputs[0] # output_attn: a, present, (attentions) - outputs = attn_outputs[1:] # residual connection hidden_states = attn_output + residual @@ -418,18 +420,17 @@ class DecisionTransformerGPT2Block(nn.Module): ) residual = hidden_states hidden_states = self.ln_cross_attn(hidden_states) - cross_attn_outputs = self.crossattention( + cross_attn_output, cross_attn_weights = self.crossattention( hidden_states, + past_key_value=past_key_value, attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, ) - attn_output = cross_attn_outputs[0] # residual connection - hidden_states = residual + attn_output - outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights + hidden_states = residual + cross_attn_output residual = hidden_states hidden_states = self.ln_2(hidden_states) @@ -437,12 +438,13 @@ class DecisionTransformerGPT2Block(nn.Module): # residual connection hidden_states = residual + feed_forward_hidden_states - if use_cache: - outputs = (hidden_states,) + outputs - else: - outputs = (hidden_states,) + outputs[1:] + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + if encoder_hidden_states is not None: + outputs += (cross_attn_weights,) - return outputs # hidden_states, present, (attentions, cross_attentions) + return outputs class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel): @@ -456,6 +458,8 @@ class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel): base_model_prefix = "transformer" is_parallelizable = True supports_gradient_checkpointing = True + _supports_cache_class = True + _supports_static_cache = False def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) @@ -521,6 +525,7 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel): self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -558,14 +563,31 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel): if token_type_ids is not None: token_type_ids = token_type_ids.view(-1, input_shape[-1]) - if past_key_values is None: - past_length = 0 - past_key_values = tuple([None] * len(self.h)) - else: - past_length = past_key_values[0][0].size(-2) + # based on pattern from src/transformers/models/whisper/modeling_whisper.py::WhisperDecoder and similar addition in GPT2Model + return_legacy_cache = False + if use_cache: + if past_key_values is None: + return_legacy_cache = True + past_key_values = DynamicCache() + elif not isinstance(past_key_values, Cache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.53.0. " + "You should pass an instance of `Cache` instead, e.g. " + "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + if self.config.add_cross_attention and not isinstance(past_key_values, EncoderDecoderCache): + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) if position_ids is None: - position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0) + position_ids = cache_position.unsqueeze(0) # Attention mask. if attention_mask is not None: @@ -624,17 +646,13 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel): ) use_cache = False - presents = () if use_cache else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_hidden_states = () if output_hidden_states else None - for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + for i, block in enumerate(self.h): # Model parallel if self.model_parallel: torch.cuda.set_device(hidden_states.device) - # Ensure layer_past is on same device as hidden_states (might not be correct) - if layer_past is not None: - layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) # Ensure that attention_mask is always on the same device as hidden_states if attention_mask is not None: attention_mask = attention_mask.to(hidden_states.device) @@ -648,6 +666,7 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel): block.__call__, hidden_states, None, + None, attention_mask, head_mask[i], encoder_hidden_states, @@ -658,7 +677,8 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel): else: outputs = block( hidden_states, - layer_past=layer_past, + past_key_value=past_key_values, + cache_position=cache_position, attention_mask=attention_mask, head_mask=head_mask[i], encoder_hidden_states=encoder_hidden_states, @@ -668,13 +688,11 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel): ) hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + all_self_attentions = all_self_attentions + (outputs[1],) if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + all_cross_attentions = all_cross_attentions + (outputs[2],) # Model Parallel: If it's the last layer for that device, put things on the next device if self.model_parallel: @@ -689,16 +707,23 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + past_key_values = past_key_values if use_cache else None + if return_legacy_cache: + past_key_values = ( + past_key_values.self_attention_cache.to_legacy_cache() + if self.config.add_cross_attention + else past_key_values.to_legacy_cache() + ) if not return_dict: return tuple( v - for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=presents, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 717962b337..6aec582b7c 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -27,8 +27,9 @@ from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN, get_activation +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa +from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_attention_mask_for_sdpa from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -46,6 +47,7 @@ from ...utils import ( logging, replace_return_docstrings, ) +from ...utils.deprecation import deprecate_kwarg from ...utils.model_parallel_utils import assert_device_map, get_device_map from .configuration_gpt2 import GPT2Config @@ -136,7 +138,8 @@ def eager_attention_forward(module, query, key, value, attention_mask, head_mask if attention_mask is not None: # Apply the attention mask - attn_weights = attn_weights + attention_mask + causal_mask = attention_mask[:, :, :, : key.shape[-2]] + attn_weights = attn_weights + causal_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) @@ -267,19 +270,21 @@ class GPT2Attention(nn.Module): return attn_output, attn_weights + @deprecate_kwarg("layer_past", new_name="past_key_value", version="4.53.0", raise_if_both_names=True) def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], - layer_past: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.FloatTensor] = None, - use_cache: Optional[bool] = False, output_attentions: Optional[bool] = False, **kwargs, ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]: - if encoder_hidden_states is not None: + is_cross_attention = encoder_hidden_states is not None + if is_cross_attention: if not hasattr(self, "q_attn"): raise ValueError( "If class is used as cross attention, the weights `q_attn` have to be defined. " @@ -299,17 +304,17 @@ class GPT2Attention(nn.Module): key_states = key_states.view(shape_kv).transpose(1, 2) value_states = value_states.view(shape_kv).transpose(1, 2) - if layer_past is not None: - past_key, past_value = layer_past - key_states = torch.cat((past_key, key_states), dim=-2) - value_states = torch.cat((past_value, value_states), dim=-2) + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + if is_cross_attention: + past_key_value = past_key_value.cross_attention_cache + else: + past_key_value = past_key_value.self_attention_cache + cache_kwargs = {"cache_position": cache_position} + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs=cache_kwargs + ) - if use_cache is True: - present = (key_states, value_states) - else: - present = None - - is_cross_attention = encoder_hidden_states is not None is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention using_eager = self.config._attn_implementation == "eager" @@ -348,11 +353,7 @@ class GPT2Attention(nn.Module): attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) - outputs = (attn_output, present) - if output_attentions: - outputs += (attn_weights,) - - return outputs # a, present, (attentions) + return attn_output, attn_weights class GPT2MLP(nn.Module): @@ -388,10 +389,12 @@ class GPT2Block(nn.Module): self.mlp = GPT2MLP(inner_dim, config) + @deprecate_kwarg("layer_past", new_name="past_key_value", version="4.53.0", raise_if_both_names=True) def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], - layer_past: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, + cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, head_mask: Optional[torch.FloatTensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, @@ -401,16 +404,15 @@ class GPT2Block(nn.Module): ) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]: residual = hidden_states hidden_states = self.ln_1(hidden_states) - attn_outputs = self.attn( + attn_output, self_attn_weights = self.attn( hidden_states, - layer_past=layer_past, + past_key_value=past_key_value, + cache_position=cache_position, attention_mask=attention_mask, head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, ) - attn_output = attn_outputs[0] # output_attn: a, present, (attentions) - outputs = attn_outputs[1:] # residual connection hidden_states = attn_output + residual @@ -423,18 +425,17 @@ class GPT2Block(nn.Module): ) residual = hidden_states hidden_states = self.ln_cross_attn(hidden_states) - cross_attn_outputs = self.crossattention( + cross_attn_output, cross_attn_weights = self.crossattention( hidden_states, + past_key_value=past_key_value, attention_mask=attention_mask, head_mask=head_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, output_attentions=output_attentions, ) - attn_output = cross_attn_outputs[0] # residual connection - hidden_states = residual + attn_output - outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights + hidden_states = residual + cross_attn_output residual = hidden_states hidden_states = self.ln_2(hidden_states) @@ -442,12 +443,13 @@ class GPT2Block(nn.Module): # residual connection hidden_states = residual + feed_forward_hidden_states - if use_cache: - outputs = (hidden_states,) + outputs - else: - outputs = (hidden_states,) + outputs[1:] + outputs = (hidden_states,) + if output_attentions: + outputs += (self_attn_weights,) + if encoder_hidden_states is not None: + outputs += (cross_attn_weights,) - return outputs # hidden_states, present, (attentions, cross_attentions) + return outputs # Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->GPT2 @@ -565,6 +567,8 @@ class GPT2PreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True + _supports_cache_class = True + _supports_static_cache = True def __init__(self, *inputs, **kwargs): super().__init__(*inputs, **kwargs) @@ -669,10 +673,24 @@ GPT2_INPUTS_DOCSTRING = r""" [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) - past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`): - Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see - `past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have - their past given to this model should not be passed as `input_ids` as they have already been computed. + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: @@ -721,6 +739,10 @@ GPT2_INPUTS_DOCSTRING = r""" more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. """ PARALLELIZE_DOCSTRING = r""" This is an experimental feature and is a subject to change at a moment's notice. @@ -868,7 +890,8 @@ class GPT2Model(GPT2PreTrainedModel): def forward( self, input_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + past_key_values: Optional[Union[Tuple[Tuple[torch.Tensor]], Cache]] = None, + cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -906,51 +929,56 @@ class GPT2Model(GPT2PreTrainedModel): if token_type_ids is not None: token_type_ids = token_type_ids.view(-1, input_shape[-1]) - if past_key_values is None: - past_length = 0 - past_key_values = tuple([None] * len(self.h)) - else: - past_length = past_key_values[0][0].size(-2) - if position_ids is None: - position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device) - position_ids = position_ids.unsqueeze(0) + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # based on pattern from src/transformers/models/whisper/modeling_whisper.py::WhisperDecoder + return_legacy_cache = False + if use_cache: + if past_key_values is None: + return_legacy_cache = True + past_key_values = DynamicCache() + elif not isinstance(past_key_values, Cache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.53.0. " + "You should pass an instance of `Cache` instead, e.g. " + "`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + + if self.config.add_cross_attention and not isinstance(past_key_values, EncoderDecoderCache): + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) if inputs_embeds is None: inputs_embeds = self.wte(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds.to(inputs_embeds.device) # Attention mask. - _use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None - attention_mask = attention_mask.view(batch_size, -1) if attention_mask is not None else None - if self._attn_implementation == "flash_attention_2": - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif _use_sdpa: - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask=attention_mask, - input_shape=(batch_size, input_shape[-1]), - inputs_embeds=inputs_embeds, - past_key_values_length=past_length, - ) - else: - if attention_mask is not None: - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - attention_mask = attention_mask[:, None, None, :] - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and the dtype's smallest value for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility - attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min + # ._update_causal_mask() and ._prepare_4d_causal_attention_mask_with_cache_position() copied from LlamaModel + if attention_mask is not None and attention_mask.ndim < 4: + attention_mask = attention_mask.view(batch_size, -1) + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + _use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None if self.config.add_cross_attention and encoder_hidden_states is not None: encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) @@ -979,25 +1007,13 @@ class GPT2Model(GPT2PreTrainedModel): output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - presents = () if use_cache else None all_self_attentions = () if output_attentions else None all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None all_hidden_states = () if output_hidden_states else None - for i in range(len(self.h)): - block, layer_past = self.h[i], past_key_values[i] + for i, block in enumerate(self.h): # Model parallel if self.model_parallel: torch.cuda.set_device(hidden_states.device) - # Ensure layer_past is on same device as hidden_states (might not be correct) - if layer_past is not None: - layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past) # Ensure that attention_mask is always on the same device as hidden_states if attention_mask is not None: attention_mask = attention_mask.to(hidden_states.device) @@ -1010,8 +1026,9 @@ class GPT2Model(GPT2PreTrainedModel): outputs = self._gradient_checkpointing_func( block.__call__, hidden_states, - None, - attention_mask, + past_key_values, + cache_position, + causal_mask, head_mask[i], encoder_hidden_states, encoder_attention_mask, @@ -1021,8 +1038,9 @@ class GPT2Model(GPT2PreTrainedModel): else: outputs = block( hidden_states, - layer_past=layer_past, - attention_mask=attention_mask, + past_key_value=past_key_values, + cache_position=cache_position, + attention_mask=causal_mask, head_mask=head_mask[i], encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, @@ -1031,13 +1049,11 @@ class GPT2Model(GPT2PreTrainedModel): ) hidden_states = outputs[0] - if use_cache is True: - presents = presents + (outputs[1],) if output_attentions: - all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + all_self_attentions = all_self_attentions + (outputs[1],) if self.config.add_cross_attention: - all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],) + all_cross_attentions = all_cross_attentions + (outputs[2],) # Model Parallel: If it's the last layer for that device, put things on the next device if self.model_parallel: @@ -1052,21 +1068,149 @@ class GPT2Model(GPT2PreTrainedModel): if output_hidden_states: all_hidden_states = all_hidden_states + (hidden_states,) + past_key_values = past_key_values if use_cache else None + if return_legacy_cache: + past_key_values = ( + past_key_values.self_attention_cache.to_legacy_cache() + if self.config.add_cross_attention + else past_key_values.to_legacy_cache() + ) if not return_dict: return tuple( v - for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions] + for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions] if v is not None ) return BaseModelOutputWithPastAndCrossAttentions( last_hidden_state=hidden_states, - past_key_values=presents, + past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attentions, cross_attentions=all_cross_attentions, ) + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + @add_start_docstrings( """ @@ -1137,6 +1281,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel, GenerationMixin): self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -1163,6 +1308,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel, GenerationMixin): input_ids, past_key_values=past_key_values, attention_mask=attention_mask, + cache_position=cache_position, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, @@ -1206,20 +1352,6 @@ class GPT2LMHeadModel(GPT2PreTrainedModel, GenerationMixin): cross_attentions=transformer_outputs.cross_attentions, ) - @staticmethod - def _reorder_cache( - past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor - ) -> Tuple[Tuple[torch.Tensor]]: - """ - This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or - [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct - beam_idx at every generation step. - """ - return tuple( - tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) - for layer_past in past_key_values - ) - @add_start_docstrings( """ @@ -1292,6 +1424,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel, GenerationMixin): self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + cache_position: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -1350,6 +1483,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel, GenerationMixin): transformer_outputs = self.transformer( input_ids, past_key_values=past_key_values, + cache_position=cache_position, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index bbf0268c8b..96c757fd8f 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -40,9 +40,9 @@ if is_torch_available(): from transformers import ( AutoModelForCausalLM, AutoTokenizer, + ClvpForCausalLM, DynamicCache, GenerationConfig, - GPT2LMHeadModel, LlamaConfig, SinkCache, StaticCache, @@ -103,7 +103,7 @@ class CacheTest(unittest.TestCase): def test_reorder_cache_retrocompatibility(self): """Tests that Cache.reorder_cache is retrocompatible with the legacy code path""" - legacy_reorder_fn = GPT2LMHeadModel._reorder_cache # An example of a legacy `_reorder_cache` function + legacy_reorder_fn = ClvpForCausalLM._reorder_cache # An example of a legacy `_reorder_cache` function legacy_cache = () new_cache = DynamicCache()