Fix GPT2 with cross attention (#39754)
* fix * use new mask API * style * fix copies and attention tests * fix head pruning tests
This commit is contained in:
committed by
GitHub
parent
dfd616e658
commit
ccb2e0e03b
@@ -268,52 +268,61 @@ class DecisionTransformerGPT2Attention(nn.Module):
|
||||
**kwargs,
|
||||
) -> tuple[Union[torch.Tensor, tuple[torch.Tensor]], ...]:
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
if past_key_value is not None:
|
||||
if isinstance(past_key_value, EncoderDecoderCache):
|
||||
is_updated = past_key_value.is_updated.get(self.layer_idx)
|
||||
if is_cross_attention:
|
||||
# after the first generated id, we can subsequently re-use all key/value_layer from cache
|
||||
curr_past_key_value = past_key_value.cross_attention_cache
|
||||
else:
|
||||
curr_past_key_value = past_key_value.self_attention_cache
|
||||
else:
|
||||
curr_past_key_value = past_key_value
|
||||
|
||||
if is_cross_attention:
|
||||
if not hasattr(self, "q_attn"):
|
||||
raise ValueError(
|
||||
"If class is used as cross attention, the weights `q_attn` have to be defined. "
|
||||
"Please make sure to instantiate class with `DecisionTransformerGPT2Attention(..., is_cross_attention=True)`."
|
||||
)
|
||||
|
||||
query_states = self.q_attn(hidden_states)
|
||||
key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
|
||||
attention_mask = encoder_attention_mask
|
||||
|
||||
# Try to get key/value states from cache if possible
|
||||
if past_key_value is not None and is_updated:
|
||||
key_states = curr_past_key_value.layers[self.layer_idx].keys
|
||||
value_states = curr_past_key_value.layers[self.layer_idx].values
|
||||
else:
|
||||
key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
|
||||
shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
|
||||
key_states = key_states.view(shape_kv).transpose(1, 2)
|
||||
value_states = value_states.view(shape_kv).transpose(1, 2)
|
||||
else:
|
||||
query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
||||
|
||||
shape_q = (*query_states.shape[:-1], -1, self.head_dim)
|
||||
shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
|
||||
|
||||
query_states = query_states.view(shape_q).transpose(1, 2)
|
||||
key_states = key_states.view(shape_kv).transpose(1, 2)
|
||||
value_states = value_states.view(shape_kv).transpose(1, 2)
|
||||
|
||||
if past_key_value is not None:
|
||||
if isinstance(past_key_value, EncoderDecoderCache):
|
||||
if is_cross_attention:
|
||||
past_key_value = past_key_value.cross_attention_cache
|
||||
else:
|
||||
past_key_value = past_key_value.self_attention_cache
|
||||
cache_kwargs = {"cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, cache_kwargs=cache_kwargs
|
||||
shape_q = (*query_states.shape[:-1], -1, self.head_dim)
|
||||
query_states = query_states.view(shape_q).transpose(1, 2)
|
||||
|
||||
if (past_key_value is not None and not is_cross_attention) or (
|
||||
past_key_value is not None and is_cross_attention and not is_updated
|
||||
):
|
||||
# save all key/value_layer to cache to be re-used for fast auto-regressive generation
|
||||
cache_position = cache_position if not is_cross_attention else None
|
||||
key_states, value_states = curr_past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
||||
)
|
||||
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
|
||||
if is_cross_attention:
|
||||
past_key_value.is_updated[self.layer_idx] = True
|
||||
|
||||
is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
|
||||
|
||||
using_eager = self.config._attn_implementation == "eager"
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and (output_attentions or head_mask is not None):
|
||||
using_eager = True
|
||||
logger.warning_once(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
# Attention functions are consistent with previous equivalent attention classes, however they do not support some options
|
||||
# (e.g. layer scaling, head mask) that eager supports. These implementations are thus equivalent to previous code, but
|
||||
# not necessarily to eager (if mentioned options are provided).
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
if using_eager and self.reorder_and_upcast_attn:
|
||||
|
||||
@@ -27,9 +27,10 @@ from torch import nn
|
||||
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
||||
|
||||
from ...activations import ACT2FN, get_activation
|
||||
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
|
||||
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
|
||||
from ...generation import GenerationMixin
|
||||
from ...modeling_attn_mask_utils import AttentionMaskConverter, _prepare_4d_attention_mask_for_sdpa
|
||||
from ...masking_utils import create_causal_mask
|
||||
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
|
||||
from ...modeling_layers import GradientCheckpointingLayer
|
||||
from ...modeling_outputs import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
@@ -278,52 +279,61 @@ class GPT2Attention(nn.Module):
|
||||
**kwargs,
|
||||
) -> tuple[Union[torch.Tensor, tuple[torch.Tensor]], ...]:
|
||||
is_cross_attention = encoder_hidden_states is not None
|
||||
if past_key_value is not None:
|
||||
if isinstance(past_key_value, EncoderDecoderCache):
|
||||
is_updated = past_key_value.is_updated.get(self.layer_idx)
|
||||
if is_cross_attention:
|
||||
# after the first generated id, we can subsequently re-use all key/value_layer from cache
|
||||
curr_past_key_value = past_key_value.cross_attention_cache
|
||||
else:
|
||||
curr_past_key_value = past_key_value.self_attention_cache
|
||||
else:
|
||||
curr_past_key_value = past_key_value
|
||||
|
||||
if is_cross_attention:
|
||||
if not hasattr(self, "q_attn"):
|
||||
raise ValueError(
|
||||
"If class is used as cross attention, the weights `q_attn` have to be defined. "
|
||||
"Please make sure to instantiate class with `GPT2Attention(..., is_cross_attention=True)`."
|
||||
)
|
||||
|
||||
query_states = self.q_attn(hidden_states)
|
||||
key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
|
||||
attention_mask = encoder_attention_mask
|
||||
|
||||
# Try to get key/value states from cache if possible
|
||||
if past_key_value is not None and is_updated:
|
||||
key_states = curr_past_key_value.layers[self.layer_idx].keys
|
||||
value_states = curr_past_key_value.layers[self.layer_idx].values
|
||||
else:
|
||||
key_states, value_states = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
|
||||
shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
|
||||
key_states = key_states.view(shape_kv).transpose(1, 2)
|
||||
value_states = value_states.view(shape_kv).transpose(1, 2)
|
||||
else:
|
||||
query_states, key_states, value_states = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
||||
|
||||
shape_q = (*query_states.shape[:-1], -1, self.head_dim)
|
||||
shape_kv = (*key_states.shape[:-1], -1, self.head_dim)
|
||||
|
||||
query_states = query_states.view(shape_q).transpose(1, 2)
|
||||
key_states = key_states.view(shape_kv).transpose(1, 2)
|
||||
value_states = value_states.view(shape_kv).transpose(1, 2)
|
||||
|
||||
if past_key_value is not None:
|
||||
if isinstance(past_key_value, EncoderDecoderCache):
|
||||
if is_cross_attention:
|
||||
past_key_value = past_key_value.cross_attention_cache
|
||||
else:
|
||||
past_key_value = past_key_value.self_attention_cache
|
||||
cache_kwargs = {"cache_position": cache_position}
|
||||
key_states, value_states = past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, cache_kwargs=cache_kwargs
|
||||
shape_q = (*query_states.shape[:-1], -1, self.head_dim)
|
||||
query_states = query_states.view(shape_q).transpose(1, 2)
|
||||
|
||||
if (past_key_value is not None and not is_cross_attention) or (
|
||||
past_key_value is not None and is_cross_attention and not is_updated
|
||||
):
|
||||
# save all key/value_layer to cache to be re-used for fast auto-regressive generation
|
||||
cache_position = cache_position if not is_cross_attention else None
|
||||
key_states, value_states = curr_past_key_value.update(
|
||||
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
||||
)
|
||||
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
|
||||
if is_cross_attention:
|
||||
past_key_value.is_updated[self.layer_idx] = True
|
||||
|
||||
is_causal = attention_mask is None and query_states.shape[-2] > 1 and not is_cross_attention
|
||||
|
||||
using_eager = self.config._attn_implementation == "eager"
|
||||
attention_interface: Callable = eager_attention_forward
|
||||
if self.config._attn_implementation != "eager":
|
||||
if self.config._attn_implementation == "sdpa" and (output_attentions or head_mask is not None):
|
||||
using_eager = True
|
||||
logger.warning_once(
|
||||
"`torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to "
|
||||
'eager attention. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
||||
)
|
||||
else:
|
||||
# Attention functions are consistent with previous equivalent attention classes, however they do not support some options
|
||||
# (e.g. layer scaling, head mask) that eager supports. These implementations are thus equivalent to previous code, but
|
||||
# not necessarily to eager (if mentioned options are provided).
|
||||
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
|
||||
|
||||
if using_eager and self.reorder_and_upcast_attn:
|
||||
@@ -861,8 +871,14 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
# ._update_causal_mask() and ._prepare_4d_causal_attention_mask_with_cache_position() copied from LlamaModel
|
||||
if attention_mask is not None and attention_mask.ndim < 4:
|
||||
attention_mask = attention_mask.view(batch_size, -1)
|
||||
causal_mask = self._update_causal_mask(
|
||||
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
|
||||
|
||||
causal_mask = create_causal_mask(
|
||||
config=self.config,
|
||||
input_embeds=inputs_embeds,
|
||||
attention_mask=attention_mask,
|
||||
cache_position=cache_position,
|
||||
past_key_values=past_key_values,
|
||||
position_ids=position_ids,
|
||||
)
|
||||
|
||||
# If a 2D or 3D attention mask is provided for the cross-attention
|
||||
@@ -903,9 +919,6 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
# Model parallel
|
||||
if self.model_parallel:
|
||||
torch.cuda.set_device(hidden_states.device)
|
||||
# Ensure that attention_mask is always on the same device as hidden_states
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask.to(hidden_states.device)
|
||||
if isinstance(head_mask, torch.Tensor):
|
||||
head_mask = head_mask.to(hidden_states.device)
|
||||
if output_hidden_states:
|
||||
@@ -966,123 +979,6 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
cross_attentions=all_cross_attentions,
|
||||
)
|
||||
|
||||
def _update_causal_mask(
|
||||
self,
|
||||
attention_mask: torch.Tensor,
|
||||
input_tensor: torch.Tensor,
|
||||
cache_position: torch.Tensor,
|
||||
past_key_values: Cache,
|
||||
output_attentions: bool,
|
||||
):
|
||||
if self.config._attn_implementation == "flash_attention_2":
|
||||
if attention_mask is not None and 0.0 in attention_mask:
|
||||
return attention_mask
|
||||
return None
|
||||
|
||||
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
|
||||
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
|
||||
# to infer the attention mask.
|
||||
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
||||
using_static_cache = isinstance(past_key_values, StaticCache)
|
||||
|
||||
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
|
||||
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
|
||||
if AttentionMaskConverter._ignore_causal_mask_sdpa(
|
||||
attention_mask,
|
||||
inputs_embeds=input_tensor,
|
||||
past_key_values_length=past_seen_tokens,
|
||||
is_training=self.training,
|
||||
):
|
||||
return None
|
||||
|
||||
dtype = input_tensor.dtype
|
||||
sequence_length = input_tensor.shape[1]
|
||||
if using_static_cache:
|
||||
target_length = past_key_values.get_max_cache_shape()
|
||||
else:
|
||||
target_length = (
|
||||
attention_mask.shape[-1]
|
||||
if isinstance(attention_mask, torch.Tensor)
|
||||
else past_seen_tokens + sequence_length + 1
|
||||
)
|
||||
|
||||
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
|
||||
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask,
|
||||
sequence_length=sequence_length,
|
||||
target_length=target_length,
|
||||
dtype=dtype,
|
||||
cache_position=cache_position,
|
||||
batch_size=input_tensor.shape[0],
|
||||
)
|
||||
|
||||
if (
|
||||
self.config._attn_implementation == "sdpa"
|
||||
and attention_mask is not None
|
||||
and attention_mask.device.type == "cuda"
|
||||
and not output_attentions
|
||||
):
|
||||
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
|
||||
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
|
||||
# Details: https://github.com/pytorch/pytorch/issues/110213
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
|
||||
|
||||
return causal_mask
|
||||
|
||||
@staticmethod
|
||||
def _prepare_4d_causal_attention_mask_with_cache_position(
|
||||
attention_mask: torch.Tensor,
|
||||
sequence_length: int,
|
||||
target_length: int,
|
||||
dtype: torch.dtype,
|
||||
cache_position: torch.Tensor,
|
||||
batch_size: int,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
|
||||
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
|
||||
|
||||
Args:
|
||||
attention_mask (`torch.Tensor`):
|
||||
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
|
||||
`(batch_size, 1, query_length, key_value_length)`.
|
||||
sequence_length (`int`):
|
||||
The sequence length being processed.
|
||||
target_length (`int`):
|
||||
The target length: when generating with static cache, the mask should be as long as the static cache,
|
||||
to account for the 0 padding, the part of the cache that is not filled yet.
|
||||
dtype (`torch.dtype`):
|
||||
The dtype to use for the 4D attention mask.
|
||||
cache_position (`torch.Tensor`):
|
||||
Indices depicting the position of the input sequence tokens in the sequence.
|
||||
batch_size (`torch.Tensor`):
|
||||
Batch size.
|
||||
"""
|
||||
if attention_mask is not None and attention_mask.dim() == 4:
|
||||
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
|
||||
causal_mask = attention_mask
|
||||
else:
|
||||
min_dtype = torch.finfo(dtype).min
|
||||
causal_mask = torch.full(
|
||||
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
|
||||
)
|
||||
if sequence_length != 1:
|
||||
causal_mask = torch.triu(causal_mask, diagonal=1)
|
||||
causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
|
||||
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
|
||||
if attention_mask is not None:
|
||||
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
|
||||
mask_length = attention_mask.shape[-1]
|
||||
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
|
||||
padding_mask = padding_mask == 0
|
||||
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
|
||||
padding_mask, min_dtype
|
||||
)
|
||||
|
||||
return causal_mask
|
||||
|
||||
|
||||
@auto_docstring(
|
||||
custom_intro="""
|
||||
|
||||
@@ -449,6 +449,7 @@ class VisionEncoderDecoderModel(PreTrainedModel, GenerationMixin):
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
cache_position: Optional[torch.LongTensor] = None,
|
||||
**kwargs,
|
||||
) -> Union[tuple[torch.FloatTensor], Seq2SeqLMOutput]:
|
||||
r"""
|
||||
@@ -561,6 +562,7 @@ class VisionEncoderDecoderModel(PreTrainedModel, GenerationMixin):
|
||||
use_cache=use_cache,
|
||||
past_key_values=past_key_values,
|
||||
return_dict=return_dict,
|
||||
cache_position=cache_position,
|
||||
**kwargs_decoder,
|
||||
)
|
||||
|
||||
|
||||
@@ -1775,6 +1775,7 @@ class ModelTesterMixin:
|
||||
model = model_class(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
model.set_attn_implementation("eager")
|
||||
heads_to_prune = {
|
||||
0: list(range(1, self.model_tester.num_attention_heads)),
|
||||
-1: [0],
|
||||
@@ -1808,6 +1809,7 @@ class ModelTesterMixin:
|
||||
model = model_class(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
model.set_attn_implementation("eager")
|
||||
heads_to_prune = {
|
||||
0: list(range(1, self.model_tester.num_attention_heads)),
|
||||
-1: [0],
|
||||
@@ -1816,7 +1818,7 @@ class ModelTesterMixin:
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir_name:
|
||||
model.save_pretrained(temp_dir_name)
|
||||
model = model_class.from_pretrained(temp_dir_name)
|
||||
model = model_class.from_pretrained(temp_dir_name, attn_implementation="eager")
|
||||
model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
@@ -1852,6 +1854,7 @@ class ModelTesterMixin:
|
||||
model = model_class(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
model.set_attn_implementation("eager")
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
@@ -1884,6 +1887,7 @@ class ModelTesterMixin:
|
||||
model = model_class(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
model.set_attn_implementation("eager")
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
|
||||
@@ -1894,7 +1898,7 @@ class ModelTesterMixin:
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir_name:
|
||||
model.save_pretrained(temp_dir_name)
|
||||
model = model_class.from_pretrained(temp_dir_name)
|
||||
model = model_class.from_pretrained(temp_dir_name, attn_implementation="eager")
|
||||
model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
Reference in New Issue
Block a user