Fix pix2struct (#34374)

* fix

* fix and test use_cache test

* style

* remove atol
This commit is contained in:
Ilyas Moutawwakil
2024-10-28 11:24:56 +01:00
committed by GitHub
parent 1d06379331
commit fddbd3c13c
2 changed files with 46 additions and 25 deletions

View File

@@ -762,11 +762,14 @@ class Pix2StructTextAttention(nn.Module):
return relative_buckets return relative_buckets
# Adapted from transformers.models.t5.modeling_t5.T5Attention.compute_bias # Adapted from transformers.models.t5.modeling_t5.T5Attention.compute_bias
def compute_bias(self, query_length, key_length, device=None): def compute_bias(self, query_length, key_length, device=None, cache_position=None):
"""Compute binned relative position bias""" """Compute binned relative position bias"""
if device is None: if device is None:
device = self.relative_attention_bias.weight.device device = self.relative_attention_bias.weight.device
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] if cache_position is None:
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
else:
context_position = cache_position[:, None].to(device)
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
relative_position = memory_position - context_position # shape (query_length, key_length) relative_position = memory_position - context_position # shape (query_length, key_length)
relative_position_bucket = self._relative_position_bucket( relative_position_bucket = self._relative_position_bucket(
@@ -779,6 +782,7 @@ class Pix2StructTextAttention(nn.Module):
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
return values return values
# Adapted from transformers.models.t5.modeling_t5.T5Attention.forward
def forward( def forward(
self, self,
hidden_states, hidden_states,
@@ -796,61 +800,66 @@ class Pix2StructTextAttention(nn.Module):
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
""" """
# Input is (batch_size, seq_length, dim) # Input is (batch_size, seq_length, dim)
# Mask is (batch_size, 1, 1, key_length) (non-causal) or (batch_size, 1, query_length, key_length) # Mask is (batch_size, 1, 1, key_length) (non-causal) or (batch_size, 1, seq_length, key_length) (causal decoder)
batch_size, seq_length = hidden_states.shape[:2] batch_size, seq_length = hidden_states.shape[:2]
# if key_value_states are provided this layer is used as a cross-attention layer for the decoder # if key_value_states are provided this layer is used as a cross-attention layer for the decoder
is_cross_attention = key_value_states is not None is_cross_attention = key_value_states is not None
query_states = self.query(hidden_states).contiguous() query_states = self.query(hidden_states)
query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
if past_key_value is not None: if past_key_value is not None:
is_updated = past_key_value.is_updated.get(self.layer_idx) is_updated = past_key_value.is_updated.get(self.layer_idx)
if is_cross_attention: if is_cross_attention:
# after the first generated id, we can subsequently re-use all key/value_states from cache # after the first generated id, we can subsequently re-use all key/value_states from cache
past_key_value = past_key_value.cross_attention_cache curr_past_key_value = past_key_value.cross_attention_cache
else: else:
past_key_value = past_key_value.self_attention_cache curr_past_key_value = past_key_value.self_attention_cache
# get key/value states
current_states = key_value_states if is_cross_attention else hidden_states current_states = key_value_states if is_cross_attention else hidden_states
if is_cross_attention and past_key_value and is_updated: if is_cross_attention and past_key_value and is_updated:
# reuse k,v, cross_attentions # reuse k,v, cross_attentions
key_states = past_key_value.key_cache[self.layer_idx] key_states = curr_past_key_value.key_cache[self.layer_idx]
value_states = past_key_value.value_cache[self.layer_idx] value_states = curr_past_key_value.value_cache[self.layer_idx]
else: else:
key_states = self.key(current_states).contiguous() key_states = self.key(current_states)
value_states = self.value(current_states).contiguous() value_states = self.value(current_states)
key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
if past_key_value is not None: if past_key_value is not None:
# save all key/value_states to cache to be re-used for fast auto-regressive generation # save all key/value_states to cache to be re-used for fast auto-regressive generation
cache_position = cache_position if not is_cross_attention else None cache_position = cache_position if not is_cross_attention else None
key_states, value_states = past_key_value.update( key_states, value_states = curr_past_key_value.update(
key_states, value_states, self.layer_idx, {"cache_position": cache_position} key_states, value_states, self.layer_idx, {"cache_position": cache_position}
) )
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
if is_cross_attention: if is_cross_attention:
past_key_value.is_updated[self.layer_idx] = True past_key_value.is_updated[self.layer_idx] = True
# compute scores # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
scores = torch.matmul(query_states, key_states.transpose(3, 2)) scores = torch.matmul(query_states, key_states.transpose(3, 2))
if position_bias is None: if position_bias is None:
real_seq_length = cache_position[-1] + 1 if query_length is None else query_length key_length = key_states.shape[-2]
key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
real_seq_length = query_length if query_length is not None else cache_position[-1] + 1
if not self.has_relative_attention_bias: if not self.has_relative_attention_bias:
position_bias = torch.zeros( position_bias = torch.zeros(
(1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype
) )
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
position_bias.requires_grad = True position_bias.requires_grad = True
else: else:
position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) position_bias = self.compute_bias(
real_seq_length, key_length, device=scores.device, cache_position=cache_position
)
position_bias = position_bias[:, :, -seq_length:, :]
if mask is not None: if mask is not None:
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) causal_mask = mask[:, :, :, : key_states.shape[-2]]
position_bias = position_bias + causal_mask
if self.pruned_heads: if self.pruned_heads:
mask = torch.ones(position_bias.shape[1]) mask = torch.ones(position_bias.shape[1])
@@ -860,10 +869,9 @@ class Pix2StructTextAttention(nn.Module):
position_bias_masked = position_bias position_bias_masked = position_bias
scores += position_bias_masked scores += position_bias_masked
# (batch_size, n_heads, seq_length, key_length)
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
# (batch_size, n_heads, seq_length, key_length) # (batch_size, n_heads, seq_length, key_length)
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
# Mask heads if we want to # Mask heads if we want to
@@ -871,12 +879,12 @@ class Pix2StructTextAttention(nn.Module):
attn_weights = attn_weights * layer_head_mask attn_weights = attn_weights * layer_head_mask
attn_output = torch.matmul(attn_weights, value_states) attn_output = torch.matmul(attn_weights, value_states)
# (batch_size, seq_length, dim)
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, -1, self.inner_dim)
attn_output = self.output(attn_output) attn_output = self.output(attn_output)
outputs = (attn_output,) + (past_key_value,) + (position_bias,) outputs = (attn_output, past_key_value, position_bias)
if output_attentions: if output_attentions:
outputs = outputs + (attn_weights,) outputs = outputs + (attn_weights,)
@@ -969,7 +977,10 @@ class Pix2StructTextBlock(nn.Module):
layer_idx=layer_idx, layer_idx=layer_idx,
) )
self.encoder_decoder_attention = Pix2StructTextLayerCrossAttention(config) self.encoder_decoder_attention = Pix2StructTextLayerCrossAttention(
config,
layer_idx=layer_idx,
)
self.mlp = Pix2StructTextLayerFF(config) self.mlp = Pix2StructTextLayerFF(config)
@@ -1019,7 +1030,6 @@ class Pix2StructTextBlock(nn.Module):
query_length=cache_position[-1] + 1, query_length=cache_position[-1] + 1,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
cache_position=cache_position,
) )
hidden_states, past_key_value = cross_attention_outputs[:2] hidden_states, past_key_value = cross_attention_outputs[:2]

