GPT2Model StaticCache support (#35761)
* initial GPT2 changes * causal_mask support * return_legacy_cache * cleanup * fix1 * outputs shape fixes * gpt2 return fix * pkv, attn fixes * fix dual_head * is_causal arg fix * decision transformer updated * style fix * batch_size from inputs_embeds * DecisionTransformerModel fixes * cross-attn support + cache warning * x-attn @decision * EDCache proper init * simplified logic in `if use_cache:` for GPT2Model * @deprecate_kwarg for DecisionTr attn fwd * @deprecate_kwarg in gpt2 * deprecation version updated to 4.51 * kwargs in gradient_checkpointing_fn * rename next_cache to past_key_values * attention_mask prep * +cache_position in GPT2DoubleHeadsModel * undo kwargs in gradient checkpointing * moved up `if self.gradient_checkpointing` * consistency in decision_transformer * pastkv, cache_pos in grad_checkpt args * rm _reorder_cache * output_attentions streamlined * decision_transformer consistency * return_legacy_cache improved * ClvpForCausalLM used for legacy cache test now * is_causal fixed * attn_output cleanup * consistency @ decision_transformer * Updated deprecation notice version to 4.52 * upd deprecation * consistent legacy cache code in decision transformers\ * next_cache -> past_kv in decision_tr * cache support flags in decision_transf * rm legacy cache warning * consistency in cache init for decision transf * no Static Cache for Decision Transformer --------- Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
This commit is contained in:
@@ -1589,7 +1589,6 @@ class ClvpForCausalLM(ClvpPreTrainedModel, GenerationMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2LMHeadModel._reorder_cache
|
|
||||||
def _reorder_cache(
|
def _reorder_cache(
|
||||||
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
|
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
|
||||||
) -> Tuple[Tuple[torch.Tensor]]:
|
) -> Tuple[Tuple[torch.Tensor]]:
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ import torch.utils.checkpoint
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from ...activations import ACT2FN
|
from ...activations import ACT2FN
|
||||||
|
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
||||||
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions
|
||||||
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
||||||
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
|
from ...pytorch_utils import Conv1D, find_pruneable_heads_and_indices, prune_conv1d_layer
|
||||||
@@ -34,6 +35,7 @@ from ...utils import (
|
|||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
|
from ...utils.deprecation import deprecate_kwarg
|
||||||
from .configuration_decision_transformer import DecisionTransformerConfig
|
from .configuration_decision_transformer import DecisionTransformerConfig
|
||||||
|
|
||||||
|
|
||||||
@@ -125,7 +127,8 @@ def eager_attention_forward(module, query, key, value, attention_mask, head_mask
|
|||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# Apply the attention mask
|
# Apply the attention mask
|
||||||
attn_weights = attn_weights + attention_mask
|
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||||
|
attn_weights = attn_weights + causal_mask
|
||||||
|
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
@@ -257,19 +260,21 @@ class DecisionTransformerGPT2Attention(nn.Module):
|
|||||||
|
|
||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
@deprecate_kwarg("layer_past", new_name="past_key_value", version="4.53.0", raise_if_both_names=True)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
||||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Cache] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
use_cache: Optional[bool] = False,
|
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
||||||
if encoder_hidden_states is not None:
|
is_cross_attention = encoder_hidden_states is not None
|
||||||
|
if is_cross_attention:
|
||||||
if not hasattr(self, "q_attn"):
|
if not hasattr(self, "q_attn"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"If class is used as cross attention, the weights `q_attn` have to be defined. "
|
"If class is used as cross attention, the weights `q_attn` have to be defined. "
|
||||||
@@ -289,17 +294,17 @@ class DecisionTransformerGPT2Attention(nn.Module):
|
|||||||
key_states = key_states.view(shape_kv).transpose(1, 2)
|
key_states = key_states.view(shape_kv).transpose(1, 2)
|
||||||
value_states = value_states.view(shape_kv).transpose(1, 2)
|
value_states = value_states.view(shape_kv).transpose(1, 2)
|
||||||
|
|
||||||
if layer_past is not None:
|
if past_key_value is not None:
|
||||||
past_key, past_value = layer_past
|
if isinstance(past_key_value, EncoderDecoderCache):
|
||||||
key_states = torch.cat((past_key, key_states), dim=-2)
|
if is_cross_attention:
|
||||||
value_states = torch.cat((past_value, value_states), dim=-2)
|
past_key_value = past_key_value.cross_attention_cache
|
||||||
|
else:
|
||||||
|
past_key_value = past_key_value.self_attention_cache
|
||||||
|
cache_kwargs = {"cache_position": cache_position}
|
||||||
|
key_states, value_states = past_key_value.update(
|
||||||
|
key_states, value_states, self.layer_idx, cache_kwargs=cache_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
if use_cache is True:
|
|
||||||
present = (key_states, value_states)
|
|
||||||
else:
|
|
||||||
present = None
|
|
||||||
|
|
||||||
is_cross_attention = encoder_hidden_states is not None
|
|
||||||
is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
|
is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
|
||||||
|
|
||||||
using_eager = self.config._attn_implementation == "eager"
|
using_eager = self.config._attn_implementation == "eager"
|
||||||
@@ -338,11 +343,7 @@ class DecisionTransformerGPT2Attention(nn.Module):
|
|||||||
attn_output = self.c_proj(attn_output)
|
attn_output = self.c_proj(attn_output)
|
||||||
attn_output = self.resid_dropout(attn_output)
|
attn_output = self.resid_dropout(attn_output)
|
||||||
|
|
||||||
outputs = (attn_output, present)
|
return attn_output, attn_weights
|
||||||
if output_attentions:
|
|
||||||
outputs += (attn_weights,)
|
|
||||||
|
|
||||||
return outputs # a, present, (attentions)
|
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP with GPT2->DecisionTransformerGPT2
|
# Copied from transformers.models.gpt2.modeling_gpt2.GPT2MLP with GPT2->DecisionTransformerGPT2
|
||||||
@@ -383,10 +384,12 @@ class DecisionTransformerGPT2Block(nn.Module):
|
|||||||
|
|
||||||
self.mlp = DecisionTransformerGPT2MLP(inner_dim, config)
|
self.mlp = DecisionTransformerGPT2MLP(inner_dim, config)
|
||||||
|
|
||||||
|
@deprecate_kwarg("layer_past", new_name="past_key_value", version="4.53.0", raise_if_both_names=True)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
||||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Cache] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
@@ -396,16 +399,15 @@ class DecisionTransformerGPT2Block(nn.Module):
|
|||||||
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
|
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.ln_1(hidden_states)
|
hidden_states = self.ln_1(hidden_states)
|
||||||
attn_outputs = self.attn(
|
attn_output, self_attn_weights = self.attn(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
layer_past=layer_past,
|
past_key_value=past_key_value,
|
||||||
|
cache_position=cache_position,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
|
|
||||||
outputs = attn_outputs[1:]
|
|
||||||
# residual connection
|
# residual connection
|
||||||
hidden_states = attn_output + residual
|
hidden_states = attn_output + residual
|
||||||
|
|
||||||
@@ -418,18 +420,17 @@ class DecisionTransformerGPT2Block(nn.Module):
|
|||||||
)
|
)
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.ln_cross_attn(hidden_states)
|
hidden_states = self.ln_cross_attn(hidden_states)
|
||||||
cross_attn_outputs = self.crossattention(
|
cross_attn_output, cross_attn_weights = self.crossattention(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
|
past_key_value=past_key_value,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
attn_output = cross_attn_outputs[0]
|
|
||||||
# residual connection
|
# residual connection
|
||||||
hidden_states = residual + attn_output
|
hidden_states = residual + cross_attn_output
|
||||||
outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
|
|
||||||
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.ln_2(hidden_states)
|
hidden_states = self.ln_2(hidden_states)
|
||||||
@@ -437,12 +438,13 @@ class DecisionTransformerGPT2Block(nn.Module):
|
|||||||
# residual connection
|
# residual connection
|
||||||
hidden_states = residual + feed_forward_hidden_states
|
hidden_states = residual + feed_forward_hidden_states
|
||||||
|
|
||||||
if use_cache:
|
outputs = (hidden_states,)
|
||||||
outputs = (hidden_states,) + outputs
|
if output_attentions:
|
||||||
else:
|
outputs += (self_attn_weights,)
|
||||||
outputs = (hidden_states,) + outputs[1:]
|
if encoder_hidden_states is not None:
|
||||||
|
outputs += (cross_attn_weights,)
|
||||||
|
|
||||||
return outputs # hidden_states, present, (attentions, cross_attentions)
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel):
|
class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel):
|
||||||
@@ -456,6 +458,8 @@ class DecisionTransformerGPT2PreTrainedModel(PreTrainedModel):
|
|||||||
base_model_prefix = "transformer"
|
base_model_prefix = "transformer"
|
||||||
is_parallelizable = True
|
is_parallelizable = True
|
||||||
supports_gradient_checkpointing = True
|
supports_gradient_checkpointing = True
|
||||||
|
_supports_cache_class = True
|
||||||
|
_supports_static_cache = False
|
||||||
|
|
||||||
def __init__(self, *inputs, **kwargs):
|
def __init__(self, *inputs, **kwargs):
|
||||||
super().__init__(*inputs, **kwargs)
|
super().__init__(*inputs, **kwargs)
|
||||||
@@ -521,6 +525,7 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
|
|||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
token_type_ids: Optional[torch.LongTensor] = None,
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
@@ -558,14 +563,31 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
|
|||||||
if token_type_ids is not None:
|
if token_type_ids is not None:
|
||||||
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
||||||
|
|
||||||
if past_key_values is None:
|
# based on pattern from src/transformers/models/whisper/modeling_whisper.py::WhisperDecoder and similar addition in GPT2Model
|
||||||
past_length = 0
|
return_legacy_cache = False
|
||||||
past_key_values = tuple([None] * len(self.h))
|
if use_cache:
|
||||||
else:
|
if past_key_values is None:
|
||||||
past_length = past_key_values[0][0].size(-2)
|
return_legacy_cache = True
|
||||||
|
past_key_values = DynamicCache()
|
||||||
|
elif not isinstance(past_key_values, Cache):
|
||||||
|
return_legacy_cache = True
|
||||||
|
logger.warning_once(
|
||||||
|
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.53.0. "
|
||||||
|
"You should pass an instance of `Cache` instead, e.g. "
|
||||||
|
"`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
|
||||||
|
)
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
|
||||||
|
if self.config.add_cross_attention and not isinstance(past_key_values, EncoderDecoderCache):
|
||||||
|
past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
|
||||||
|
|
||||||
|
if cache_position is None:
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
cache_position = torch.arange(
|
||||||
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||||
|
)
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
position_ids = cache_position.unsqueeze(0)
|
||||||
position_ids = position_ids.unsqueeze(0)
|
|
||||||
|
|
||||||
# Attention mask.
|
# Attention mask.
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
@@ -624,17 +646,13 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
|
|||||||
)
|
)
|
||||||
use_cache = False
|
use_cache = False
|
||||||
|
|
||||||
presents = () if use_cache else None
|
|
||||||
all_self_attentions = () if output_attentions else None
|
all_self_attentions = () if output_attentions else None
|
||||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
for i, block in enumerate(self.h):
|
||||||
# Model parallel
|
# Model parallel
|
||||||
if self.model_parallel:
|
if self.model_parallel:
|
||||||
torch.cuda.set_device(hidden_states.device)
|
torch.cuda.set_device(hidden_states.device)
|
||||||
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
|
||||||
if layer_past is not None:
|
|
||||||
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
|
|
||||||
# Ensure that attention_mask is always on the same device as hidden_states
|
# Ensure that attention_mask is always on the same device as hidden_states
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attention_mask = attention_mask.to(hidden_states.device)
|
attention_mask = attention_mask.to(hidden_states.device)
|
||||||
@@ -648,6 +666,7 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
|
|||||||
block.__call__,
|
block.__call__,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
attention_mask,
|
attention_mask,
|
||||||
head_mask[i],
|
head_mask[i],
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
@@ -658,7 +677,8 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
outputs = block(
|
outputs = block(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
layer_past=layer_past,
|
past_key_value=past_key_values,
|
||||||
|
cache_position=cache_position,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
head_mask=head_mask[i],
|
head_mask=head_mask[i],
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
@@ -668,13 +688,11 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
if use_cache is True:
|
|
||||||
presents = presents + (outputs[1],)
|
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
all_self_attentions = all_self_attentions + (outputs[1],)
|
||||||
if self.config.add_cross_attention:
|
if self.config.add_cross_attention:
|
||||||
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
all_cross_attentions = all_cross_attentions + (outputs[2],)
|
||||||
|
|
||||||
# Model Parallel: If it's the last layer for that device, put things on the next device
|
# Model Parallel: If it's the last layer for that device, put things on the next device
|
||||||
if self.model_parallel:
|
if self.model_parallel:
|
||||||
@@ -689,16 +707,23 @@ class DecisionTransformerGPT2Model(DecisionTransformerGPT2PreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
past_key_values = past_key_values if use_cache else None
|
||||||
|
if return_legacy_cache:
|
||||||
|
past_key_values = (
|
||||||
|
past_key_values.self_attention_cache.to_legacy_cache()
|
||||||
|
if self.config.add_cross_attention
|
||||||
|
else past_key_values.to_legacy_cache()
|
||||||
|
)
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(
|
return tuple(
|
||||||
v
|
v
|
||||||
for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
|
for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions]
|
||||||
if v is not None
|
if v is not None
|
||||||
)
|
)
|
||||||
|
|
||||||
return BaseModelOutputWithPastAndCrossAttentions(
|
return BaseModelOutputWithPastAndCrossAttentions(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
past_key_values=presents,
|
past_key_values=past_key_values,
|
||||||
hidden_states=all_hidden_states,
|
hidden_states=all_hidden_states,
|
||||||
attentions=all_self_attentions,
|
attentions=all_self_attentions,
|
||||||
cross_attentions=all_cross_attentions,
|
cross_attentions=all_cross_attentions,
|
||||||
|
|||||||
@@ -27,8 +27,9 @@ from torch import nn
|
|||||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||||
|
|
||||||
from ...activations import ACT2FN, get_activation
|
from ...activations import ACT2FN, get_activation
|
||||||
|
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
|
||||||
from ...generation import GenerationMixin
|
from ...generation import GenerationMixin
|
||||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
|
from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_attention_mask_for_sdpa
|
||||||
from ...modeling_outputs import (
|
from ...modeling_outputs import (
|
||||||
BaseModelOutputWithPastAndCrossAttentions,
|
BaseModelOutputWithPastAndCrossAttentions,
|
||||||
CausalLMOutputWithCrossAttentions,
|
CausalLMOutputWithCrossAttentions,
|
||||||
@@ -46,6 +47,7 @@ from ...utils import (
|
|||||||
logging,
|
logging,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
)
|
)
|
||||||
|
from ...utils.deprecation import deprecate_kwarg
|
||||||
from ...utils.model_parallel_utils import assert_device_map, get_device_map
|
from ...utils.model_parallel_utils import assert_device_map, get_device_map
|
||||||
from .configuration_gpt2 import GPT2Config
|
from .configuration_gpt2 import GPT2Config
|
||||||
|
|
||||||
@@ -136,7 +138,8 @@ def eager_attention_forward(module, query, key, value, attention_mask, head_mask
|
|||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# Apply the attention mask
|
# Apply the attention mask
|
||||||
attn_weights = attn_weights + attention_mask
|
causal_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||||
|
attn_weights = attn_weights + causal_mask
|
||||||
|
|
||||||
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1)
|
||||||
|
|
||||||
@@ -267,19 +270,21 @@ class GPT2Attention(nn.Module):
|
|||||||
|
|
||||||
return attn_output, attn_weights
|
return attn_output, attn_weights
|
||||||
|
|
||||||
|
@deprecate_kwarg("layer_past", new_name="past_key_value", version="4.53.0", raise_if_both_names=True)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
||||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Cache] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
use_cache: Optional[bool] = False,
|
|
||||||
output_attentions: Optional[bool] = False,
|
output_attentions: Optional[bool] = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]], ...]:
|
||||||
if encoder_hidden_states is not None:
|
is_cross_attention = encoder_hidden_states is not None
|
||||||
|
if is_cross_attention:
|
||||||
if not hasattr(self, "q_attn"):
|
if not hasattr(self, "q_attn"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"If class is used as cross attention, the weights `q_attn` have to be defined. "
|
"If class is used as cross attention, the weights `q_attn` have to be defined. "
|
||||||
@@ -299,17 +304,17 @@ class GPT2Attention(nn.Module):
|
|||||||
key_states = key_states.view(shape_kv).transpose(1, 2)
|
key_states = key_states.view(shape_kv).transpose(1, 2)
|
||||||
value_states = value_states.view(shape_kv).transpose(1, 2)
|
value_states = value_states.view(shape_kv).transpose(1, 2)
|
||||||
|
|
||||||
if layer_past is not None:
|
if past_key_value is not None:
|
||||||
past_key, past_value = layer_past
|
if isinstance(past_key_value, EncoderDecoderCache):
|
||||||
key_states = torch.cat((past_key, key_states), dim=-2)
|
if is_cross_attention:
|
||||||
value_states = torch.cat((past_value, value_states), dim=-2)
|
past_key_value = past_key_value.cross_attention_cache
|
||||||
|
else:
|
||||||
|
past_key_value = past_key_value.self_attention_cache
|
||||||
|
cache_kwargs = {"cache_position": cache_position}
|
||||||
|
key_states, value_states = past_key_value.update(
|
||||||
|
key_states, value_states, self.layer_idx, cache_kwargs=cache_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
if use_cache is True:
|
|
||||||
present = (key_states, value_states)
|
|
||||||
else:
|
|
||||||
present = None
|
|
||||||
|
|
||||||
is_cross_attention = encoder_hidden_states is not None
|
|
||||||
is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
|
is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
|
||||||
|
|
||||||
using_eager = self.config._attn_implementation == "eager"
|
using_eager = self.config._attn_implementation == "eager"
|
||||||
@@ -348,11 +353,7 @@ class GPT2Attention(nn.Module):
|
|||||||
attn_output = self.c_proj(attn_output)
|
attn_output = self.c_proj(attn_output)
|
||||||
attn_output = self.resid_dropout(attn_output)
|
attn_output = self.resid_dropout(attn_output)
|
||||||
|
|
||||||
outputs = (attn_output, present)
|
return attn_output, attn_weights
|
||||||
if output_attentions:
|
|
||||||
outputs += (attn_weights,)
|
|
||||||
|
|
||||||
return outputs # a, present, (attentions)
|
|
||||||
|
|
||||||
|
|
||||||
class GPT2MLP(nn.Module):
|
class GPT2MLP(nn.Module):
|
||||||
@@ -388,10 +389,12 @@ class GPT2Block(nn.Module):
|
|||||||
|
|
||||||
self.mlp = GPT2MLP(inner_dim, config)
|
self.mlp = GPT2MLP(inner_dim, config)
|
||||||
|
|
||||||
|
@deprecate_kwarg("layer_past", new_name="past_key_value", version="4.53.0", raise_if_both_names=True)
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
hidden_states: Optional[Tuple[torch.FloatTensor]],
|
||||||
layer_past: Optional[Tuple[torch.Tensor]] = None,
|
past_key_value: Optional[Cache] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
head_mask: Optional[torch.FloatTensor] = None,
|
head_mask: Optional[torch.FloatTensor] = None,
|
||||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||||
@@ -401,16 +404,15 @@ class GPT2Block(nn.Module):
|
|||||||
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
|
) -> Union[Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]]:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.ln_1(hidden_states)
|
hidden_states = self.ln_1(hidden_states)
|
||||||
attn_outputs = self.attn(
|
attn_output, self_attn_weights = self.attn(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
layer_past=layer_past,
|
past_key_value=past_key_value,
|
||||||
|
cache_position=cache_position,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
|
|
||||||
outputs = attn_outputs[1:]
|
|
||||||
# residual connection
|
# residual connection
|
||||||
hidden_states = attn_output + residual
|
hidden_states = attn_output + residual
|
||||||
|
|
||||||
@@ -423,18 +425,17 @@ class GPT2Block(nn.Module):
|
|||||||
)
|
)
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.ln_cross_attn(hidden_states)
|
hidden_states = self.ln_cross_attn(hidden_states)
|
||||||
cross_attn_outputs = self.crossattention(
|
cross_attn_output, cross_attn_weights = self.crossattention(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
|
past_key_value=past_key_value,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
)
|
)
|
||||||
attn_output = cross_attn_outputs[0]
|
|
||||||
# residual connection
|
# residual connection
|
||||||
hidden_states = residual + attn_output
|
hidden_states = residual + cross_attn_output
|
||||||
outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
|
|
||||||
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.ln_2(hidden_states)
|
hidden_states = self.ln_2(hidden_states)
|
||||||
@@ -442,12 +443,13 @@ class GPT2Block(nn.Module):
|
|||||||
# residual connection
|
# residual connection
|
||||||
hidden_states = residual + feed_forward_hidden_states
|
hidden_states = residual + feed_forward_hidden_states
|
||||||
|
|
||||||
if use_cache:
|
outputs = (hidden_states,)
|
||||||
outputs = (hidden_states,) + outputs
|
if output_attentions:
|
||||||
else:
|
outputs += (self_attn_weights,)
|
||||||
outputs = (hidden_states,) + outputs[1:]
|
if encoder_hidden_states is not None:
|
||||||
|
outputs += (cross_attn_weights,)
|
||||||
|
|
||||||
return outputs # hidden_states, present, (attentions, cross_attentions)
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
# Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->GPT2
|
# Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->GPT2
|
||||||
@@ -565,6 +567,8 @@ class GPT2PreTrainedModel(PreTrainedModel):
|
|||||||
_skip_keys_device_placement = "past_key_values"
|
_skip_keys_device_placement = "past_key_values"
|
||||||
_supports_flash_attn_2 = True
|
_supports_flash_attn_2 = True
|
||||||
_supports_sdpa = True
|
_supports_sdpa = True
|
||||||
|
_supports_cache_class = True
|
||||||
|
_supports_static_cache = True
|
||||||
|
|
||||||
def __init__(self, *inputs, **kwargs):
|
def __init__(self, *inputs, **kwargs):
|
||||||
super().__init__(*inputs, **kwargs)
|
super().__init__(*inputs, **kwargs)
|
||||||
@@ -669,10 +673,24 @@ GPT2_INPUTS_DOCSTRING = r"""
|
|||||||
[`PreTrainedTokenizer.__call__`] for details.
|
[`PreTrainedTokenizer.__call__`] for details.
|
||||||
|
|
||||||
[What are input IDs?](../glossary#input-ids)
|
[What are input IDs?](../glossary#input-ids)
|
||||||
past_key_values (`Tuple[Tuple[torch.Tensor]]` of length `config.n_layers`):
|
past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
|
||||||
Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
|
Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
|
||||||
`past_key_values` output below). Can be used to speed up sequential decoding. The `input_ids` which have
|
blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
|
||||||
their past given to this model should not be passed as `input_ids` as they have already been computed.
|
returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
|
||||||
|
|
||||||
|
Two formats are allowed:
|
||||||
|
- a [`~cache_utils.Cache`] instance, see our
|
||||||
|
[kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache);
|
||||||
|
- Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
|
||||||
|
shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
|
||||||
|
cache format.
|
||||||
|
|
||||||
|
The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
|
||||||
|
legacy cache format will be returned.
|
||||||
|
|
||||||
|
If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
|
||||||
|
have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
|
||||||
|
of shape `(batch_size, sequence_length)`.
|
||||||
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||||
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
|
||||||
|
|
||||||
@@ -721,6 +739,10 @@ GPT2_INPUTS_DOCSTRING = r"""
|
|||||||
more detail.
|
more detail.
|
||||||
return_dict (`bool`, *optional*):
|
return_dict (`bool`, *optional*):
|
||||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||||
|
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
|
||||||
|
Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
|
||||||
|
this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
|
||||||
|
the complete sequence length.
|
||||||
"""
|
"""
|
||||||
PARALLELIZE_DOCSTRING = r"""
|
PARALLELIZE_DOCSTRING = r"""
|
||||||
This is an experimental feature and is a subject to change at a moment's notice.
|
This is an experimental feature and is a subject to change at a moment's notice.
|
||||||
@@ -868,7 +890,8 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
past_key_values: Optional[Union[Tuple[Tuple[torch.Tensor]], Cache]] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
token_type_ids: Optional[torch.LongTensor] = None,
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
@@ -906,51 +929,56 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
if token_type_ids is not None:
|
if token_type_ids is not None:
|
||||||
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
||||||
|
|
||||||
if past_key_values is None:
|
if self.gradient_checkpointing and self.training:
|
||||||
past_length = 0
|
if use_cache:
|
||||||
past_key_values = tuple([None] * len(self.h))
|
logger.warning_once(
|
||||||
else:
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
past_length = past_key_values[0][0].size(-2)
|
)
|
||||||
if position_ids is None:
|
use_cache = False
|
||||||
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
|
||||||
position_ids = position_ids.unsqueeze(0)
|
# based on pattern from src/transformers/models/whisper/modeling_whisper.py::WhisperDecoder
|
||||||
|
return_legacy_cache = False
|
||||||
|
if use_cache:
|
||||||
|
if past_key_values is None:
|
||||||
|
return_legacy_cache = True
|
||||||
|
past_key_values = DynamicCache()
|
||||||
|
elif not isinstance(past_key_values, Cache):
|
||||||
|
return_legacy_cache = True
|
||||||
|
logger.warning_once(
|
||||||
|
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.53.0. "
|
||||||
|
"You should pass an instance of `Cache` instead, e.g. "
|
||||||
|
"`past_key_values=DynamicCache.from_legacy_cache(past_key_values)`."
|
||||||
|
)
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
|
||||||
|
if self.config.add_cross_attention and not isinstance(past_key_values, EncoderDecoderCache):
|
||||||
|
past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = self.wte(input_ids)
|
inputs_embeds = self.wte(input_ids)
|
||||||
|
|
||||||
|
if cache_position is None:
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
cache_position = torch.arange(
|
||||||
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
|
||||||
|
)
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = cache_position.unsqueeze(0)
|
||||||
|
|
||||||
position_embeds = self.wpe(position_ids)
|
position_embeds = self.wpe(position_ids)
|
||||||
hidden_states = inputs_embeds + position_embeds.to(inputs_embeds.device)
|
hidden_states = inputs_embeds + position_embeds.to(inputs_embeds.device)
|
||||||
|
|
||||||
# Attention mask.
|
# Attention mask.
|
||||||
_use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
|
# ._update_causal_mask() and ._prepare_4d_causal_attention_mask_with_cache_position() copied from LlamaModel
|
||||||
attention_mask = attention_mask.view(batch_size, -1) if attention_mask is not None else None
|
if attention_mask is not None and attention_mask.ndim < 4:
|
||||||
if self._attn_implementation == "flash_attention_2":
|
attention_mask = attention_mask.view(batch_size, -1)
|
||||||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
causal_mask = self._update_causal_mask(
|
||||||
elif _use_sdpa:
|
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||||
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
)
|
||||||
attention_mask=attention_mask,
|
|
||||||
input_shape=(batch_size, input_shape[-1]),
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
past_key_values_length=past_length,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
if attention_mask is not None:
|
|
||||||
# We create a 3D attention mask from a 2D tensor mask.
|
|
||||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
|
||||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
|
||||||
# this attention mask is more simple than the triangular masking of causal attention
|
|
||||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
|
||||||
attention_mask = attention_mask[:, None, None, :]
|
|
||||||
|
|
||||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
|
||||||
# masked positions, this operation will create a tensor which is 0.0 for
|
|
||||||
# positions we want to attend and the dtype's smallest value for masked positions.
|
|
||||||
# Since we are adding it to the raw scores before the softmax, this is
|
|
||||||
# effectively the same as removing these entirely.
|
|
||||||
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
|
||||||
attention_mask = (1.0 - attention_mask) * torch.finfo(self.dtype).min
|
|
||||||
|
|
||||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||||
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
||||||
|
_use_sdpa = self._attn_implementation == "sdpa" and output_attentions is False and head_mask is None
|
||||||
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
||||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||||
@@ -979,25 +1007,13 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
|
|
||||||
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
|
output_shape = (-1,) + input_shape[1:] + (hidden_states.size(-1),)
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
|
||||||
if use_cache:
|
|
||||||
logger.warning_once(
|
|
||||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
|
||||||
)
|
|
||||||
use_cache = False
|
|
||||||
|
|
||||||
presents = () if use_cache else None
|
|
||||||
all_self_attentions = () if output_attentions else None
|
all_self_attentions = () if output_attentions else None
|
||||||
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
for i in range(len(self.h)):
|
for i, block in enumerate(self.h):
|
||||||
block, layer_past = self.h[i], past_key_values[i]
|
|
||||||
# Model parallel
|
# Model parallel
|
||||||
if self.model_parallel:
|
if self.model_parallel:
|
||||||
torch.cuda.set_device(hidden_states.device)
|
torch.cuda.set_device(hidden_states.device)
|
||||||
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
|
||||||
if layer_past is not None:
|
|
||||||
layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
|
|
||||||
# Ensure that attention_mask is always on the same device as hidden_states
|
# Ensure that attention_mask is always on the same device as hidden_states
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attention_mask = attention_mask.to(hidden_states.device)
|
attention_mask = attention_mask.to(hidden_states.device)
|
||||||
@@ -1010,8 +1026,9 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
outputs = self._gradient_checkpointing_func(
|
outputs = self._gradient_checkpointing_func(
|
||||||
block.__call__,
|
block.__call__,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
None,
|
past_key_values,
|
||||||
attention_mask,
|
cache_position,
|
||||||
|
causal_mask,
|
||||||
head_mask[i],
|
head_mask[i],
|
||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
@@ -1021,8 +1038,9 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
else:
|
else:
|
||||||
outputs = block(
|
outputs = block(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
layer_past=layer_past,
|
past_key_value=past_key_values,
|
||||||
attention_mask=attention_mask,
|
cache_position=cache_position,
|
||||||
|
attention_mask=causal_mask,
|
||||||
head_mask=head_mask[i],
|
head_mask=head_mask[i],
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
@@ -1031,13 +1049,11 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
if use_cache is True:
|
|
||||||
presents = presents + (outputs[1],)
|
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
|
all_self_attentions = all_self_attentions + (outputs[1],)
|
||||||
if self.config.add_cross_attention:
|
if self.config.add_cross_attention:
|
||||||
all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)
|
all_cross_attentions = all_cross_attentions + (outputs[2],)
|
||||||
|
|
||||||
# Model Parallel: If it's the last layer for that device, put things on the next device
|
# Model Parallel: If it's the last layer for that device, put things on the next device
|
||||||
if self.model_parallel:
|
if self.model_parallel:
|
||||||
@@ -1052,21 +1068,149 @@ class GPT2Model(GPT2PreTrainedModel):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
|
past_key_values = past_key_values if use_cache else None
|
||||||
|
if return_legacy_cache:
|
||||||
|
past_key_values = (
|
||||||
|
past_key_values.self_attention_cache.to_legacy_cache()
|
||||||
|
if self.config.add_cross_attention
|
||||||
|
else past_key_values.to_legacy_cache()
|
||||||
|
)
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
return tuple(
|
return tuple(
|
||||||
v
|
v
|
||||||
for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
|
for v in [hidden_states, past_key_values, all_hidden_states, all_self_attentions, all_cross_attentions]
|
||||||
if v is not None
|
if v is not None
|
||||||
)
|
)
|
||||||
|
|
||||||
return BaseModelOutputWithPastAndCrossAttentions(
|
return BaseModelOutputWithPastAndCrossAttentions(
|
||||||
last_hidden_state=hidden_states,
|
last_hidden_state=hidden_states,
|
||||||
past_key_values=presents,
|
past_key_values=past_key_values,
|
||||||
hidden_states=all_hidden_states,
|
hidden_states=all_hidden_states,
|
||||||
attentions=all_self_attentions,
|
attentions=all_self_attentions,
|
||||||
cross_attentions=all_cross_attentions,
|
cross_attentions=all_cross_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _update_causal_mask(
|
||||||
|
self,
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
input_tensor: torch.Tensor,
|
||||||
|
cache_position: torch.Tensor,
|
||||||
|
past_key_values: Cache,
|
||||||
|
output_attentions: bool,
|
||||||
|
):
|
||||||
|
if self.config._attn_implementation == "flash_attention_2":
|
||||||
|
if attention_mask is not None and 0.0 in attention_mask:
|
||||||
|
return attention_mask
|
||||||
|
return None
|
||||||
|
|
||||||
|
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||||
|
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||||
|
# to infer the attention mask.
|
||||||
|
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||||
|
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||||
|
|
||||||
|
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||||
|
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
||||||
|
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||||
|
attention_mask,
|
||||||
|
inputs_embeds=input_tensor,
|
||||||
|
past_key_values_length=past_seen_tokens,
|
||||||
|
is_training=self.training,
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
|
dtype, device = input_tensor.dtype, input_tensor.device
|
||||||
|
sequence_length = input_tensor.shape[1]
|
||||||
|
if using_static_cache:
|
||||||
|
target_length = past_key_values.get_max_cache_shape()
|
||||||
|
else:
|
||||||
|
target_length = (
|
||||||
|
attention_mask.shape[-1]
|
||||||
|
if isinstance(attention_mask, torch.Tensor)
|
||||||
|
else past_seen_tokens + sequence_length + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||||
|
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask,
|
||||||
|
sequence_length=sequence_length,
|
||||||
|
target_length=target_length,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
cache_position=cache_position,
|
||||||
|
batch_size=input_tensor.shape[0],
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.config._attn_implementation == "sdpa"
|
||||||
|
and attention_mask is not None
|
||||||
|
and attention_mask.device.type == "cuda"
|
||||||
|
and not output_attentions
|
||||||
|
):
|
||||||
|
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||||
|
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||||
|
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
|
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||||
|
|
||||||
|
return causal_mask
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||||
|
attention_mask: torch.Tensor,
|
||||||
|
sequence_length: int,
|
||||||
|
target_length: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
device: torch.device,
|
||||||
|
cache_position: torch.Tensor,
|
||||||
|
batch_size: int,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||||
|
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
attention_mask (`torch.Tensor`):
|
||||||
|
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||||
|
`(batch_size, 1, query_length, key_value_length)`.
|
||||||
|
sequence_length (`int`):
|
||||||
|
The sequence length being processed.
|
||||||
|
target_length (`int`):
|
||||||
|
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||||
|
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||||
|
dtype (`torch.dtype`):
|
||||||
|
The dtype to use for the 4D attention mask.
|
||||||
|
device (`torch.device`):
|
||||||
|
The device to plcae the 4D attention mask on.
|
||||||
|
cache_position (`torch.Tensor`):
|
||||||
|
Indices depicting the position of the input sequence tokens in the sequence.
|
||||||
|
batch_size (`torch.Tensor`):
|
||||||
|
Batch size.
|
||||||
|
"""
|
||||||
|
if attention_mask is not None and attention_mask.dim() == 4:
|
||||||
|
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||||
|
causal_mask = attention_mask
|
||||||
|
else:
|
||||||
|
min_dtype = torch.finfo(dtype).min
|
||||||
|
causal_mask = torch.full(
|
||||||
|
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
if sequence_length != 1:
|
||||||
|
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||||
|
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
|
||||||
|
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||||
|
if attention_mask is not None:
|
||||||
|
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||||
|
mask_length = attention_mask.shape[-1]
|
||||||
|
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||||
|
padding_mask = padding_mask == 0
|
||||||
|
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||||
|
padding_mask, min_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
return causal_mask
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""
|
"""
|
||||||
@@ -1137,6 +1281,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel, GenerationMixin):
|
|||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
token_type_ids: Optional[torch.LongTensor] = None,
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
@@ -1163,6 +1308,7 @@ class GPT2LMHeadModel(GPT2PreTrainedModel, GenerationMixin):
|
|||||||
input_ids,
|
input_ids,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
cache_position=cache_position,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
head_mask=head_mask,
|
head_mask=head_mask,
|
||||||
@@ -1206,20 +1352,6 @@ class GPT2LMHeadModel(GPT2PreTrainedModel, GenerationMixin):
|
|||||||
cross_attentions=transformer_outputs.cross_attentions,
|
cross_attentions=transformer_outputs.cross_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _reorder_cache(
|
|
||||||
past_key_values: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
|
|
||||||
) -> Tuple[Tuple[torch.Tensor]]:
|
|
||||||
"""
|
|
||||||
This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
|
|
||||||
[`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
|
|
||||||
beam_idx at every generation step.
|
|
||||||
"""
|
|
||||||
return tuple(
|
|
||||||
tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
|
|
||||||
for layer_past in past_key_values
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"""
|
"""
|
||||||
@@ -1292,6 +1424,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel, GenerationMixin):
|
|||||||
self,
|
self,
|
||||||
input_ids: Optional[torch.LongTensor] = None,
|
input_ids: Optional[torch.LongTensor] = None,
|
||||||
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
|
||||||
|
cache_position: Optional[torch.LongTensor] = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
token_type_ids: Optional[torch.LongTensor] = None,
|
token_type_ids: Optional[torch.LongTensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
@@ -1350,6 +1483,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel, GenerationMixin):
|
|||||||
transformer_outputs = self.transformer(
|
transformer_outputs = self.transformer(
|
||||||
input_ids,
|
input_ids,
|
||||||
past_key_values=past_key_values,
|
past_key_values=past_key_values,
|
||||||
|
cache_position=cache_position,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
|
|||||||
@@ -40,9 +40,9 @@ if is_torch_available():
|
|||||||
from transformers import (
|
from transformers import (
|
||||||
AutoModelForCausalLM,
|
AutoModelForCausalLM,
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
|
ClvpForCausalLM,
|
||||||
DynamicCache,
|
DynamicCache,
|
||||||
GenerationConfig,
|
GenerationConfig,
|
||||||
GPT2LMHeadModel,
|
|
||||||
LlamaConfig,
|
LlamaConfig,
|
||||||
SinkCache,
|
SinkCache,
|
||||||
StaticCache,
|
StaticCache,
|
||||||
@@ -103,7 +103,7 @@ class CacheTest(unittest.TestCase):
|
|||||||
|
|
||||||
def test_reorder_cache_retrocompatibility(self):
|
def test_reorder_cache_retrocompatibility(self):
|
||||||
"""Tests that Cache.reorder_cache is retrocompatible with the legacy code path"""
|
"""Tests that Cache.reorder_cache is retrocompatible with the legacy code path"""
|
||||||
legacy_reorder_fn = GPT2LMHeadModel._reorder_cache # An example of a legacy `_reorder_cache` function
|
legacy_reorder_fn = ClvpForCausalLM._reorder_cache # An example of a legacy `_reorder_cache` function
|
||||||
|
|
||||||
legacy_cache = ()
|
legacy_cache = ()
|
||||||
new_cache = DynamicCache()
|
new_cache = DynamicCache()
|
||||||
|
|||||||
Reference in New Issue
Block a user