GPT2Model StaticCache support (#35761)

* initial GPT2 changes

* causal_mask support

* return_legacy_cache

* cleanup

* fix1

* outputs shape fixes

* gpt2 return fix

* pkv, attn fixes

* fix dual_head

* is_causal arg fix

* decision transformer updated

* style fix

* batch_size from inputs_embeds

* DecisionTransformerModel fixes

* cross-attn support + cache warning

* x-attn @decision

* EDCache proper init

* simplified logic in `if use_cache:` for GPT2Model

* @deprecate_kwarg for DecisionTr attn fwd

* @deprecate_kwarg in gpt2

* deprecation version updated to 4.51

* kwargs in gradient_checkpointing_fn

* rename next_cache to past_key_values

* attention_mask prep

* +cache_position in GPT2DoubleHeadsModel

* undo kwargs in gradient checkpointing

* moved up `if self.gradient_checkpointing`

* consistency in decision_transformer

* pastkv, cache_pos in grad_checkpt args

* rm _reorder_cache

* output_attentions streamlined

* decision_transformer consistency

* return_legacy_cache improved

* ClvpForCausalLM used for legacy cache test now

* is_causal fixed

* attn_output cleanup

* consistency @ decision_transformer

* Updated deprecation notice version to 4.52

* upd deprecation

* consistent legacy cache code in decision transformers\

* next_cache -> past_kv in decision_tr

* cache support flags in decision_transf

* rm legacy cache warning

* consistency in cache init for decision transf

* no Static Cache for Decision Transformer

---------

Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
This commit is contained in:
Poedator
2025-04-24 15:46:35 +03:00
committed by GitHub
parent 9f927c8250
commit 7c62e69326
4 changed files with 324 additions and 166 deletions

View File

@@ -1589,7 +1589,6 @@ class ClvpForCausalLM(ClvpPreTrainedModel, GenerationMixin):
)
@staticmethod
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache
def _reorder_cache(
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:

View File

@@ -24,6 +24,7 @@ import torch.utils.checkpoint
from torch import nn
from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
@@ -34,6 +35,7 @@ from ...utils import (
logging,
replace_return_docstrings,
)
from ...utils.deprecation import deprecate_kwarg
from .configuration_decision_transformer import DecisionTransformerConfig
@@ -125,7 +127,8 @@ def eager_attention_forward(module, query, key, value, attention_mask, head_mask
if attention_mask is not None:
# Apply the attention mask
attn_weights = attn_weights + attention_mask
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
@@ -257,19 +260,21 @@ class DecisionTransformerGPT2Attention(nn.Module):
return attn_output, attn_weights
@deprecate_kwarg("layer_past", new_name="past_key_value", version="4.53.0", raise_if_both_names=True)
def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
**kwargs,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
if encoder_hidden_states is not None:
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention:
if not hasattr(self, "q_attn"):
raise ValueError(
"If class is used as cross attention, the weights `q_attn` have to be defined. "
@@ -289,17 +294,17 @@ class DecisionTransformerGPT2Attention(nn.Module):
key_states = key_states.view(shape_kv).transpose(1, 2)
value_states = value_states.view(shape_kv).transpose(1, 2)
if layer_past is not None:
past_key, past_value = layer_past
key_states = torch.cat((past_key, key_states), dim=-2)
value_states = torch.cat((past_value, value_states), dim=-2)
if use_cache is True:
present = (key_states, value_states)
if past_key_value is not None:
if isinstance(past_key_value, EncoderDecoderCache):
if is_cross_attention:
past_key_value = past_key_value.cross_attention_cache
else:
present = None
past_key_value = past_key_value.self_attention_cache
cache_kwargs = {"cache_position": cache_position}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs=cache_kwargs
)
is_cross_attention = encoder_hidden_states is not None
is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
using_eager = self.config._attn_implementation == "eager"
@@ -338,11 +343,7 @@ class DecisionTransformerGPT2Attention(nn.Module):
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)
return outputs # a, present, (attentions)
return attn_output, attn_weights
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP with GPT2->DecisionTransformerGPT2
@@ -383,10 +384,12 @@ class DecisionTransformerGPT2Block(nn.Module):
self.mlp = DecisionTransformerGPT2MLP(inner_dim, config)
@deprecate_kwarg("layer_past", new_name="past_key_value", version="4.53.0", raise_if_both_names=True)
def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
@@ -396,16 +399,15 @@ class DecisionTransformerGPT2Block(nn.Module):
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_outputs = self.attn(
attn_output, self_attn_weights = self.attn(
hidden_states,
layer_past=layer_past,
past_key_value=past_key_value,
cache_position=cache_position,
attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
outputs = attn_outputs[1:]
# residual connection
hidden_states = attn_output + residual
@@ -418,18 +420,17 @@ class DecisionTransformerGPT2Block(nn.Module):
)
residual = hidden_states
hidden_states = self.ln_cross_attn(hidden_states)
cross_attn_outputs = self.crossattention(
cross_attn_output, cross_attn_weights = self.crossattention(
hidden_states,
past_key_value=past_key_value,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
)
attn_output = cross_attn_outputs[0]
# residual connection
hidden_states = residual + attn_output
outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
hidden_states = residual + cross_attn_output
residual = hidden_states
hidden_states = self.ln_2(hidden_states)
@@ -437,12 +438,13 @@ class DecisionTransformerGPT2Block(nn.Module):
# residual connection
hidden_states = residual + feed_forward_hidden_states
if use_cache:
outputs = (hidden_states,) + outputs
else:
outputs = (hidden_states,) + outputs[1:]
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if encoder_hidden_states is not None:
outputs += (cross_attn_weights,)
return outputs # hidden_states, present, (attentions, cross_attentions)
return outputs
class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel):
@@ -456,6 +458,8 @@ class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel):
base_model_prefix = "transformer"
is_parallelizable = True
supports_gradient_checkpointing = True
_supports_cache_class = True
_supports_static_cache = False
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
@@ -521,6 +525,7 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
@@ -558,14 +563,31 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
# based on pattern from src/transformers/models/whisper/modeling_whisper.py::WhisperDecoder and similar addition in GPT2Model
return_legacy_cache = False
if use_cache:
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)
return_legacy_cache = True
past_key_values = DynamicCache()
elif not isinstance(past_key_values, Cache):
return_legacy_cache = True
logger.warning_once(
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.53.0. "
"You should pass an instance of `Cache` instead, e.g. "
"`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
)
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if self.config.add_cross_attention and not isinstance(past_key_values, EncoderDecoderCache):
past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0)
position_ids = cache_position.unsqueeze(0)
# Attention mask.
if attention_mask is not None:
@@ -624,17 +646,13 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
)
use_cache = False
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
for i, block in enumerate(self.h):
# Model parallel
if self.model_parallel:
torch.cuda.set_device(hidden_states.device)
# Ensure layer_past is on same device as hidden_states (might not be correct)
if layer_past is not None:
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
# Ensure that attention_mask is always on the same device as hidden_states
if attention_mask is not None:
attention_mask = attention_mask.to(hidden_states.device)
@@ -648,6 +666,7 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
block.__call__,
hidden_states,
None,
None,
attention_mask,
head_mask[i],
encoder_hidden_states,
@@ -658,7 +677,8 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
else:
outputs = block(
hidden_states,
layer_past=layer_past,
past_key_value=past_key_values,
cache_position=cache_position,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
@@ -668,13 +688,11 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
all_self_attentions = all_self_attentions + (outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
all_cross_attentions = all_cross_attentions + (outputs[2],)
# Model Parallel: If it's the last layer for that device, put things on the next device
if self.model_parallel:
@@ -689,16 +707,23 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
past_key_values = past_key_values if use_cache else None
if return_legacy_cache:
past_key_values = (
past_key_values.self_attention_cache.to_legacy_cache()
if self.config.add_cross_attention
else past_key_values.to_legacy_cache()
)
if not return_dict:
return tuple(
v
for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
past_key_values=past_key_values,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,

View File

@@ -27,8 +27,9 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN, get_activation
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_attention_mask_for_sdpa
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
@@ -46,6 +47,7 @@ from ...utils import (
logging,
replace_return_docstrings,
)
from ...utils.deprecation import deprecate_kwarg
from ...utils.model_parallel_utils import assert_device_map, get_device_map
from .configuration_gpt2 import GPT2Config
@@ -136,7 +138,8 @@ def eager_attention_forward(module, query, key, value, attention_mask, head_mask
if attention_mask is not None:
# Apply the attention mask
attn_weights = attn_weights + attention_mask
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
attn_weights = attn_weights + causal_mask
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
@@ -267,19 +270,21 @@ class GPT2Attention(nn.Module):
return attn_output, attn_weights
@deprecate_kwarg("layer_past", new_name="past_key_value", version="4.53.0", raise_if_both_names=True)
def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
**kwargs,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
if encoder_hidden_states is not None:
is_cross_attention = encoder_hidden_states is not None
if is_cross_attention:
if not hasattr(self, "q_attn"):
raise ValueError(
"If class is used as cross attention, the weights `q_attn` have to be defined. "
@@ -299,17 +304,17 @@ class GPT2Attention(nn.Module):
key_states = key_states.view(shape_kv).transpose(1, 2)
value_states = value_states.view(shape_kv).transpose(1, 2)
if layer_past is not None:
past_key, past_value = layer_past
key_states = torch.cat((past_key, key_states), dim=-2)
value_states = torch.cat((past_value, value_states), dim=-2)
if use_cache is True:
present = (key_states, value_states)
if past_key_value is not None:
if isinstance(past_key_value, EncoderDecoderCache):
if is_cross_attention:
past_key_value = past_key_value.cross_attention_cache
else:
present = None
past_key_value = past_key_value.self_attention_cache
cache_kwargs = {"cache_position": cache_position}
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs=cache_kwargs
)
is_cross_attention = encoder_hidden_states is not None
is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
using_eager = self.config._attn_implementation == "eager"
@@ -348,11 +353,7 @@ class GPT2Attention(nn.Module):
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
outputs = (attn_output, present)
if output_attentions:
outputs += (attn_weights,)
return outputs # a, present, (attentions)
return attn_output, attn_weights
class GPT2MLP(nn.Module):
@@ -388,10 +389,12 @@ class GPT2Block(nn.Module):
self.mlp = GPT2MLP(inner_dim, config)
@deprecate_kwarg("layer_past", new_name="past_key_value", version="4.53.0", raise_if_both_names=True)
def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
layer_past: Optional[Tuple[torch.Tensor]] = None,
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
@@ -401,16 +404,15 @@ class GPT2Block(nn.Module):
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_outputs = self.attn(
attn_output, self_attn_weights = self.attn(
hidden_states,
layer_past=layer_past,
past_key_value=past_key_value,
cache_position=cache_position,
attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
outputs = attn_outputs[1:]
# residual connection
hidden_states = attn_output + residual
@@ -423,18 +425,17 @@ class GPT2Block(nn.Module):
)
residual = hidden_states
hidden_states = self.ln_cross_attn(hidden_states)
cross_attn_outputs = self.crossattention(
cross_attn_output, cross_attn_weights = self.crossattention(
hidden_states,
past_key_value=past_key_value,
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
)
attn_output = cross_attn_outputs[0]
# residual connection
hidden_states = residual + attn_output
outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
hidden_states = residual + cross_attn_output
residual = hidden_states
hidden_states = self.ln_2(hidden_states)
@@ -442,12 +443,13 @@ class GPT2Block(nn.Module):
# residual connection
hidden_states = residual + feed_forward_hidden_states
if use_cache:
outputs = (hidden_states,) + outputs
else:
outputs = (hidden_states,) + outputs[1:]
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if encoder_hidden_states is not None:
outputs += (cross_attn_weights,)
return outputs # hidden_states, present, (attentions, cross_attentions)
return outputs
# Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->GPT2
@@ -565,6 +567,8 @@ class GPT2PreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_supports_static_cache = True
def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
@@ -669,10 +673,24 @@ GPT2_INPUTS_DOCSTRING = r"""
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
`past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
their past given to this model should not be passed as `input_ids` as they have already been computed.
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
Two formats are allowed:
- a [`~cache_utils.Cache`] instance, see our
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
cache format.
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
legacy cache format will be returned.
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
of shape `(batch_size, sequence_length)`.
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
@@ -721,6 +739,10 @@ GPT2_INPUTS_DOCSTRING = r"""
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
the complete sequence length.
"""
PARALLELIZE_DOCSTRING = r"""
This is an experimental feature and is a subject to change at a moment's notice.
@@ -868,7 +890,8 @@ class GPT2Model(GPT2PreTrainedModel):
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
past_key_values: Optional[Union[Tuple[Tuple[torch.Tensor]], Cache]] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
@@ -906,51 +929,56 @@ class GPT2Model(GPT2PreTrainedModel):
if token_type_ids is not None:
token_type_ids = token_type_ids.view(-1, input_shape[-1])
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# based on pattern from src/transformers/models/whisper/modeling_whisper.py::WhisperDecoder
return_legacy_cache = False
if use_cache:
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.h))
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0)
return_legacy_cache = True
past_key_values = DynamicCache()
elif not isinstance(past_key_values, Cache):
return_legacy_cache = True
logger.warning_once(
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.53.0. "
"You should pass an instance of `Cache` instead, e.g. "
"`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
)
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if self.config.add_cross_attention and not isinstance(past_key_values, EncoderDecoderCache):
past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)
if cache_position is None:
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
cache_position = torch.arange(
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds.to(inputs_embeds.device)
# Attention mask.
_use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
attention_mask = attention_mask.view(batch_size, -1) if attention_mask is not None else None
if self._attn_implementation == "flash_attention_2":
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif _use_sdpa:
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask=attention_mask,
input_shape=(batch_size, input_shape[-1]),
inputs_embeds=inputs_embeds,
past_key_values_length=past_length,
# ._update_causal_mask() and ._prepare_4d_causal_attention_mask_with_cache_position() copied from LlamaModel
if attention_mask is not None and attention_mask.ndim < 4:
attention_mask = attention_mask.view(batch_size, -1)
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
else:
if attention_mask is not None:
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask[:, None, None, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and the dtype's smallest value for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
_use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
if self.config.add_cross_attention and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
@@ -979,25 +1007,13 @@ class GPT2Model(GPT2PreTrainedModel):
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
presents = () if use_cache else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
all_hidden_states = () if output_hidden_states else None
for i in range(len(self.h)):
block, layer_past = self.h[i], past_key_values[i]
for i, block in enumerate(self.h):
# Model parallel
if self.model_parallel:
torch.cuda.set_device(hidden_states.device)
# Ensure layer_past is on same device as hidden_states (might not be correct)
if layer_past is not None:
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
# Ensure that attention_mask is always on the same device as hidden_states
if attention_mask is not None:
attention_mask = attention_mask.to(hidden_states.device)
@@ -1010,8 +1026,9 @@ class GPT2Model(GPT2PreTrainedModel):
outputs = self._gradient_checkpointing_func(
block.__call__,
hidden_states,
None,
attention_mask,
past_key_values,
cache_position,
causal_mask,
head_mask[i],
encoder_hidden_states,
encoder_attention_mask,
@@ -1021,8 +1038,9 @@ class GPT2Model(GPT2PreTrainedModel):
else:
outputs = block(
hidden_states,
layer_past=layer_past,
attention_mask=attention_mask,
past_key_value=past_key_values,
cache_position=cache_position,
attention_mask=causal_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
@@ -1031,13 +1049,11 @@ class GPT2Model(GPT2PreTrainedModel):
)
hidden_states = outputs[0]
if use_cache is True:
presents = presents + (outputs[1],)
if output_attentions:
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
all_self_attentions = all_self_attentions + (outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
all_cross_attentions = all_cross_attentions + (outputs[2],)
# Model Parallel: If it's the last layer for that device, put things on the next device
if self.model_parallel:
@@ -1052,21 +1068,149 @@ class GPT2Model(GPT2PreTrainedModel):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
past_key_values = past_key_values if use_cache else None
if return_legacy_cache:
past_key_values = (
past_key_values.self_attention_cache.to_legacy_cache()
if self.config.add_cross_attention
else past_key_values.to_legacy_cache()
)
if not return_dict:
return tuple(
v
for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=presents,
past_key_values=past_key_values,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
@add_start_docstrings(
"""
@@ -1137,6 +1281,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel, GenerationMixin):
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
@@ -1163,6 +1308,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel, GenerationMixin):
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
cache_position=cache_position,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
@@ -1206,20 +1352,6 @@ class GPT2LMHeadModel(GPT2PreTrainedModel, GenerationMixin):
cross_attentions=transformer_outputs.cross_attentions,
)
@staticmethod
def _reorder_cache(
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
"""
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
beam_idx at every generation step.
"""
return tuple(
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
for layer_past in past_key_values
)
@add_start_docstrings(
"""
@@ -1292,6 +1424,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel, GenerationMixin):
self,
input_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
@@ -1350,6 +1483,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel, GenerationMixin):
transformer_outputs = self.transformer(
input_ids,
past_key_values=past_key_values,
cache_position=cache_position,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,

View File

@@ -40,9 +40,9 @@ if is_torch_available():
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
ClvpForCausalLM,
DynamicCache,
GenerationConfig,
GPT2LMHeadModel,
LlamaConfig,
SinkCache,
StaticCache,
@@ -103,7 +103,7 @@ class CacheTest(unittest.TestCase):
def test_reorder_cache_retrocompatibility(self):
"""Tests that Cache.reorder_cache is retrocompatible with the legacy code path"""
legacy_reorder_fn = GPT2LMHeadModel._reorder_cache # An example of a legacy `_reorder_cache` function
legacy_reorder_fn = ClvpForCausalLM._reorder_cache # An example of a legacy `_reorder_cache` function
legacy_cache = ()
new_cache = DynamicCache()