Fix pix2struct (#34374)
* fix * fix and test use_cache test * style * remove atol
This commit is contained in:
committed by
GitHub
parent
1d06379331
commit
fddbd3c13c
@@ -762,11 +762,14 @@ class Pix2StructTextAttention(nn.Module):
|
|||||||
return relative_buckets
|
return relative_buckets
|
||||||
|
|
||||||
# Adapted from transformers.models.t5.modeling_t5.T5Attention.compute_bias
|
# Adapted from transformers.models.t5.modeling_t5.T5Attention.compute_bias
|
||||||
def compute_bias(self, query_length, key_length, device=None):
|
def compute_bias(self, query_length, key_length, device=None, cache_position=None):
|
||||||
"""Compute binned relative position bias"""
|
"""Compute binned relative position bias"""
|
||||||
if device is None:
|
if device is None:
|
||||||
device = self.relative_attention_bias.weight.device
|
device = self.relative_attention_bias.weight.device
|
||||||
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
|
if cache_position is None:
|
||||||
|
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
|
||||||
|
else:
|
||||||
|
context_position = cache_position[:, None].to(device)
|
||||||
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
|
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
|
||||||
relative_position = memory_position - context_position # shape (query_length, key_length)
|
relative_position = memory_position - context_position # shape (query_length, key_length)
|
||||||
relative_position_bucket = self._relative_position_bucket(
|
relative_position_bucket = self._relative_position_bucket(
|
||||||
@@ -779,6 +782,7 @@ class Pix2StructTextAttention(nn.Module):
|
|||||||
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
|
values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
# Adapted from transformers.models.t5.modeling_t5.T5Attention.forward
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
@@ -796,61 +800,66 @@ class Pix2StructTextAttention(nn.Module):
|
|||||||
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
|
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
|
||||||
"""
|
"""
|
||||||
# Input is (batch_size, seq_length, dim)
|
# Input is (batch_size, seq_length, dim)
|
||||||
# Mask is (batch_size, 1, 1, key_length) (non-causal) or (batch_size, 1, query_length, key_length)
|
# Mask is (batch_size, 1, 1, key_length) (non-causal) or (batch_size, 1, seq_length, key_length) (causal decoder)
|
||||||
batch_size, seq_length = hidden_states.shape[:2]
|
batch_size, seq_length = hidden_states.shape[:2]
|
||||||
|
|
||||||
# if key_value_states are provided this layer is used as a cross-attention layer for the decoder
|
# if key_value_states are provided this layer is used as a cross-attention layer for the decoder
|
||||||
is_cross_attention = key_value_states is not None
|
is_cross_attention = key_value_states is not None
|
||||||
|
|
||||||
query_states = self.query(hidden_states).contiguous()
|
query_states = self.query(hidden_states)
|
||||||
query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
|
query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
is_updated = past_key_value.is_updated.get(self.layer_idx)
|
is_updated = past_key_value.is_updated.get(self.layer_idx)
|
||||||
if is_cross_attention:
|
if is_cross_attention:
|
||||||
# after the first generated id, we can subsequently re-use all key/value_states from cache
|
# after the first generated id, we can subsequently re-use all key/value_states from cache
|
||||||
past_key_value = past_key_value.cross_attention_cache
|
curr_past_key_value = past_key_value.cross_attention_cache
|
||||||
else:
|
else:
|
||||||
past_key_value = past_key_value.self_attention_cache
|
curr_past_key_value = past_key_value.self_attention_cache
|
||||||
|
|
||||||
# get key/value states
|
|
||||||
current_states = key_value_states if is_cross_attention else hidden_states
|
current_states = key_value_states if is_cross_attention else hidden_states
|
||||||
if is_cross_attention and past_key_value and is_updated:
|
if is_cross_attention and past_key_value and is_updated:
|
||||||
# reuse k,v, cross_attentions
|
# reuse k,v, cross_attentions
|
||||||
key_states = past_key_value.key_cache[self.layer_idx]
|
key_states = curr_past_key_value.key_cache[self.layer_idx]
|
||||||
value_states = past_key_value.value_cache[self.layer_idx]
|
value_states = curr_past_key_value.value_cache[self.layer_idx]
|
||||||
else:
|
else:
|
||||||
key_states = self.key(current_states).contiguous()
|
key_states = self.key(current_states)
|
||||||
value_states = self.value(current_states).contiguous()
|
value_states = self.value(current_states)
|
||||||
key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
|
key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
|
||||||
value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
|
value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
# save all key/value_states to cache to be re-used for fast auto-regressive generation
|
# save all key/value_states to cache to be re-used for fast auto-regressive generation
|
||||||
cache_position = cache_position if not is_cross_attention else None
|
cache_position = cache_position if not is_cross_attention else None
|
||||||
key_states, value_states = past_key_value.update(
|
key_states, value_states = curr_past_key_value.update(
|
||||||
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
|
||||||
)
|
)
|
||||||
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
|
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
|
||||||
if is_cross_attention:
|
if is_cross_attention:
|
||||||
past_key_value.is_updated[self.layer_idx] = True
|
past_key_value.is_updated[self.layer_idx] = True
|
||||||
|
|
||||||
# compute scores
|
# compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
|
||||||
scores = torch.matmul(query_states, key_states.transpose(3, 2))
|
scores = torch.matmul(query_states, key_states.transpose(3, 2))
|
||||||
|
|
||||||
if position_bias is None:
|
if position_bias is None:
|
||||||
real_seq_length = cache_position[-1] + 1 if query_length is None else query_length
|
key_length = key_states.shape[-2]
|
||||||
key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
|
# cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
|
||||||
|
real_seq_length = query_length if query_length is not None else cache_position[-1] + 1
|
||||||
if not self.has_relative_attention_bias:
|
if not self.has_relative_attention_bias:
|
||||||
position_bias = torch.zeros(
|
position_bias = torch.zeros(
|
||||||
(1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
|
(1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype
|
||||||
)
|
)
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
position_bias.requires_grad = True
|
position_bias.requires_grad = True
|
||||||
else:
|
else:
|
||||||
position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
|
position_bias = self.compute_bias(
|
||||||
|
real_seq_length, key_length, device=scores.device, cache_position=cache_position
|
||||||
|
)
|
||||||
|
position_bias = position_bias[:, :, -seq_length:, :]
|
||||||
|
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
|
causal_mask = mask[:, :, :, : key_states.shape[-2]]
|
||||||
|
position_bias = position_bias + causal_mask
|
||||||
|
|
||||||
if self.pruned_heads:
|
if self.pruned_heads:
|
||||||
mask = torch.ones(position_bias.shape[1])
|
mask = torch.ones(position_bias.shape[1])
|
||||||
@@ -860,10 +869,9 @@ class Pix2StructTextAttention(nn.Module):
|
|||||||
position_bias_masked = position_bias
|
position_bias_masked = position_bias
|
||||||
|
|
||||||
scores += position_bias_masked
|
scores += position_bias_masked
|
||||||
# (batch_size, n_heads, seq_length, key_length)
|
|
||||||
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
|
|
||||||
|
|
||||||
# (batch_size, n_heads, seq_length, key_length)
|
# (batch_size, n_heads, seq_length, key_length)
|
||||||
|
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
|
||||||
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
|
||||||
|
|
||||||
# Mask heads if we want to
|
# Mask heads if we want to
|
||||||
@@ -871,12 +879,12 @@ class Pix2StructTextAttention(nn.Module):
|
|||||||
attn_weights = attn_weights * layer_head_mask
|
attn_weights = attn_weights * layer_head_mask
|
||||||
|
|
||||||
attn_output = torch.matmul(attn_weights, value_states)
|
attn_output = torch.matmul(attn_weights, value_states)
|
||||||
# (batch_size, seq_length, dim)
|
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
|
|
||||||
|
|
||||||
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
attn_output = attn_output.view(batch_size, -1, self.inner_dim)
|
||||||
attn_output = self.output(attn_output)
|
attn_output = self.output(attn_output)
|
||||||
|
|
||||||
outputs = (attn_output,) + (past_key_value,) + (position_bias,)
|
outputs = (attn_output, past_key_value, position_bias)
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
outputs = outputs + (attn_weights,)
|
outputs = outputs + (attn_weights,)
|
||||||
@@ -969,7 +977,10 @@ class Pix2StructTextBlock(nn.Module):
|
|||||||
layer_idx=layer_idx,
|
layer_idx=layer_idx,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.encoder_decoder_attention = Pix2StructTextLayerCrossAttention(config)
|
self.encoder_decoder_attention = Pix2StructTextLayerCrossAttention(
|
||||||
|
config,
|
||||||
|
layer_idx=layer_idx,
|
||||||
|
)
|
||||||
|
|
||||||
self.mlp = Pix2StructTextLayerFF(config)
|
self.mlp = Pix2StructTextLayerFF(config)
|
||||||
|
|
||||||
@@ -1019,7 +1030,6 @@ class Pix2StructTextBlock(nn.Module):
|
|||||||
query_length=cache_position[-1] + 1,
|
query_length=cache_position[-1] + 1,
|
||||||
use_cache=use_cache,
|
use_cache=use_cache,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
cache_position=cache_position,
|
|
||||||
)
|
)
|
||||||
hidden_states, past_key_value = cross_attention_outputs[:2]
|
hidden_states, past_key_value = cross_attention_outputs[:2]
|
||||||
|
|
||||||
|
|||||||
@@ -419,6 +419,7 @@ class Pix2StructModelTester:
|
|||||||
@require_torch
|
@require_torch
|
||||||
class Pix2StructModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
class Pix2StructModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||||
all_model_classes = (Pix2StructForConditionalGeneration,) if is_torch_available() else ()
|
all_model_classes = (Pix2StructForConditionalGeneration,) if is_torch_available() else ()
|
||||||
|
all_generative_model_classes = (Pix2StructForConditionalGeneration,) if is_torch_available() else {}
|
||||||
pipeline_model_mapping = {"image-to-text": Pix2StructForConditionalGeneration} if is_torch_available() else {}
|
pipeline_model_mapping = {"image-to-text": Pix2StructForConditionalGeneration} if is_torch_available() else {}
|
||||||
fx_compatible = False
|
fx_compatible = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
@@ -445,6 +446,16 @@ class Pix2StructModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_generative_model(self):
|
||||||
|
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
for model_class in self.all_generative_model_classes:
|
||||||
|
model = model_class(config).eval().to(torch_device)
|
||||||
|
|
||||||
|
output = model.generate(**input_dict, use_cache=False, min_new_tokens=10, max_new_tokens=10)
|
||||||
|
output_use_cache = model.generate(**input_dict, use_cache=True, min_new_tokens=10, max_new_tokens=10)
|
||||||
|
|
||||||
|
torch.testing.assert_close(output, output_use_cache)
|
||||||
|
|
||||||
@unittest.skip(reason="Hidden_states is tested in individual model tests")
|
@unittest.skip(reason="Hidden_states is tested in individual model tests")
|
||||||
def test_hidden_states_output(self):
|
def test_hidden_states_output(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
Reference in New Issue
Block a user