View File

@@ -419,6 +419,7 @@ class Pix2StructModelTester:
@require_torch @require_torch
class Pix2StructModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): class Pix2StructModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (Pix2StructForConditionalGeneration,) if is_torch_available() else () all_model_classes = (Pix2StructForConditionalGeneration,) if is_torch_available() else ()
all_generative_model_classes = (Pix2StructForConditionalGeneration,) if is_torch_available() else {}
pipeline_model_mapping = {"image-to-text": Pix2StructForConditionalGeneration} if is_torch_available() else {} pipeline_model_mapping = {"image-to-text": Pix2StructForConditionalGeneration} if is_torch_available() else {}
fx_compatible = False fx_compatible = False
test_head_masking = False test_head_masking = False
@@ -445,6 +446,16 @@ class Pix2StructModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
), ),
) )
def test_generative_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_generative_model_classes:
model = model_class(config).eval().to(torch_device)
output = model.generate(**input_dict, use_cache=False, min_new_tokens=10, max_new_tokens=10)
output_use_cache = model.generate(**input_dict, use_cache=True, min_new_tokens=10, max_new_tokens=10)
torch.testing.assert_close(output, output_use_cache)
@unittest.skip(reason="Hidden_states is tested in individual model tests") @unittest.skip(reason="Hidden_states is tested in individual model tests")
def test_hidden_states_output(self): def test_hidden_states_output(self):
pass pass