T5 compile compatibilty (#34089)

* this worked in normal generation, needs more tests

* fix almost all tests in t5

* nit

* longt5, umt5, mt5

* style

* udop, pix2struct

* more models

* fix some tests

* fix onnx tests

* tracing tests fixed

* compile enabled and tested for t5 models

* fix small bug in slow tests

* [run-slow] t5

* uncomment

* style

* update with new generation refactoring

* nit

* fix copies

* this is the fix, had to change t5 to fix copies

* update

* [run-slow] t5

* [run-slow] t5

* update

* add test for encoder only T5

* clean up after rebase

* fix pop2piano

* add comment

* style

* fix copies after rebase

* fix copies  missed this one
This commit is contained in:
Raushan Turganbay
2024-10-22 08:23:53 +02:00
committed by GitHub
parent 5077bc034f
commit 73d65e637b
22 changed files with 2744 additions and 1179 deletions

View File

@@ -1475,11 +1475,7 @@ class EncoderDecoderCache(Cache):
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed.""" """Returns the sequence length of the cached states. A layer index can be optionally passed."""
# check if empty list because in case of static cache it will be a tensors and we can't check `if not torch.Tensor` # check if empty list because in case of static cache it will be a tensors and we can't check `if not torch.Tensor`
if self.self_attention_cache.key_cache == []: return self.self_attention_cache.get_seq_length(layer_idx)
return 0
if len(self.self_attention_cache.key_cache) > 1 and self.self_attention_cache.key_cache[layer_idx] == []:
return 0
return (self.self_attention_cache.key_cache[layer_idx][0, 0].any(dim=-1)).sum()
def reset(self): def reset(self):
if hasattr(self.self_attention_cache, "reset"): if hasattr(self.self_attention_cache, "reset"):

View File

@@ -1535,8 +1535,12 @@ class GenerationMixin:
def _get_initial_cache_position(self, input_ids, model_kwargs): def _get_initial_cache_position(self, input_ids, model_kwargs):
"""Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length""" """Calculates `cache_position` for the pre-fill stage based on `input_ids` and optionally past length"""
# `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange` # `torch.compile`-friendly `torch.arange` from a shape -- the lines below are equivalent to `torch.arange`
if "inputs_embeds" in model_kwargs: if "inputs_embeds" in model_kwargs and not self.config.is_encoder_decoder:
cache_position = torch.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1 cache_position = torch.ones_like(model_kwargs["inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1
elif "decoder_inputs_embeds" in model_kwargs and self.config.is_encoder_decoder:
cache_position = (
torch.ones_like(model_kwargs["decoder_inputs_embeds"][0, :, 0], dtype=torch.int64).cumsum(0) - 1
)
else: else:
cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1 cache_position = torch.ones_like(input_ids[0, :], dtype=torch.int64).cumsum(0) - 1
@@ -1633,7 +1637,7 @@ class GenerationMixin:
cache_kwargs = { cache_kwargs = {
"config": self.config.get_text_config(), "config": self.config.get_text_config(),
"max_batch_size": batch_size, "batch_size": batch_size,
"max_cache_len": max_cache_len, "max_cache_len": max_cache_len,
"device": device, "device": device,
"dtype": cache_dtype, "dtype": cache_dtype,

View File

@@ -79,7 +79,12 @@ class LongT5Config(PretrainedConfig):
model_type = "longt5" model_type = "longt5"
keys_to_ignore_at_inference = ["past_key_values"] keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"} attribute_map = {
"hidden_size": "d_model",
"num_attention_heads": "num_heads",
"num_hidden_layers": "num_layers",
"head_dim": "d_kv",
}
def __init__( def __init__(
self, self,

View File

@@ -24,7 +24,9 @@ from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
@@ -39,6 +41,7 @@ from ...utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_torch_fx_proxy, is_torch_fx_proxy,
is_torchdynamo_compiling,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
@@ -317,7 +320,12 @@ class LongT5LayerFF(nn.Module):
# Copied from transformers.models.t5.modeling_t5.T5Attention with T5->LongT5 # Copied from transformers.models.t5.modeling_t5.T5Attention with T5->LongT5
class LongT5Attention(nn.Module): class LongT5Attention(nn.Module):
def __init__(self, config: LongT5Config, has_relative_attention_bias=False): def __init__(
self,
config: LongT5Config,
has_relative_attention_bias=False,
layer_idx: Optional[int] = None,
):
super().__init__() super().__init__()
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
self.has_relative_attention_bias = has_relative_attention_bias self.has_relative_attention_bias = has_relative_attention_bias
@@ -328,6 +336,13 @@ class LongT5Attention(nn.Module):
self.n_heads = config.num_heads self.n_heads = config.num_heads
self.dropout = config.dropout_rate self.dropout = config.dropout_rate
self.inner_dim = self.n_heads * self.key_value_proj_dim self.inner_dim = self.n_heads * self.key_value_proj_dim
self.layer_idx = layer_idx
if layer_idx is None and self.is_decoder:
logger.warning_once(
f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
"will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
# Mesh TensorFlow initialization to avoid scaling before softmax # Mesh TensorFlow initialization to avoid scaling before softmax
self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
@@ -404,11 +419,14 @@ class LongT5Attention(nn.Module):
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
return relative_buckets return relative_buckets
def compute_bias(self, query_length, key_length, device=None): def compute_bias(self, query_length, key_length, device=None, cache_position=None):
"""Compute binned relative position bias""" """Compute binned relative position bias"""
if device is None: if device is None:
device = self.relative_attention_bias.weight.device device = self.relative_attention_bias.weight.device
if cache_position is None:
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
else:
context_position = cache_position[:, None].to(device)
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
relative_position = memory_position - context_position # shape (query_length, key_length) relative_position = memory_position - context_position # shape (query_length, key_length)
relative_position_bucket = self._relative_position_bucket( relative_position_bucket = self._relative_position_bucket(
@@ -432,94 +450,72 @@ class LongT5Attention(nn.Module):
query_length=None, query_length=None,
use_cache=False, use_cache=False,
output_attentions=False, output_attentions=False,
cache_position=None,
): ):
""" """
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
""" """
# Input is (batch_size, seq_length, dim) # Input is (batch_size, seq_length, dim)
# Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder)
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
batch_size, seq_length = hidden_states.shape[:2] batch_size, seq_length = hidden_states.shape[:2]
real_seq_length = seq_length # if key_value_states are provided this layer is used as a cross-attention layer for the decoder
is_cross_attention = key_value_states is not None
query_states = self.q(hidden_states)
query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
if past_key_value is not None: if past_key_value is not None:
if len(past_key_value) != 2: is_updated = past_key_value.is_updated.get(self.layer_idx)
raise ValueError( if is_cross_attention:
f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" # after the first generated id, we can subsequently re-use all key/value_states from cache
) curr_past_key_value = past_key_value.cross_attention_cache
real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
def shape(states):
"""projection"""
return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
def unshape(states):
"""reshape"""
return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
def project(hidden_states, proj_layer, key_value_states, past_key_value):
"""projects hidden states correctly to key/query states"""
if key_value_states is None:
# self-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(hidden_states))
elif past_key_value is None:
# cross-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(key_value_states))
if past_key_value is not None:
if key_value_states is None:
# self-attn
# (batch_size, n_heads, key_length, dim_per_head)
hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
elif past_key_value.shape[2] != key_value_states.shape[1]:
# checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
# cross-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(key_value_states))
else: else:
# cross-attn curr_past_key_value = past_key_value.self_attention_cache
hidden_states = past_key_value
return hidden_states
# get query states current_states = key_value_states if is_cross_attention else hidden_states
query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) if is_cross_attention and past_key_value is not None and is_updated:
# reuse k,v, cross_attentions
key_states = curr_past_key_value.key_cache[self.layer_idx]
value_states = curr_past_key_value.value_cache[self.layer_idx]
else:
key_states = self.k(current_states)
value_states = self.v(current_states)
key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
# get key/value states if past_key_value is not None:
key_states = project( # save all key/value_states to cache to be re-used for fast auto-regressive generation
hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None cache_position = cache_position if not is_cross_attention else None
) key_states, value_states = curr_past_key_value.update(
value_states = project( key_states, value_states, self.layer_idx, {"cache_position": cache_position}
hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
) )
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
if is_cross_attention:
past_key_value.is_updated[self.layer_idx] = True
# compute scores # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
scores = torch.matmul( scores = torch.matmul(query_states, key_states.transpose(3, 2))
query_states, key_states.transpose(3, 2)
) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
if position_bias is None: if position_bias is None:
key_length = key_states.shape[-2]
# cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
real_seq_length = query_length if query_length is not None else cache_position[-1] + 1
if not self.has_relative_attention_bias: if not self.has_relative_attention_bias:
position_bias = torch.zeros( position_bias = torch.zeros(
(1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype
) )
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
position_bias.requires_grad = True position_bias.requires_grad = True
else: else:
position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) position_bias = self.compute_bias(
real_seq_length, key_length, device=scores.device, cache_position=cache_position
# if key and values are already calculated )
# we want only the last query position bias position_bias = position_bias[:, :, -seq_length:, :]
if past_key_value is not None:
position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
if mask is not None: if mask is not None:
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) causal_mask = mask[:, :, :, : key_states.shape[-2]]
position_bias = position_bias + causal_mask
if self.pruned_heads: if self.pruned_heads:
mask = torch.ones(position_bias.shape[1]) mask = torch.ones(position_bias.shape[1])
@@ -529,22 +525,22 @@ class LongT5Attention(nn.Module):
position_bias_masked = position_bias position_bias_masked = position_bias
scores += position_bias_masked scores += position_bias_masked
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
scores # (batch_size, n_heads, seq_length, key_length)
) # (batch_size, n_heads, seq_length, key_length) attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
attn_weights = nn.functional.dropout( attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_weights, p=self.dropout, training=self.training
) # (batch_size, n_heads, seq_length, key_length)
# Mask heads if we want to # Mask heads if we want to
if layer_head_mask is not None: if layer_head_mask is not None:
attn_weights = attn_weights * layer_head_mask attn_weights = attn_weights * layer_head_mask
attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, -1, self.inner_dim)
attn_output = self.o(attn_output) attn_output = self.o(attn_output)
present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None outputs = (attn_output, past_key_value, position_bias)
outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
if output_attentions: if output_attentions:
outputs = outputs + (attn_weights,) outputs = outputs + (attn_weights,)
@@ -1008,9 +1004,11 @@ class LongT5TransientGlobalAttention(nn.Module):
# Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->LongT5 # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->LongT5
class LongT5LayerSelfAttention(nn.Module): class LongT5LayerSelfAttention(nn.Module):
def __init__(self, config, has_relative_attention_bias=False): def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.SelfAttention = LongT5Attention(config, has_relative_attention_bias=has_relative_attention_bias) self.SelfAttention = LongT5Attention(
config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx
)
self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate) self.dropout = nn.Dropout(config.dropout_rate)
@@ -1023,6 +1021,7 @@ class LongT5LayerSelfAttention(nn.Module):
past_key_value=None, past_key_value=None,
use_cache=False, use_cache=False,
output_attentions=False, output_attentions=False,
cache_position=None,
): ):
normed_hidden_states = self.layer_norm(hidden_states) normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.SelfAttention( attention_output = self.SelfAttention(
@@ -1033,6 +1032,7 @@ class LongT5LayerSelfAttention(nn.Module):
past_key_value=past_key_value, past_key_value=past_key_value,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
cache_position=cache_position,
) )
hidden_states = hidden_states + self.dropout(attention_output[0]) hidden_states = hidden_states + self.dropout(attention_output[0])
outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
@@ -1042,7 +1042,7 @@ class LongT5LayerSelfAttention(nn.Module):
class LongT5LayerLocalSelfAttention(nn.Module): class LongT5LayerLocalSelfAttention(nn.Module):
"""Local self attention used in encoder""" """Local self attention used in encoder"""
def __init__(self, config, has_relative_attention_bias=False): def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.LocalSelfAttention = LongT5LocalAttention(config, has_relative_attention_bias=has_relative_attention_bias) self.LocalSelfAttention = LongT5LocalAttention(config, has_relative_attention_bias=has_relative_attention_bias)
self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
@@ -1073,7 +1073,7 @@ class LongT5LayerLocalSelfAttention(nn.Module):
class LongT5LayerTransientGlobalSelfAttention(nn.Module): class LongT5LayerTransientGlobalSelfAttention(nn.Module):
"""Transient-Global self attention used in encoder""" """Transient-Global self attention used in encoder"""
def __init__(self, config, has_relative_attention_bias=False): def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.TransientGlobalSelfAttention = LongT5TransientGlobalAttention( self.TransientGlobalSelfAttention = LongT5TransientGlobalAttention(
config, has_relative_attention_bias=has_relative_attention_bias config, has_relative_attention_bias=has_relative_attention_bias
@@ -1105,9 +1105,9 @@ class LongT5LayerTransientGlobalSelfAttention(nn.Module):
# Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->LongT5 # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->LongT5
class LongT5LayerCrossAttention(nn.Module): class LongT5LayerCrossAttention(nn.Module):
def __init__(self, config): def __init__(self, config, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.EncDecAttention = LongT5Attention(config, has_relative_attention_bias=False) self.EncDecAttention = LongT5Attention(config, has_relative_attention_bias=False, layer_idx=layer_idx)
self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate) self.dropout = nn.Dropout(config.dropout_rate)
@@ -1122,6 +1122,7 @@ class LongT5LayerCrossAttention(nn.Module):
use_cache=False, use_cache=False,
query_length=None, query_length=None,
output_attentions=False, output_attentions=False,
cache_position=None,
): ):
normed_hidden_states = self.layer_norm(hidden_states) normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.EncDecAttention( attention_output = self.EncDecAttention(
@@ -1134,6 +1135,7 @@ class LongT5LayerCrossAttention(nn.Module):
use_cache=use_cache, use_cache=use_cache,
query_length=query_length, query_length=query_length,
output_attentions=output_attentions, output_attentions=output_attentions,
cache_position=cache_position,
) )
layer_output = hidden_states + self.dropout(attention_output[0]) layer_output = hidden_states + self.dropout(attention_output[0])
outputs = (layer_output,) + attention_output[1:] # add attentions if we output them outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
@@ -1141,7 +1143,7 @@ class LongT5LayerCrossAttention(nn.Module):
class LongT5Block(nn.Module): class LongT5Block(nn.Module):
def __init__(self, config, has_relative_attention_bias=False): def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
if config.is_decoder: if config.is_decoder:
@@ -1156,9 +1158,11 @@ class LongT5Block(nn.Module):
f"but got {config.encoder_attention_type}." f"but got {config.encoder_attention_type}."
) )
self.layer = nn.ModuleList() self.layer = nn.ModuleList()
self.layer.append(attention_layer(config, has_relative_attention_bias=has_relative_attention_bias)) self.layer.append(
attention_layer(config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx)
)
if self.is_decoder: if self.is_decoder:
self.layer.append(LongT5LayerCrossAttention(config)) self.layer.append(LongT5LayerCrossAttention(config, layer_idx=layer_idx))
self.layer.append(LongT5LayerFF(config)) self.layer.append(LongT5LayerFF(config))
@@ -1176,34 +1180,19 @@ class LongT5Block(nn.Module):
use_cache=False, use_cache=False,
output_attentions=False, output_attentions=False,
return_dict=True, return_dict=True,
cache_position=None,
): ):
if past_key_value is not None:
if not self.is_decoder:
logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.")
expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
if len(past_key_value) != expected_num_past_key_values:
raise ValueError(
f"There should be {expected_num_past_key_values} past states. "
f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
f"Got {len(past_key_value)} past key / value states"
)
self_attn_past_key_value = past_key_value[:2]
cross_attn_past_key_value = past_key_value[2:]
else:
self_attn_past_key_value, cross_attn_past_key_value = None, None
self_attention_outputs = self.layer[0]( self_attention_outputs = self.layer[0](
hidden_states, hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
position_bias=position_bias, position_bias=position_bias,
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
past_key_value=self_attn_past_key_value, past_key_value=past_key_value,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
cache_position=cache_position,
) )
hidden_states, present_key_value_state = self_attention_outputs[:2] hidden_states, past_key_value = self_attention_outputs[:2]
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
# clamp inf values to enable fp16 inference - check https://github.com/huggingface/transformers/pull/19229/ # clamp inf values to enable fp16 inference - check https://github.com/huggingface/transformers/pull/19229/
@@ -1213,35 +1202,25 @@ class LongT5Block(nn.Module):
do_cross_attention = self.is_decoder and encoder_hidden_states is not None do_cross_attention = self.is_decoder and encoder_hidden_states is not None
if do_cross_attention: if do_cross_attention:
# the actual query length is unknown for cross attention
# if using past key value states. Need to inject it here
if present_key_value_state is not None:
query_length = present_key_value_state[0].shape[2]
else:
query_length = None
cross_attention_outputs = self.layer[1]( cross_attention_outputs = self.layer[1](
hidden_states, hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
position_bias=encoder_decoder_position_bias, position_bias=encoder_decoder_position_bias,
layer_head_mask=cross_attn_layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=past_key_value,
query_length=query_length, query_length=cache_position[-1] + 1,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
cache_position=cache_position,
) )
hidden_states = cross_attention_outputs[0] hidden_states, past_key_value = cross_attention_outputs[:2]
# clamp inf values to enable fp16 inference - check https://github.com/huggingface/transformers/pull/19229/ # clamp inf values to enable fp16 inference - check https://github.com/huggingface/transformers/pull/19229/
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000 clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
# Combine self attn and cross attn key value states
if present_key_value_state is not None:
present_key_value_state = present_key_value_state + cross_attention_outputs[1]
# Keep cross-attention outputs and relative position weights # Keep cross-attention outputs and relative position weights
attention_outputs = attention_outputs + cross_attention_outputs[2:] attention_outputs = attention_outputs + cross_attention_outputs[2:]
@@ -1256,7 +1235,7 @@ class LongT5Block(nn.Module):
outputs = (hidden_states,) outputs = (hidden_states,)
if use_cache: if use_cache:
outputs = outputs + (present_key_value_state,) + attention_outputs outputs = outputs + (past_key_value,) + attention_outputs
else: else:
outputs = outputs + attention_outputs outputs = outputs + attention_outputs
@@ -1273,6 +1252,8 @@ class LongT5PreTrainedModel(PreTrainedModel):
base_model_prefix = "transformer" base_model_prefix = "transformer"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_no_split_modules = ["LongT5Block"] _no_split_modules = ["LongT5Block"]
_supports_cache_class = True
_supports_static_cache = False # TODO: @raushan more involved due to local/global attn
@property @property
# Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel.dummy_inputs # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel.dummy_inputs
@@ -1376,7 +1357,10 @@ class LongT5Stack(LongT5PreTrainedModel):
self.block_len = self.local_radius + 1 self.block_len = self.local_radius + 1
self.block = nn.ModuleList( self.block = nn.ModuleList(
[LongT5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] [
LongT5Block(config, has_relative_attention_bias=bool(i == 0), layer_idx=i)
for i in range(config.num_layers)
]
) )
self.final_layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.final_layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate) self.dropout = nn.Dropout(config.dropout_rate)
@@ -1408,6 +1392,7 @@ class LongT5Stack(LongT5PreTrainedModel):
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
cache_position=None,
): ):
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -1430,36 +1415,65 @@ class LongT5Stack(LongT5PreTrainedModel):
err_msg_prefix = "decoder_" if self.is_decoder else "" err_msg_prefix = "decoder_" if self.is_decoder else ""
raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if inputs_embeds is None: if inputs_embeds is None:
assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings" assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
batch_size, seq_length = input_shape batch_size, seq_length = input_shape
# initialize past_key_values
return_legacy_cache = False
return_self_attention_cache = False
if self.is_decoder and (use_cache or past_key_values is not None):
if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache):
return_self_attention_cache = True
past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
elif not isinstance(past_key_values, EncoderDecoderCache):
return_legacy_cache = True
logger.warning_once(
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. "
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "
"`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
)
past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
elif past_key_values is None:
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
elif not self.is_decoder:
# do not pass cache object down the line for encoder stack
# it messes indexing later in decoder-stack because cache object is modified in-place
past_key_values = None
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
if cache_position is None:
cache_position = torch.arange(
past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
)
if attention_mask is None and not is_torchdynamo_compiling():
# required mask seq length can be calculated via length of past # required mask seq length can be calculated via length of past
mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length mask_seq_length = past_key_values_length + seq_length
if use_cache is True:
assert self.is_decoder, f"`use_cache` can only be set to `True` if {self} is used as a decoder"
if attention_mask is None:
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
# initialize past_key_values with `None` if past does not exist
if past_key_values is None:
past_key_values = [None] * len(self.block)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
# We use local attention in encoder self-attention, otherwise standard self & cross attentions are used
if self.is_decoder: if self.is_decoder:
extended_attention_mask = self.get_extended_attention_mask( causal_mask = self._update_causal_mask(
attention_mask, input_shape, inputs_embeds.device attention_mask,
inputs_embeds,
cache_position,
past_key_values.self_attention_cache if past_key_values is not None else None,
output_attentions,
) )
# We use local attention in encoder self-attention, otherwise standard self & cross attentions are used
elif self.config.encoder_attention_type == "local": elif self.config.encoder_attention_type == "local":
extended_attention_mask = _get_local_attention_mask(attention_mask, self.block_len, inputs_embeds.device) causal_mask = _get_local_attention_mask(attention_mask, self.block_len, inputs_embeds.device)
else: # we need to use both local attention mask and standard extended mask for transient-global attention else: # we need to use both local attention mask and standard extended mask for transient-global attention
extended_attention_mask = attention_mask causal_mask = attention_mask
# If a 2D or 3D attention mask is provided for the cross-attention # If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
@@ -1472,17 +1486,9 @@ class LongT5Stack(LongT5PreTrainedModel):
else: else:
encoder_extended_attention_mask = None encoder_extended_attention_mask = None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# Prepare head mask if needed # Prepare head mask if needed
head_mask = self.get_head_mask(head_mask, self.config.num_layers) head_mask = self.get_head_mask(head_mask, self.config.num_layers)
cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
present_key_value_states = () if use_cache else None
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
all_cross_attentions = () if (output_attentions and self.is_decoder) else None all_cross_attentions = () if (output_attentions and self.is_decoder) else None
@@ -1491,7 +1497,7 @@ class LongT5Stack(LongT5PreTrainedModel):
hidden_states = self.dropout(inputs_embeds) hidden_states = self.dropout(inputs_embeds)
for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): for i, layer_module in enumerate(self.block):
layer_head_mask = head_mask[i] layer_head_mask = head_mask[i]
cross_attn_layer_head_mask = cross_attn_head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i]
@@ -1502,7 +1508,7 @@ class LongT5Stack(LongT5PreTrainedModel):
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.forward, layer_module.forward,
hidden_states, hidden_states,
extended_attention_mask, causal_mask,
position_bias, position_bias,
encoder_hidden_states, encoder_hidden_states,
encoder_extended_attention_mask, encoder_extended_attention_mask,
@@ -1512,20 +1518,24 @@ class LongT5Stack(LongT5PreTrainedModel):
None, # past_key_value is always None with gradient checkpointing None, # past_key_value is always None with gradient checkpointing
use_cache, use_cache,
output_attentions, output_attentions,
return_dict,
cache_position,
) )
else: else:
layer_outputs = layer_module( layer_outputs = layer_module(
hidden_states, hidden_states,
attention_mask=extended_attention_mask, attention_mask=causal_mask,
position_bias=position_bias, position_bias=position_bias,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask, encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias, encoder_decoder_position_bias=encoder_decoder_position_bias,
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
cross_attn_layer_head_mask=cross_attn_layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask,
past_key_value=past_key_value, past_key_value=past_key_values,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
return_dict=return_dict,
cache_position=cache_position,
) )
# layer_outputs is a tuple with: # layer_outputs is a tuple with:
@@ -1533,7 +1543,7 @@ class LongT5Stack(LongT5PreTrainedModel):
if use_cache is False: if use_cache is False:
layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
hidden_states, present_key_value_state = layer_outputs[:2] hidden_states, next_decoder_cache = layer_outputs[:2]
# We share the position biases between the layers - the first layer store them # We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
@@ -1541,9 +1551,6 @@ class LongT5Stack(LongT5PreTrainedModel):
position_bias = layer_outputs[2] position_bias = layer_outputs[2]
if self.is_decoder and encoder_hidden_states is not None: if self.is_decoder and encoder_hidden_states is not None:
encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
# append next layer key value states
if use_cache:
present_key_value_states = present_key_value_states + (present_key_value_state,)
if output_attentions: if output_attentions:
all_attentions = all_attentions + (layer_outputs[3],) all_attentions = all_attentions + (layer_outputs[3],)
@@ -1557,12 +1564,18 @@ class LongT5Stack(LongT5PreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if return_self_attention_cache:
next_cache = past_key_values.self_attention_cache
if return_legacy_cache:
next_cache = past_key_values.to_legacy_cache()
if not return_dict: if not return_dict:
return tuple( return tuple(
v v
for v in [ for v in [
hidden_states, hidden_states,
present_key_value_states, next_cache,
all_hidden_states, all_hidden_states,
all_attentions, all_attentions,
all_cross_attentions, all_cross_attentions,
@@ -1571,12 +1584,135 @@ class LongT5Stack(LongT5PreTrainedModel):
) )
return BaseModelOutputWithPastAndCrossAttentions( return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=present_key_value_states, past_key_values=next_cache,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_attentions, attentions=all_attentions,
cross_attentions=all_cross_attentions, cross_attentions=all_cross_attentions,
) )
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
LONGT5_START_DOCSTRING = r""" LONGT5_START_DOCSTRING = r"""
@@ -1693,6 +1829,9 @@ LONGT5_INPUTS_DOCSTRING = r"""
more detail. more detail.
return_dict (`bool`, *optional*): return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
cache in the correct position and to infer the complete sequence length.
""" """
LONGT5_ENCODER_INPUTS_DOCSTRING = r""" LONGT5_ENCODER_INPUTS_DOCSTRING = r"""
@@ -1817,6 +1956,7 @@ class LongT5Model(LongT5PreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
r""" r"""
Returns: Returns:
@@ -1883,6 +2023,7 @@ class LongT5Model(LongT5PreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
cache_position=cache_position,
) )
if not return_dict: if not return_dict:
@@ -1975,6 +2116,7 @@ class LongT5ForConditionalGeneration(LongT5PreTrainedModel, GenerationMixin):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -2050,6 +2192,7 @@ class LongT5ForConditionalGeneration(LongT5PreTrainedModel, GenerationMixin):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
cache_position=cache_position,
) )
sequence_output = decoder_outputs[0] sequence_output = decoder_outputs[0]

View File

@@ -72,7 +72,12 @@ class MT5Config(PretrainedConfig):
model_type = "mt5" model_type = "mt5"
keys_to_ignore_at_inference = ["past_key_values"] keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"} attribute_map = {
"hidden_size": "d_model",
"num_attention_heads": "num_heads",
"num_hidden_layers": "num_layers",
"head_dim": "d_kv",
}
def __init__( def __init__(
self, self,

View File

@@ -25,7 +25,9 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
@@ -43,6 +45,7 @@ from ...utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_torch_fx_proxy, is_torch_fx_proxy,
is_torchdynamo_compiling,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
@@ -214,7 +217,12 @@ class MT5LayerFF(nn.Module):
# Copied from transformers.models.t5.modeling_t5.T5Attention with T5->MT5 # Copied from transformers.models.t5.modeling_t5.T5Attention with T5->MT5
class MT5Attention(nn.Module): class MT5Attention(nn.Module):
def __init__(self, config: MT5Config, has_relative_attention_bias=False): def __init__(
self,
config: MT5Config,
has_relative_attention_bias=False,
layer_idx: Optional[int] = None,
):
super().__init__() super().__init__()
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
self.has_relative_attention_bias = has_relative_attention_bias self.has_relative_attention_bias = has_relative_attention_bias
@@ -225,6 +233,13 @@ class MT5Attention(nn.Module):
self.n_heads = config.num_heads self.n_heads = config.num_heads
self.dropout = config.dropout_rate self.dropout = config.dropout_rate
self.inner_dim = self.n_heads * self.key_value_proj_dim self.inner_dim = self.n_heads * self.key_value_proj_dim
self.layer_idx = layer_idx
if layer_idx is None and self.is_decoder:
logger.warning_once(
f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
"will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
# Mesh TensorFlow initialization to avoid scaling before softmax # Mesh TensorFlow initialization to avoid scaling before softmax
self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
@@ -301,11 +316,14 @@ class MT5Attention(nn.Module):
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
return relative_buckets return relative_buckets
def compute_bias(self, query_length, key_length, device=None): def compute_bias(self, query_length, key_length, device=None, cache_position=None):
"""Compute binned relative position bias""" """Compute binned relative position bias"""
if device is None: if device is None:
device = self.relative_attention_bias.weight.device device = self.relative_attention_bias.weight.device
if cache_position is None:
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
else:
context_position = cache_position[:, None].to(device)
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
relative_position = memory_position - context_position # shape (query_length, key_length) relative_position = memory_position - context_position # shape (query_length, key_length)
relative_position_bucket = self._relative_position_bucket( relative_position_bucket = self._relative_position_bucket(
@@ -329,94 +347,72 @@ class MT5Attention(nn.Module):
query_length=None, query_length=None,
use_cache=False, use_cache=False,
output_attentions=False, output_attentions=False,
cache_position=None,
): ):
""" """
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
""" """
# Input is (batch_size, seq_length, dim) # Input is (batch_size, seq_length, dim)
# Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder)
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
batch_size, seq_length = hidden_states.shape[:2] batch_size, seq_length = hidden_states.shape[:2]
real_seq_length = seq_length # if key_value_states are provided this layer is used as a cross-attention layer for the decoder
is_cross_attention = key_value_states is not None
query_states = self.q(hidden_states)
query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
if past_key_value is not None: if past_key_value is not None:
if len(past_key_value) != 2: is_updated = past_key_value.is_updated.get(self.layer_idx)
raise ValueError( if is_cross_attention:
f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" # after the first generated id, we can subsequently re-use all key/value_states from cache
) curr_past_key_value = past_key_value.cross_attention_cache
real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
def shape(states):
"""projection"""
return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
def unshape(states):
"""reshape"""
return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
def project(hidden_states, proj_layer, key_value_states, past_key_value):
"""projects hidden states correctly to key/query states"""
if key_value_states is None:
# self-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(hidden_states))
elif past_key_value is None:
# cross-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(key_value_states))
if past_key_value is not None:
if key_value_states is None:
# self-attn
# (batch_size, n_heads, key_length, dim_per_head)
hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
elif past_key_value.shape[2] != key_value_states.shape[1]:
# checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
# cross-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(key_value_states))
else: else:
# cross-attn curr_past_key_value = past_key_value.self_attention_cache
hidden_states = past_key_value
return hidden_states
# get query states current_states = key_value_states if is_cross_attention else hidden_states
query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) if is_cross_attention and past_key_value is not None and is_updated:
# reuse k,v, cross_attentions
key_states = curr_past_key_value.key_cache[self.layer_idx]
value_states = curr_past_key_value.value_cache[self.layer_idx]
else:
key_states = self.k(current_states)
value_states = self.v(current_states)
key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
# get key/value states if past_key_value is not None:
key_states = project( # save all key/value_states to cache to be re-used for fast auto-regressive generation
hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None cache_position = cache_position if not is_cross_attention else None
) key_states, value_states = curr_past_key_value.update(
value_states = project( key_states, value_states, self.layer_idx, {"cache_position": cache_position}
hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
) )
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
if is_cross_attention:
past_key_value.is_updated[self.layer_idx] = True
# compute scores # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
scores = torch.matmul( scores = torch.matmul(query_states, key_states.transpose(3, 2))
query_states, key_states.transpose(3, 2)
) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
if position_bias is None: if position_bias is None:
key_length = key_states.shape[-2]
# cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
real_seq_length = query_length if query_length is not None else cache_position[-1] + 1
if not self.has_relative_attention_bias: if not self.has_relative_attention_bias:
position_bias = torch.zeros( position_bias = torch.zeros(
(1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype
) )
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
position_bias.requires_grad = True position_bias.requires_grad = True
else: else:
position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) position_bias = self.compute_bias(
real_seq_length, key_length, device=scores.device, cache_position=cache_position
# if key and values are already calculated )
# we want only the last query position bias position_bias = position_bias[:, :, -seq_length:, :]
if past_key_value is not None:
position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
if mask is not None: if mask is not None:
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) causal_mask = mask[:, :, :, : key_states.shape[-2]]
position_bias = position_bias + causal_mask
if self.pruned_heads: if self.pruned_heads:
mask = torch.ones(position_bias.shape[1]) mask = torch.ones(position_bias.shape[1])
@@ -426,22 +422,22 @@ class MT5Attention(nn.Module):
position_bias_masked = position_bias position_bias_masked = position_bias
scores += position_bias_masked scores += position_bias_masked
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
scores # (batch_size, n_heads, seq_length, key_length)
) # (batch_size, n_heads, seq_length, key_length) attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
attn_weights = nn.functional.dropout( attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_weights, p=self.dropout, training=self.training
) # (batch_size, n_heads, seq_length, key_length)
# Mask heads if we want to # Mask heads if we want to
if layer_head_mask is not None: if layer_head_mask is not None:
attn_weights = attn_weights * layer_head_mask attn_weights = attn_weights * layer_head_mask
attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, -1, self.inner_dim)
attn_output = self.o(attn_output) attn_output = self.o(attn_output)
present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None outputs = (attn_output, past_key_value, position_bias)
outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
if output_attentions: if output_attentions:
outputs = outputs + (attn_weights,) outputs = outputs + (attn_weights,)
@@ -450,9 +446,11 @@ class MT5Attention(nn.Module):
# Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->MT5 # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->MT5
class MT5LayerSelfAttention(nn.Module): class MT5LayerSelfAttention(nn.Module):
def __init__(self, config, has_relative_attention_bias=False): def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.SelfAttention = MT5Attention(config, has_relative_attention_bias=has_relative_attention_bias) self.SelfAttention = MT5Attention(
config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx
)
self.layer_norm = MT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.layer_norm = MT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate) self.dropout = nn.Dropout(config.dropout_rate)
@@ -465,6 +463,7 @@ class MT5LayerSelfAttention(nn.Module):
past_key_value=None, past_key_value=None,
use_cache=False, use_cache=False,
output_attentions=False, output_attentions=False,
cache_position=None,
): ):
normed_hidden_states = self.layer_norm(hidden_states) normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.SelfAttention( attention_output = self.SelfAttention(
@@ -475,6 +474,7 @@ class MT5LayerSelfAttention(nn.Module):
past_key_value=past_key_value, past_key_value=past_key_value,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
cache_position=cache_position,
) )
hidden_states = hidden_states + self.dropout(attention_output[0]) hidden_states = hidden_states + self.dropout(attention_output[0])
outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
@@ -483,9 +483,9 @@ class MT5LayerSelfAttention(nn.Module):
# Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->MT5 # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->MT5
class MT5LayerCrossAttention(nn.Module): class MT5LayerCrossAttention(nn.Module):
def __init__(self, config): def __init__(self, config, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.EncDecAttention = MT5Attention(config, has_relative_attention_bias=False) self.EncDecAttention = MT5Attention(config, has_relative_attention_bias=False, layer_idx=layer_idx)
self.layer_norm = MT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.layer_norm = MT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate) self.dropout = nn.Dropout(config.dropout_rate)
@@ -500,6 +500,7 @@ class MT5LayerCrossAttention(nn.Module):
use_cache=False, use_cache=False,
query_length=None, query_length=None,
output_attentions=False, output_attentions=False,
cache_position=None,
): ):
normed_hidden_states = self.layer_norm(hidden_states) normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.EncDecAttention( attention_output = self.EncDecAttention(
@@ -512,6 +513,7 @@ class MT5LayerCrossAttention(nn.Module):
use_cache=use_cache, use_cache=use_cache,
query_length=query_length, query_length=query_length,
output_attentions=output_attentions, output_attentions=output_attentions,
cache_position=cache_position,
) )
layer_output = hidden_states + self.dropout(attention_output[0]) layer_output = hidden_states + self.dropout(attention_output[0])
outputs = (layer_output,) + attention_output[1:] # add attentions if we output them outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
@@ -520,13 +522,15 @@ class MT5LayerCrossAttention(nn.Module):
# Copied from transformers.models.t5.modeling_t5.T5Block with T5->MT5 # Copied from transformers.models.t5.modeling_t5.T5Block with T5->MT5
class MT5Block(nn.Module): class MT5Block(nn.Module):
def __init__(self, config, has_relative_attention_bias=False): def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
self.layer = nn.ModuleList() self.layer = nn.ModuleList()
self.layer.append(MT5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) self.layer.append(
MT5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx)
)
if self.is_decoder: if self.is_decoder:
self.layer.append(MT5LayerCrossAttention(config)) self.layer.append(MT5LayerCrossAttention(config, layer_idx=layer_idx))
self.layer.append(MT5LayerFF(config)) self.layer.append(MT5LayerFF(config))
@@ -544,34 +548,19 @@ class MT5Block(nn.Module):
use_cache=False, use_cache=False,
output_attentions=False, output_attentions=False,
return_dict=True, return_dict=True,
cache_position=None,
): ):
if past_key_value is not None:
if not self.is_decoder:
logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.")
expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
if len(past_key_value) != expected_num_past_key_values:
raise ValueError(
f"There should be {expected_num_past_key_values} past states. "
f"{'2 (key / value) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
f"Got {len(past_key_value)} past key / value states"
)
self_attn_past_key_value = past_key_value[:2]
cross_attn_past_key_value = past_key_value[2:]
else:
self_attn_past_key_value, cross_attn_past_key_value = None, None
self_attention_outputs = self.layer[0]( self_attention_outputs = self.layer[0](
hidden_states, hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
position_bias=position_bias, position_bias=position_bias,
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
past_key_value=self_attn_past_key_value, past_key_value=past_key_value,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
cache_position=cache_position,
) )
hidden_states, present_key_value_state = self_attention_outputs[:2] hidden_states, past_key_value = self_attention_outputs[:2]
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
# clamp inf values to enable fp16 training # clamp inf values to enable fp16 training
@@ -585,25 +574,18 @@ class MT5Block(nn.Module):
do_cross_attention = self.is_decoder and encoder_hidden_states is not None do_cross_attention = self.is_decoder and encoder_hidden_states is not None
if do_cross_attention: if do_cross_attention:
# the actual query length is unknown for cross attention
# if using past key value states. Need to inject it here
if present_key_value_state is not None:
query_length = present_key_value_state[0].shape[2]
else:
query_length = None
cross_attention_outputs = self.layer[1]( cross_attention_outputs = self.layer[1](
hidden_states, hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
position_bias=encoder_decoder_position_bias, position_bias=encoder_decoder_position_bias,
layer_head_mask=cross_attn_layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=past_key_value,
query_length=query_length, query_length=cache_position[-1] + 1,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = cross_attention_outputs[0] hidden_states, past_key_value = cross_attention_outputs[:2]
# clamp inf values to enable fp16 training # clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16: if hidden_states.dtype == torch.float16:
@@ -614,10 +596,6 @@ class MT5Block(nn.Module):
) )
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
# Combine self attn and cross attn key value states
if present_key_value_state is not None:
present_key_value_state = present_key_value_state + cross_attention_outputs[1]
# Keep cross-attention outputs and relative position weights # Keep cross-attention outputs and relative position weights
attention_outputs = attention_outputs + cross_attention_outputs[2:] attention_outputs = attention_outputs + cross_attention_outputs[2:]
@@ -636,11 +614,11 @@ class MT5Block(nn.Module):
outputs = (hidden_states,) outputs = (hidden_states,)
if use_cache: if use_cache:
outputs = outputs + (present_key_value_state,) + attention_outputs outputs = outputs + (past_key_value,) + attention_outputs
else: else:
outputs = outputs + attention_outputs outputs = outputs + attention_outputs
return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) return outputs # hidden-states, past_key_value, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
def load_tf_weights_in_mt5(model, config, tf_checkpoint_path): def load_tf_weights_in_mt5(model, config, tf_checkpoint_path):
@@ -780,6 +758,9 @@ class MT5PreTrainedModel(PreTrainedModel):
base_model_prefix = "transformer" base_model_prefix = "transformer"
is_parallelizable = True is_parallelizable = True
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_supports_quantized_cache = False # enc-dec models don't support yet
_supports_static_cache = True
_supports_cache_class = True
_no_split_modules = ["MT5Block"] _no_split_modules = ["MT5Block"]
_keep_in_fp32_modules = ["wo"] _keep_in_fp32_modules = ["wo"]
@@ -892,7 +873,7 @@ class MT5Stack(MT5PreTrainedModel):
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
self.block = nn.ModuleList( self.block = nn.ModuleList(
[MT5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] [MT5Block(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) for i in range(config.num_layers)]
) )
self.final_layer_norm = MT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.final_layer_norm = MT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate) self.dropout = nn.Dropout(config.dropout_rate)
@@ -968,6 +949,7 @@ class MT5Stack(MT5PreTrainedModel):
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
cache_position=None,
): ):
# Model parallel # Model parallel
if self.model_parallel: if self.model_parallel:
@@ -994,6 +976,13 @@ class MT5Stack(MT5PreTrainedModel):
err_msg_prefix = "decoder_" if self.is_decoder else "" err_msg_prefix = "decoder_" if self.is_decoder else ""
raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if inputs_embeds is None: if inputs_embeds is None:
if self.embed_tokens is None: if self.embed_tokens is None:
raise ValueError("You have to initialize the model with valid token embeddings") raise ValueError("You have to initialize the model with valid token embeddings")
@@ -1001,23 +990,57 @@ class MT5Stack(MT5PreTrainedModel):
batch_size, seq_length = input_shape batch_size, seq_length = input_shape
# required mask seq length can be calculated via length of past
mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
if use_cache is True: if use_cache is True:
if not self.is_decoder: if not self.is_decoder:
raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
# initialize past_key_values with `None` if past does not exist # initialize past_key_values
if past_key_values is None: return_legacy_cache = False
past_key_values = [None] * len(self.block) return_self_attention_cache = False
if self.is_decoder and (use_cache or past_key_values is not None):
if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache):
return_self_attention_cache = True
past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
elif not isinstance(past_key_values, EncoderDecoderCache):
return_legacy_cache = True
logger.warning_once(
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. "
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "
"`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
)
past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
elif past_key_values is None:
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
elif not self.is_decoder:
# do not pass cache object down the line for encoder stack
# it messes indexing later in decoder-stack because cache object is modified in-place
past_key_values = None
if attention_mask is None: past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
if cache_position is None:
cache_position = torch.arange(
past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
)
if attention_mask is None and not is_torchdynamo_compiling():
# required mask seq length can be calculated via length of past cache
mask_seq_length = past_key_values_length + seq_length
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] if self.config.is_decoder:
# ourselves in which case we just need to make it broadcastable to all heads. causal_mask = self._update_causal_mask(
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) attention_mask,
inputs_embeds,
cache_position,
past_key_values.self_attention_cache if past_key_values is not None else None,
output_attentions,
)
elif attention_mask is not None:
causal_mask = attention_mask[:, None, None, :]
causal_mask = causal_mask.to(dtype=inputs_embeds.dtype)
causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min
else:
causal_mask = None
# If a 2D or 3D attention mask is provided for the cross-attention # If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
@@ -1032,17 +1055,9 @@ class MT5Stack(MT5PreTrainedModel):
else: else:
encoder_extended_attention_mask = None encoder_extended_attention_mask = None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# Prepare head mask if needed # Prepare head mask if needed
head_mask = self.get_head_mask(head_mask, self.config.num_layers) head_mask = self.get_head_mask(head_mask, self.config.num_layers)
cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
present_key_value_states = () if use_cache else None
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
all_cross_attentions = () if (output_attentions and self.is_decoder) else None all_cross_attentions = () if (output_attentions and self.is_decoder) else None
@@ -1051,15 +1066,15 @@ class MT5Stack(MT5PreTrainedModel):
hidden_states = self.dropout(inputs_embeds) hidden_states = self.dropout(inputs_embeds)
for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): for i, layer_module in enumerate(self.block):
layer_head_mask = head_mask[i] layer_head_mask = head_mask[i]
cross_attn_layer_head_mask = cross_attn_head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i]
# Model parallel # Model parallel
if self.model_parallel: if self.model_parallel:
torch.cuda.set_device(hidden_states.device) torch.cuda.set_device(hidden_states.device)
# Ensure that attention_mask is always on the same device as hidden_states # Ensure that attention_mask is always on the same device as hidden_states
if attention_mask is not None: if causal_mask is not None:
attention_mask = attention_mask.to(hidden_states.device) causal_mask = causal_mask.to(hidden_states.device)
if position_bias is not None: if position_bias is not None:
position_bias = position_bias.to(hidden_states.device) position_bias = position_bias.to(hidden_states.device)
if encoder_hidden_states is not None: if encoder_hidden_states is not None:
@@ -1079,7 +1094,7 @@ class MT5Stack(MT5PreTrainedModel):
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.forward, layer_module.forward,
hidden_states, hidden_states,
extended_attention_mask, causal_mask,
position_bias, position_bias,
encoder_hidden_states, encoder_hidden_states,
encoder_extended_attention_mask, encoder_extended_attention_mask,
@@ -1089,20 +1104,24 @@ class MT5Stack(MT5PreTrainedModel):
None, # past_key_value is always None with gradient checkpointing None, # past_key_value is always None with gradient checkpointing
use_cache, use_cache,
output_attentions, output_attentions,
return_dict,
cache_position,
) )
else: else:
layer_outputs = layer_module( layer_outputs = layer_module(
hidden_states, hidden_states,
attention_mask=extended_attention_mask, attention_mask=causal_mask,
position_bias=position_bias, position_bias=position_bias,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask, encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias, encoder_decoder_position_bias=encoder_decoder_position_bias,
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
cross_attn_layer_head_mask=cross_attn_layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask,
past_key_value=past_key_value, past_key_value=past_key_values,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
return_dict=return_dict,
cache_position=cache_position,
) )
# layer_outputs is a tuple with: # layer_outputs is a tuple with:
@@ -1110,7 +1129,7 @@ class MT5Stack(MT5PreTrainedModel):
if use_cache is False: if use_cache is False:
layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
hidden_states, present_key_value_state = layer_outputs[:2] hidden_states, next_decoder_cache = layer_outputs[:2]
# We share the position biases between the layers - the first layer store them # We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
@@ -1118,9 +1137,6 @@ class MT5Stack(MT5PreTrainedModel):
position_bias = layer_outputs[2] position_bias = layer_outputs[2]
if self.is_decoder and encoder_hidden_states is not None: if self.is_decoder and encoder_hidden_states is not None:
encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
# append next layer key value states
if use_cache:
present_key_value_states = present_key_value_states + (present_key_value_state,)
if output_attentions: if output_attentions:
all_attentions = all_attentions + (layer_outputs[3],) all_attentions = all_attentions + (layer_outputs[3],)
@@ -1140,12 +1156,18 @@ class MT5Stack(MT5PreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if return_self_attention_cache:
next_cache = past_key_values.self_attention_cache
if return_legacy_cache:
next_cache = past_key_values.to_legacy_cache()
if not return_dict: if not return_dict:
return tuple( return tuple(
v v
for v in [ for v in [
hidden_states, hidden_states,
present_key_value_states, next_cache,
all_hidden_states, all_hidden_states,
all_attentions, all_attentions,
all_cross_attentions, all_cross_attentions,
@@ -1154,12 +1176,135 @@ class MT5Stack(MT5PreTrainedModel):
) )
return BaseModelOutputWithPastAndCrossAttentions( return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=present_key_value_states, past_key_values=next_cache,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_attentions, attentions=all_attentions,
cross_attentions=all_cross_attentions, cross_attentions=all_cross_attentions,
) )
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
MT5_START_DOCSTRING = r""" MT5_START_DOCSTRING = r"""
@@ -1454,6 +1599,7 @@ class MT5Model(MT5PreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
r""" r"""
Returns: Returns:
@@ -1533,6 +1679,7 @@ class MT5Model(MT5PreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
cache_position=cache_position,
) )
if not return_dict: if not return_dict:
@@ -1685,6 +1832,7 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel, GenerationMixin):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -1779,6 +1927,7 @@ class MT5ForConditionalGeneration(MT5PreTrainedModel, GenerationMixin):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
cache_position=cache_position,
) )
sequence_output = decoder_outputs[0] sequence_output = decoder_outputs[0]

View File

@@ -22,7 +22,9 @@ import torch.utils.checkpoint
from torch import nn from torch import nn
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPooling, BaseModelOutputWithPooling,
@@ -38,6 +40,7 @@ from ...utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_torch_fx_proxy, is_torch_fx_proxy,
is_torchdynamo_compiling,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
@@ -184,14 +187,17 @@ class Pix2StructVisionAttention(nn.Module):
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
position_bias.requires_grad = True position_bias.requires_grad = True
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_length), device=scores.device, dtype=scores.dtype)
if attention_mask.dim() == 2: if attention_mask.dim() == 2:
position_bias = position_bias + attention_mask[:, None, None, :].to(position_bias.device) position_bias = position_bias + attention_mask[:, None, None, :].to(position_bias.device)
else: elif attention_mask is not None:
# (batch_size, n_heads, seq_length, key_length) # (batch_size, n_heads, seq_length, key_length)
position_bias = position_bias + attention_mask.to(position_bias.device) position_bias = position_bias + attention_mask.to(position_bias.device)
elif not is_torchdynamo_compiling():
attention_mask = torch.ones(
(batch_size, seq_length), device=position_bias.device, dtype=position_bias.dtype
)
position_bias = position_bias + attention_mask.to(position_bias.device)
position_bias = 1 - position_bias position_bias = 1 - position_bias
position_bias_masked = position_bias.masked_fill(position_bias == 1, torch.finfo(scores.dtype).min) position_bias_masked = position_bias.masked_fill(position_bias == 1, torch.finfo(scores.dtype).min)
@@ -355,6 +361,8 @@ class Pix2StructPreTrainedModel(PreTrainedModel):
""" """
config_class = Pix2StructConfig config_class = Pix2StructConfig
_supports_cache_class = True
_supports_static_cache = False
@property @property
def dummy_inputs(self): def dummy_inputs(self):
@@ -673,7 +681,9 @@ class Pix2StructTextLayerFF(nn.Module):
class Pix2StructTextAttention(nn.Module): class Pix2StructTextAttention(nn.Module):
def __init__(self, config: Pix2StructTextConfig, has_relative_attention_bias=False): def __init__(
self, config: Pix2StructTextConfig, has_relative_attention_bias=False, layer_idx: Optional[int] = None
):
super().__init__() super().__init__()
self.has_relative_attention_bias = has_relative_attention_bias self.has_relative_attention_bias = has_relative_attention_bias
self.relative_attention_num_buckets = config.relative_attention_num_buckets self.relative_attention_num_buckets = config.relative_attention_num_buckets
@@ -683,6 +693,13 @@ class Pix2StructTextAttention(nn.Module):
self.n_heads = config.num_heads self.n_heads = config.num_heads
self.dropout = config.dropout_rate self.dropout = config.dropout_rate
self.inner_dim = self.n_heads * self.key_value_proj_dim self.inner_dim = self.n_heads * self.key_value_proj_dim
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
"will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
# Mesh TensorFlow initialization to avoid scaling before softmax # Mesh TensorFlow initialization to avoid scaling before softmax
self.query = nn.Linear(self.hidden_size, self.hidden_size, bias=False) self.query = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
@@ -773,75 +790,56 @@ class Pix2StructTextAttention(nn.Module):
query_length=None, query_length=None,
use_cache=False, use_cache=False,
output_attentions=False, output_attentions=False,
cache_position=None,
): ):
""" """
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
""" """
# Input is (batch_size, seq_length, dim) # Input is (batch_size, seq_length, dim)
# Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) # Mask is (batch_size, 1, 1, key_length) (non-causal) or (batch_size, 1, query_length, key_length)
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
batch_size, seq_length = hidden_states.shape[:2] batch_size, seq_length = hidden_states.shape[:2]
real_seq_length = seq_length # if key_value_states are provided this layer is used as a cross-attention layer for the decoder
is_cross_attention = key_value_states is not None
query_states = self.query(hidden_states).contiguous()
query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
if past_key_value is not None: if past_key_value is not None:
if len(past_key_value) != 2: is_updated = past_key_value.is_updated.get(self.layer_idx)
raise ValueError( if is_cross_attention:
f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" # after the first generated id, we can subsequently re-use all key/value_states from cache
) past_key_value = past_key_value.cross_attention_cache
real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
def to_projection_shape(states):
"""projection"""
return states.contiguous().view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
def project(hidden_states, proj_layer, key_value_states, past_key_value):
"""projects hidden states correctly to key/query states"""
if key_value_states is None:
# self-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = to_projection_shape(proj_layer(hidden_states))
elif past_key_value is None:
# cross-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = to_projection_shape(proj_layer(key_value_states))
if past_key_value is not None:
if key_value_states is None:
# self-attn
# (batch_size, n_heads, key_length, dim_per_head)
hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
elif past_key_value.shape[2] != key_value_states.shape[1]:
# checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
# cross-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = to_projection_shape(proj_layer(key_value_states))
else: else:
# cross-attn past_key_value = past_key_value.self_attention_cache
hidden_states = past_key_value
return hidden_states
# get query states
# (batch_size, n_heads, seq_length, dim_per_head)
query_states = to_projection_shape(self.query(hidden_states))
# get key/value states # get key/value states
key_states = project( current_states = key_value_states if is_cross_attention else hidden_states
hidden_states, self.key, key_value_states, past_key_value[0] if past_key_value is not None else None if is_cross_attention and past_key_value and is_updated:
) # reuse k,v, cross_attentions
value_states = project( key_states = past_key_value.key_cache[self.layer_idx]
hidden_states, self.value, key_value_states, past_key_value[1] if past_key_value is not None else None value_states = past_key_value.value_cache[self.layer_idx]
else:
key_states = self.key(current_states).contiguous()
value_states = self.value(current_states).contiguous()
key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
if past_key_value is not None:
# save all key/value_states to cache to be re-used for fast auto-regressive generation
cache_position = cache_position if not is_cross_attention else None
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, {"cache_position": cache_position}
) )
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
if is_cross_attention:
past_key_value.is_updated[self.layer_idx] = True
# compute scores # compute scores
scores = torch.matmul( scores = torch.matmul(query_states, key_states.transpose(3, 2))
query_states, key_states.transpose(3, 2)
) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
if position_bias is None: if position_bias is None:
real_seq_length = cache_position[-1] + 1 if query_length is None else query_length
key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
if not self.has_relative_attention_bias: if not self.has_relative_attention_bias:
position_bias = torch.zeros( position_bias = torch.zeros(
(1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype
@@ -851,11 +849,6 @@ class Pix2StructTextAttention(nn.Module):
else: else:
position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device)
# if key and values are already calculated
# we want only the last query position bias
if past_key_value is not None:
position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
if mask is not None: if mask is not None:
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
@@ -883,19 +876,20 @@ class Pix2StructTextAttention(nn.Module):
attn_output = self.output(attn_output) attn_output = self.output(attn_output)
present_key_value_state = (key_states, value_states) if use_cache else None outputs = (attn_output,) + (past_key_value,) + (position_bias,)
outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
if output_attentions: if output_attentions:
outputs = outputs + (attn_weights,) outputs = outputs + (attn_weights,)
return outputs return outputs
# Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5LayerNorm->Pix2StructLayerNorm,T5Attention->Pix2StructTextAttention,self.SelfAttention->self.attention,config.d_model->config.hidden_size # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5LayerNorm->Pix2StructLayerNorm,T5Attention->Pix2StructTextAttention,T5LayerSelfAttention->Pix2StructTextLayerSelfAttention,self.SelfAttention->self.attention,config.d_model->config.hidden_size
class Pix2StructTextLayerSelfAttention(nn.Module): class Pix2StructTextLayerSelfAttention(nn.Module):
def __init__(self, config, has_relative_attention_bias=False): def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.attention = Pix2StructTextAttention(config, has_relative_attention_bias=has_relative_attention_bias) self.attention = Pix2StructTextAttention(
config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx
)
self.layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate) self.dropout = nn.Dropout(config.dropout_rate)
@@ -908,6 +902,7 @@ class Pix2StructTextLayerSelfAttention(nn.Module):
past_key_value=None, past_key_value=None,
use_cache=False, use_cache=False,
output_attentions=False, output_attentions=False,
cache_position=None,
): ):
normed_hidden_states = self.layer_norm(hidden_states) normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.attention( attention_output = self.attention(
@@ -918,17 +913,18 @@ class Pix2StructTextLayerSelfAttention(nn.Module):
past_key_value=past_key_value, past_key_value=past_key_value,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
cache_position=cache_position,
) )
hidden_states = hidden_states + self.dropout(attention_output[0]) hidden_states = hidden_states + self.dropout(attention_output[0])
outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
return outputs return outputs
# Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5LayerNorm->Pix2StructLayerNorm,T5Attention->Pix2StructTextAttention,self.EncDecAttention->self.attention,config.d_model->config.hidden_size # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5LayerNorm->Pix2StructLayerNorm,T5Attention->Pix2StructTextAttention,T5LayerCrossAttention->Pix2StructTextLayerCrossAttention,self.EncDecAttention->self.attention,config.d_model->config.hidden_size
class Pix2StructTextLayerCrossAttention(nn.Module): class Pix2StructTextLayerCrossAttention(nn.Module):
def __init__(self, config): def __init__(self, config, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.attention = Pix2StructTextAttention(config, has_relative_attention_bias=False) self.attention = Pix2StructTextAttention(config, has_relative_attention_bias=False, layer_idx=layer_idx)
self.layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate) self.dropout = nn.Dropout(config.dropout_rate)
@@ -943,6 +939,7 @@ class Pix2StructTextLayerCrossAttention(nn.Module):
use_cache=False, use_cache=False,
query_length=None, query_length=None,
output_attentions=False, output_attentions=False,
cache_position=None,
): ):
normed_hidden_states = self.layer_norm(hidden_states) normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.attention( attention_output = self.attention(
@@ -955,6 +952,7 @@ class Pix2StructTextLayerCrossAttention(nn.Module):
use_cache=use_cache, use_cache=use_cache,
query_length=query_length, query_length=query_length,
output_attentions=output_attentions, output_attentions=output_attentions,
cache_position=cache_position,
) )
layer_output = hidden_states + self.dropout(attention_output[0]) layer_output = hidden_states + self.dropout(attention_output[0])
outputs = (layer_output,) + attention_output[1:] # add attentions if we output them outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
@@ -962,11 +960,13 @@ class Pix2StructTextLayerCrossAttention(nn.Module):
class Pix2StructTextBlock(nn.Module): class Pix2StructTextBlock(nn.Module):
def __init__(self, config, has_relative_attention_bias=False): def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.self_attention = Pix2StructTextLayerSelfAttention( self.self_attention = Pix2StructTextLayerSelfAttention(
config, has_relative_attention_bias=has_relative_attention_bias config,
has_relative_attention_bias=has_relative_attention_bias,
layer_idx=layer_idx,
) )
self.encoder_decoder_attention = Pix2StructTextLayerCrossAttention(config) self.encoder_decoder_attention = Pix2StructTextLayerCrossAttention(config)
@@ -987,32 +987,19 @@ class Pix2StructTextBlock(nn.Module):
use_cache=False, use_cache=False,
output_attentions=False, output_attentions=False,
return_dict=True, return_dict=True,
cache_position=None,
): ):
if past_key_value is not None:
expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
if len(past_key_value) != expected_num_past_key_values:
raise ValueError(
f"There should be {expected_num_past_key_values} past states. "
f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
f"Got {len(past_key_value)} past key / value states"
)
self_attn_past_key_value = past_key_value[:2]
cross_attn_past_key_value = past_key_value[2:]
else:
self_attn_past_key_value, cross_attn_past_key_value = None, None
self_attention_outputs = self.self_attention( self_attention_outputs = self.self_attention(
hidden_states, hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
position_bias=position_bias, position_bias=position_bias,
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
past_key_value=self_attn_past_key_value, past_key_value=past_key_value,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
cache_position=cache_position,
) )
hidden_states, present_key_value_state = self_attention_outputs[:2] hidden_states, past_key_value = self_attention_outputs[:2]
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
# clamp inf values to enable fp16 training # clamp inf values to enable fp16 training
@@ -1022,35 +1009,25 @@ class Pix2StructTextBlock(nn.Module):
do_cross_attention = encoder_hidden_states is not None do_cross_attention = encoder_hidden_states is not None
if do_cross_attention: if do_cross_attention:
# the actual query length is unknown for cross attention
# if using past key value states. Need to inject it here
if present_key_value_state is not None:
query_length = present_key_value_state[0].shape[2]
else:
query_length = None
cross_attention_outputs = self.encoder_decoder_attention( cross_attention_outputs = self.encoder_decoder_attention(
hidden_states, hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
position_bias=encoder_decoder_position_bias, position_bias=encoder_decoder_position_bias,
layer_head_mask=cross_attn_layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=past_key_value,
query_length=query_length, query_length=cache_position[-1] + 1,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
cache_position=cache_position,
) )
hidden_states = cross_attention_outputs[0] hidden_states, past_key_value = cross_attention_outputs[:2]
# clamp inf values to enable fp16 training # clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000 clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
# Combine self attn and cross attn key value states
if present_key_value_state is not None:
present_key_value_state = present_key_value_state + cross_attention_outputs[1]
# Keep cross-attention outputs and relative position weights # Keep cross-attention outputs and relative position weights
attention_outputs = attention_outputs + cross_attention_outputs[2:] attention_outputs = attention_outputs + cross_attention_outputs[2:]
@@ -1065,7 +1042,7 @@ class Pix2StructTextBlock(nn.Module):
outputs = (hidden_states,) outputs = (hidden_states,)
if use_cache: if use_cache:
outputs = outputs + (present_key_value_state,) + attention_outputs outputs = outputs + (past_key_value,) + attention_outputs
else: else:
outputs = outputs + attention_outputs outputs = outputs + attention_outputs
@@ -1187,6 +1164,9 @@ PIX2STRUCT_TEXT_INPUTS_DOCSTRING = r"""
more detail. more detail.
return_dict (`bool`, *optional*): return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
cache in the correct position and to infer the complete sequence length.
""" """
PIX2STRUCT_INPUTS_DOCSTRING = r""" PIX2STRUCT_INPUTS_DOCSTRING = r"""
@@ -1293,7 +1273,10 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.layer = nn.ModuleList( self.layer = nn.ModuleList(
[Pix2StructTextBlock(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] [
Pix2StructTextBlock(config, has_relative_attention_bias=bool(i == 0), layer_idx=i)
for i in range(config.num_layers)
]
) )
self.final_layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.final_layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate) self.dropout = nn.Dropout(config.dropout_rate)
@@ -1364,6 +1347,7 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None, labels: Optional[torch.LongTensor] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
**kwargs, **kwargs,
) -> Union[Tuple[torch.FloatTensor, ...], CausalLMOutputWithCrossAttentions]: ) -> Union[Tuple[torch.FloatTensor, ...], CausalLMOutputWithCrossAttentions]:
r""" r"""
@@ -1405,24 +1389,54 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
batch_size, seq_length = input_shape batch_size, seq_length = input_shape
# required mask seq length can be calculated via length of past # initialize past_key_values
mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length return_legacy_cache = False
return_self_attention_cache = False
if use_cache or past_key_values is not None:
if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache):
return_self_attention_cache = True
past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
elif not isinstance(past_key_values, EncoderDecoderCache):
return_legacy_cache = True
logger.warning_once(
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. "
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "
"`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
)
past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
elif past_key_values is None:
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
if attention_mask is None: past_key_values_length = 0
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) if cache_position is not None:
if encoder_attention_mask is None and encoder_hidden_states is not None: past_key_values_length = cache_position[0]
encoder_seq_length = encoder_hidden_states.shape[1] elif past_key_values is not None:
encoder_attention_mask = torch.ones( past_key_values_length = past_key_values.get_seq_length()
batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long
if cache_position is None:
cache_position = torch.arange(
past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
) )
# initialize past_key_values with `None` if past does not exist if attention_mask is None:
if past_key_values is None: # required mask seq length can be calculated via length of past
past_key_values = [None] * len(self.layer) mask_seq_length = (
past_key_values.get_seq_length() + seq_length if past_key_values is not None else seq_length
)
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] if self.config.is_decoder:
# ourselves in which case we just need to make it broadcastable to all heads. causal_mask = self._update_causal_mask(
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) attention_mask,
inputs_embeds,
cache_position,
past_key_values.self_attention_cache if past_key_values is not None else None,
output_attentions,
)
else:
causal_mask = attention_mask[:, None, None, :]
causal_mask = causal_mask.to(dtype=inputs_embeds.dtype)
causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min
# If a 2D or 3D attention mask is provided for the cross-attention # If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
@@ -1438,7 +1452,6 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
# Prepare head mask if needed # Prepare head mask if needed
head_mask = self.get_head_mask(head_mask, self.config.num_layers) head_mask = self.get_head_mask(head_mask, self.config.num_layers)
cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
present_key_value_states = () if use_cache else None
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
all_cross_attentions = () if (output_attentions) else None all_cross_attentions = () if (output_attentions) else None
@@ -1447,7 +1460,7 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
hidden_states = self.dropout(inputs_embeds) hidden_states = self.dropout(inputs_embeds)
for i, (layer_module, past_key_value) in enumerate(zip(self.layer, past_key_values)): for i, layer_module in enumerate(self.layer):
layer_head_mask = head_mask[i] layer_head_mask = head_mask[i]
cross_attn_layer_head_mask = cross_attn_head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i]
if output_hidden_states: if output_hidden_states:
@@ -1462,7 +1475,7 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.forward, layer_module.forward,
hidden_states, hidden_states,
extended_attention_mask, causal_mask,
position_bias, position_bias,
encoder_hidden_states, encoder_hidden_states,
encoder_extended_attention_mask, encoder_extended_attention_mask,
@@ -1472,20 +1485,22 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
None, # past_key_value is always None with gradient checkpointing None, # past_key_value is always None with gradient checkpointing
use_cache, use_cache,
output_attentions, output_attentions,
cache_position,
) )
else: else:
layer_outputs = layer_module( layer_outputs = layer_module(
hidden_states, hidden_states,
attention_mask=extended_attention_mask, attention_mask=causal_mask,
position_bias=position_bias, position_bias=position_bias,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask, encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias, encoder_decoder_position_bias=encoder_decoder_position_bias,
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
cross_attn_layer_head_mask=cross_attn_layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask,
past_key_value=past_key_value, past_key_value=past_key_values,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
cache_position=cache_position,
) )
# layer_outputs is a tuple with: # layer_outputs is a tuple with:
@@ -1493,7 +1508,7 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
if use_cache is False: if use_cache is False:
layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
hidden_states, present_key_value_state = layer_outputs[:2] hidden_states, next_decoder_cache = layer_outputs[:2]
# We share the position biases between the layers - the first layer store them # We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
@@ -1501,9 +1516,6 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
position_bias = layer_outputs[2] position_bias = layer_outputs[2]
if encoder_hidden_states is not None: if encoder_hidden_states is not None:
encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
# append next layer key value states
if use_cache:
present_key_value_states = present_key_value_states + (present_key_value_state,)
if output_attentions: if output_attentions:
all_attentions = all_attentions + (layer_outputs[3],) all_attentions = all_attentions + (layer_outputs[3],)
@@ -1527,13 +1539,19 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
loss = loss_fct(logits.contiguous().view(-1, logits.size(-1)), labels.contiguous().view(-1)) loss = loss_fct(logits.contiguous().view(-1, logits.size(-1)), labels.contiguous().view(-1))
next_cache = next_decoder_cache if use_cache else None
if return_self_attention_cache:
next_cache = past_key_values.self_attention_cache
if return_legacy_cache:
next_cache = past_key_values.to_legacy_cache()
if not return_dict: if not return_dict:
return tuple( return tuple(
v v
for v in [ for v in [
loss, loss,
logits, logits,
present_key_value_states, next_cache,
all_hidden_states, all_hidden_states,
all_attentions, all_attentions,
all_cross_attentions, all_cross_attentions,
@@ -1543,12 +1561,135 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
return CausalLMOutputWithCrossAttentions( return CausalLMOutputWithCrossAttentions(
loss=loss, loss=loss,
logits=logits, logits=logits,
past_key_values=present_key_value_states, past_key_values=next_cache,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_attentions, attentions=all_attentions,
cross_attentions=all_cross_attentions, cross_attentions=all_cross_attentions,
) )
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
@add_start_docstrings( @add_start_docstrings(
"A conditional generation model with a language modeling head. Can be used for sequence generation tasks.", "A conditional generation model with a language modeling head. Can be used for sequence generation tasks.",
@@ -1615,6 +1756,7 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel, GenerationMi
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
r""" r"""
Returns: Returns:
@@ -1723,6 +1865,7 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel, GenerationMi
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
labels=labels, labels=labels,
return_dict=return_dict, return_dict=return_dict,
cache_position=cache_position,
) )
if not return_dict: if not return_dict:

View File

@@ -25,7 +25,9 @@ from torch.nn import CrossEntropyLoss
from transformers.generation import GenerationConfig from transformers.generation import GenerationConfig
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
@@ -37,6 +39,7 @@ from ...utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_torch_fx_proxy, is_torch_fx_proxy,
is_torchdynamo_compiling,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
@@ -136,6 +139,9 @@ POP2PIANO_INPUTS_DOCSTRING = r"""
more detail. more detail.
return_dict (`bool`, *optional*): return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
cache in the correct position and to infer the complete sequence length.
""" """
@@ -245,7 +251,12 @@ class Pop2PianoLayerFF(nn.Module):
# Copied from transformers.models.t5.modeling_t5.T5Attention with T5->Pop2Piano,t5->pop2piano # Copied from transformers.models.t5.modeling_t5.T5Attention with T5->Pop2Piano,t5->pop2piano
class Pop2PianoAttention(nn.Module): class Pop2PianoAttention(nn.Module):
def __init__(self, config: Pop2PianoConfig, has_relative_attention_bias=False): def __init__(
self,
config: Pop2PianoConfig,
has_relative_attention_bias=False,
layer_idx: Optional[int] = None,
):
super().__init__() super().__init__()
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
self.has_relative_attention_bias = has_relative_attention_bias self.has_relative_attention_bias = has_relative_attention_bias
@@ -256,6 +267,13 @@ class Pop2PianoAttention(nn.Module):
self.n_heads = config.num_heads self.n_heads = config.num_heads
self.dropout = config.dropout_rate self.dropout = config.dropout_rate
self.inner_dim = self.n_heads * self.key_value_proj_dim self.inner_dim = self.n_heads * self.key_value_proj_dim
self.layer_idx = layer_idx
if layer_idx is None and self.is_decoder:
logger.warning_once(
f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
"will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
# Mesh TensorFlow initialization to avoid scaling before softmax # Mesh TensorFlow initialization to avoid scaling before softmax
self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
@@ -332,11 +350,14 @@ class Pop2PianoAttention(nn.Module):
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
return relative_buckets return relative_buckets
def compute_bias(self, query_length, key_length, device=None): def compute_bias(self, query_length, key_length, device=None, cache_position=None):
"""Compute binned relative position bias""" """Compute binned relative position bias"""
if device is None: if device is None:
device = self.relative_attention_bias.weight.device device = self.relative_attention_bias.weight.device
if cache_position is None:
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
else:
context_position = cache_position[:, None].to(device)
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
relative_position = memory_position - context_position # shape (query_length, key_length) relative_position = memory_position - context_position # shape (query_length, key_length)
relative_position_bucket = self._relative_position_bucket( relative_position_bucket = self._relative_position_bucket(
@@ -360,94 +381,72 @@ class Pop2PianoAttention(nn.Module):
query_length=None, query_length=None,
use_cache=False, use_cache=False,
output_attentions=False, output_attentions=False,
cache_position=None,
): ):
""" """
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
""" """
# Input is (batch_size, seq_length, dim) # Input is (batch_size, seq_length, dim)
# Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder)
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
batch_size, seq_length = hidden_states.shape[:2] batch_size, seq_length = hidden_states.shape[:2]
real_seq_length = seq_length # if key_value_states are provided this layer is used as a cross-attention layer for the decoder
is_cross_attention = key_value_states is not None
query_states = self.q(hidden_states)
query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
if past_key_value is not None: if past_key_value is not None:
if len(past_key_value) != 2: is_updated = past_key_value.is_updated.get(self.layer_idx)
raise ValueError( if is_cross_attention:
f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" # after the first generated id, we can subsequently re-use all key/value_states from cache
) curr_past_key_value = past_key_value.cross_attention_cache
real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
def shape(states):
"""projection"""
return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
def unshape(states):
"""reshape"""
return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
def project(hidden_states, proj_layer, key_value_states, past_key_value):
"""projects hidden states correctly to key/query states"""
if key_value_states is None:
# self-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(hidden_states))
elif past_key_value is None:
# cross-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(key_value_states))
if past_key_value is not None:
if key_value_states is None:
# self-attn
# (batch_size, n_heads, key_length, dim_per_head)
hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
elif past_key_value.shape[2] != key_value_states.shape[1]:
# checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
# cross-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(key_value_states))
else: else:
# cross-attn curr_past_key_value = past_key_value.self_attention_cache
hidden_states = past_key_value
return hidden_states
# get query states current_states = key_value_states if is_cross_attention else hidden_states
query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) if is_cross_attention and past_key_value is not None and is_updated:
# reuse k,v, cross_attentions
key_states = curr_past_key_value.key_cache[self.layer_idx]
value_states = curr_past_key_value.value_cache[self.layer_idx]
else:
key_states = self.k(current_states)
value_states = self.v(current_states)
key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
# get key/value states if past_key_value is not None:
key_states = project( # save all key/value_states to cache to be re-used for fast auto-regressive generation
hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None cache_position = cache_position if not is_cross_attention else None
) key_states, value_states = curr_past_key_value.update(
value_states = project( key_states, value_states, self.layer_idx, {"cache_position": cache_position}
hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
) )
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
if is_cross_attention:
past_key_value.is_updated[self.layer_idx] = True
# compute scores # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
scores = torch.matmul( scores = torch.matmul(query_states, key_states.transpose(3, 2))
query_states, key_states.transpose(3, 2)
) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
if position_bias is None: if position_bias is None:
key_length = key_states.shape[-2]
# cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
real_seq_length = query_length if query_length is not None else cache_position[-1] + 1
if not self.has_relative_attention_bias: if not self.has_relative_attention_bias:
position_bias = torch.zeros( position_bias = torch.zeros(
(1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype
) )
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
position_bias.requires_grad = True position_bias.requires_grad = True
else: else:
position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) position_bias = self.compute_bias(
real_seq_length, key_length, device=scores.device, cache_position=cache_position
# if key and values are already calculated )
# we want only the last query position bias position_bias = position_bias[:, :, -seq_length:, :]
if past_key_value is not None:
position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
if mask is not None: if mask is not None:
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) causal_mask = mask[:, :, :, : key_states.shape[-2]]
position_bias = position_bias + causal_mask
if self.pruned_heads: if self.pruned_heads:
mask = torch.ones(position_bias.shape[1]) mask = torch.ones(position_bias.shape[1])
@@ -457,22 +456,22 @@ class Pop2PianoAttention(nn.Module):
position_bias_masked = position_bias position_bias_masked = position_bias
scores += position_bias_masked scores += position_bias_masked
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
scores # (batch_size, n_heads, seq_length, key_length)
) # (batch_size, n_heads, seq_length, key_length) attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
attn_weights = nn.functional.dropout( attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_weights, p=self.dropout, training=self.training
) # (batch_size, n_heads, seq_length, key_length)
# Mask heads if we want to # Mask heads if we want to
if layer_head_mask is not None: if layer_head_mask is not None:
attn_weights = attn_weights * layer_head_mask attn_weights = attn_weights * layer_head_mask
attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, -1, self.inner_dim)
attn_output = self.o(attn_output) attn_output = self.o(attn_output)
present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None outputs = (attn_output, past_key_value, position_bias)
outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
if output_attentions: if output_attentions:
outputs = outputs + (attn_weights,) outputs = outputs + (attn_weights,)
@@ -481,9 +480,11 @@ class Pop2PianoAttention(nn.Module):
# Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->Pop2Piano,t5->pop2piano # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->Pop2Piano,t5->pop2piano
class Pop2PianoLayerSelfAttention(nn.Module): class Pop2PianoLayerSelfAttention(nn.Module):
def __init__(self, config, has_relative_attention_bias=False): def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.SelfAttention = Pop2PianoAttention(config, has_relative_attention_bias=has_relative_attention_bias) self.SelfAttention = Pop2PianoAttention(
config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx
)
self.layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate) self.dropout = nn.Dropout(config.dropout_rate)
@@ -496,6 +497,7 @@ class Pop2PianoLayerSelfAttention(nn.Module):
past_key_value=None, past_key_value=None,
use_cache=False, use_cache=False,
output_attentions=False, output_attentions=False,
cache_position=None,
): ):
normed_hidden_states = self.layer_norm(hidden_states) normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.SelfAttention( attention_output = self.SelfAttention(
@@ -506,6 +508,7 @@ class Pop2PianoLayerSelfAttention(nn.Module):
past_key_value=past_key_value, past_key_value=past_key_value,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
cache_position=cache_position,
) )
hidden_states = hidden_states + self.dropout(attention_output[0]) hidden_states = hidden_states + self.dropout(attention_output[0])
outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
@@ -514,9 +517,9 @@ class Pop2PianoLayerSelfAttention(nn.Module):
# Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->Pop2Piano,t5->pop2piano # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->Pop2Piano,t5->pop2piano
class Pop2PianoLayerCrossAttention(nn.Module): class Pop2PianoLayerCrossAttention(nn.Module):
def __init__(self, config): def __init__(self, config, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.EncDecAttention = Pop2PianoAttention(config, has_relative_attention_bias=False) self.EncDecAttention = Pop2PianoAttention(config, has_relative_attention_bias=False, layer_idx=layer_idx)
self.layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate) self.dropout = nn.Dropout(config.dropout_rate)
@@ -531,6 +534,7 @@ class Pop2PianoLayerCrossAttention(nn.Module):
use_cache=False, use_cache=False,
query_length=None, query_length=None,
output_attentions=False, output_attentions=False,
cache_position=None,
): ):
normed_hidden_states = self.layer_norm(hidden_states) normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.EncDecAttention( attention_output = self.EncDecAttention(
@@ -543,6 +547,7 @@ class Pop2PianoLayerCrossAttention(nn.Module):
use_cache=use_cache, use_cache=use_cache,
query_length=query_length, query_length=query_length,
output_attentions=output_attentions, output_attentions=output_attentions,
cache_position=cache_position,
) )
layer_output = hidden_states + self.dropout(attention_output[0]) layer_output = hidden_states + self.dropout(attention_output[0])
outputs = (layer_output,) + attention_output[1:] # add attentions if we output them outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
@@ -551,13 +556,17 @@ class Pop2PianoLayerCrossAttention(nn.Module):
# Copied from transformers.models.t5.modeling_t5.T5Block with T5->Pop2Piano,t5->pop2piano # Copied from transformers.models.t5.modeling_t5.T5Block with T5->Pop2Piano,t5->pop2piano
class Pop2PianoBlock(nn.Module): class Pop2PianoBlock(nn.Module):
def __init__(self, config, has_relative_attention_bias=False): def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
self.layer = nn.ModuleList() self.layer = nn.ModuleList()
self.layer.append(Pop2PianoLayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) self.layer.append(
Pop2PianoLayerSelfAttention(
config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx
)
)
if self.is_decoder: if self.is_decoder:
self.layer.append(Pop2PianoLayerCrossAttention(config)) self.layer.append(Pop2PianoLayerCrossAttention(config, layer_idx=layer_idx))
self.layer.append(Pop2PianoLayerFF(config)) self.layer.append(Pop2PianoLayerFF(config))
@@ -575,34 +584,19 @@ class Pop2PianoBlock(nn.Module):
use_cache=False, use_cache=False,
output_attentions=False, output_attentions=False,
return_dict=True, return_dict=True,
cache_position=None,
): ):
if past_key_value is not None:
if not self.is_decoder:
logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.")
expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
if len(past_key_value) != expected_num_past_key_values:
raise ValueError(
f"There should be {expected_num_past_key_values} past states. "
f"{'2 (key / value) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
f"Got {len(past_key_value)} past key / value states"
)
self_attn_past_key_value = past_key_value[:2]
cross_attn_past_key_value = past_key_value[2:]
else:
self_attn_past_key_value, cross_attn_past_key_value = None, None
self_attention_outputs = self.layer[0]( self_attention_outputs = self.layer[0](
hidden_states, hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
position_bias=position_bias, position_bias=position_bias,
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
past_key_value=self_attn_past_key_value, past_key_value=past_key_value,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
cache_position=cache_position,
) )
hidden_states, present_key_value_state = self_attention_outputs[:2] hidden_states, past_key_value = self_attention_outputs[:2]
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
# clamp inf values to enable fp16 training # clamp inf values to enable fp16 training
@@ -616,25 +610,18 @@ class Pop2PianoBlock(nn.Module):
do_cross_attention = self.is_decoder and encoder_hidden_states is not None do_cross_attention = self.is_decoder and encoder_hidden_states is not None
if do_cross_attention: if do_cross_attention:
# the actual query length is unknown for cross attention
# if using past key value states. Need to inject it here
if present_key_value_state is not None:
query_length = present_key_value_state[0].shape[2]
else:
query_length = None
cross_attention_outputs = self.layer[1]( cross_attention_outputs = self.layer[1](
hidden_states, hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
position_bias=encoder_decoder_position_bias, position_bias=encoder_decoder_position_bias,
layer_head_mask=cross_attn_layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=past_key_value,
query_length=query_length, query_length=cache_position[-1] + 1,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = cross_attention_outputs[0] hidden_states, past_key_value = cross_attention_outputs[:2]
# clamp inf values to enable fp16 training # clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16: if hidden_states.dtype == torch.float16:
@@ -645,10 +632,6 @@ class Pop2PianoBlock(nn.Module):
) )
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
# Combine self attn and cross attn key value states
if present_key_value_state is not None:
present_key_value_state = present_key_value_state + cross_attention_outputs[1]
# Keep cross-attention outputs and relative position weights # Keep cross-attention outputs and relative position weights
attention_outputs = attention_outputs + cross_attention_outputs[2:] attention_outputs = attention_outputs + cross_attention_outputs[2:]
@@ -667,11 +650,11 @@ class Pop2PianoBlock(nn.Module):
outputs = (hidden_states,) outputs = (hidden_states,)
if use_cache: if use_cache:
outputs = outputs + (present_key_value_state,) + attention_outputs outputs = outputs + (past_key_value,) + attention_outputs
else: else:
outputs = outputs + attention_outputs outputs = outputs + attention_outputs
return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) return outputs # hidden-states, past_key_value, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
class Pop2PianoPreTrainedModel(PreTrainedModel): class Pop2PianoPreTrainedModel(PreTrainedModel):
@@ -684,6 +667,8 @@ class Pop2PianoPreTrainedModel(PreTrainedModel):
base_model_prefix = "transformer" base_model_prefix = "transformer"
is_parallelizable = False is_parallelizable = False
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_supports_cache_class = True
_supports_static_cache = False
_no_split_modules = ["Pop2PianoBlock"] _no_split_modules = ["Pop2PianoBlock"]
_keep_in_fp32_modules = ["wo"] _keep_in_fp32_modules = ["wo"]
@@ -769,7 +754,10 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel):
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
self.block = nn.ModuleList( self.block = nn.ModuleList(
[Pop2PianoBlock(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] [
Pop2PianoBlock(config, has_relative_attention_bias=bool(i == 0), layer_idx=i)
for i in range(config.num_layers)
]
) )
self.final_layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.final_layer_norm = Pop2PianoLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate) self.dropout = nn.Dropout(config.dropout_rate)
@@ -803,6 +791,7 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel):
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
cache_position=None,
): ):
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -825,6 +814,13 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel):
err_msg_prefix = "decoder_" if self.is_decoder else "" err_msg_prefix = "decoder_" if self.is_decoder else ""
raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if inputs_embeds is None: if inputs_embeds is None:
if self.embed_tokens is None: if self.embed_tokens is None:
raise ValueError("You have to initialize the model with valid token embeddings") raise ValueError("You have to initialize the model with valid token embeddings")
@@ -832,28 +828,55 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel):
batch_size, seq_length = input_shape batch_size, seq_length = input_shape
# required mask seq length can be calculated via length of past
mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
if use_cache is True: if use_cache is True:
if not self.is_decoder: if not self.is_decoder:
raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
if attention_mask is None: # initialize past_key_values
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) return_legacy_cache = False
if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: return_self_attention_cache = False
encoder_seq_length = encoder_hidden_states.shape[1] if self.is_decoder and (use_cache or past_key_values is not None):
encoder_attention_mask = torch.ones( if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache):
batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long return_self_attention_cache = True
past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
elif not isinstance(past_key_values, EncoderDecoderCache):
return_legacy_cache = True
logger.warning_once(
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. "
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "
"`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
)
past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
elif past_key_values is None:
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
elif not self.is_decoder:
# do not pass cache object down the line for encoder stack
# it messes indexing later in decoder-stack because cache object is modified in-place
past_key_values = None
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
if cache_position is None:
cache_position = torch.arange(
past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
) )
# initialize past_key_values with `None` if past does not exist if attention_mask is None and not is_torchdynamo_compiling():
if past_key_values is None: # required mask seq length can be calculated via length of past cache
past_key_values = [None] * len(self.block) mask_seq_length = past_key_values_length + seq_length
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] if self.config.is_decoder:
# ourselves in which case we just need to make it broadcastable to all heads. causal_mask = self._update_causal_mask(
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) attention_mask,
inputs_embeds,
cache_position,
past_key_values.self_attention_cache if past_key_values is not None else None,
output_attentions,
)
else:
causal_mask = attention_mask[:, None, None, :]
causal_mask = causal_mask.to(dtype=inputs_embeds.dtype)
causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min
# If a 2D or 3D attention mask is provided for the cross-attention # If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
@@ -866,17 +889,9 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel):
else: else:
encoder_extended_attention_mask = None encoder_extended_attention_mask = None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# Prepare head mask if needed # Prepare head mask if needed
head_mask = self.get_head_mask(head_mask, self.config.num_layers) head_mask = self.get_head_mask(head_mask, self.config.num_layers)
cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
present_key_value_states = () if use_cache else None
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
all_cross_attentions = () if (output_attentions and self.is_decoder) else None all_cross_attentions = () if (output_attentions and self.is_decoder) else None
@@ -885,7 +900,7 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel):
hidden_states = self.dropout(inputs_embeds) hidden_states = self.dropout(inputs_embeds)
for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): for i, layer_module in enumerate(self.block):
layer_head_mask = head_mask[i] layer_head_mask = head_mask[i]
cross_attn_layer_head_mask = cross_attn_head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i]
if output_hidden_states: if output_hidden_states:
@@ -895,7 +910,7 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel):
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.forward, layer_module.forward,
hidden_states, hidden_states,
extended_attention_mask, causal_mask,
position_bias, position_bias,
encoder_hidden_states, encoder_hidden_states,
encoder_extended_attention_mask, encoder_extended_attention_mask,
@@ -905,20 +920,22 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel):
None, # past_key_value is always None with gradient checkpointing None, # past_key_value is always None with gradient checkpointing
use_cache, use_cache,
output_attentions, output_attentions,
cache_position,
) )
else: else:
layer_outputs = layer_module( layer_outputs = layer_module(
hidden_states, hidden_states,
attention_mask=extended_attention_mask, attention_mask=causal_mask,
position_bias=position_bias, position_bias=position_bias,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask, encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias, encoder_decoder_position_bias=encoder_decoder_position_bias,
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
cross_attn_layer_head_mask=cross_attn_layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask,
past_key_value=past_key_value, past_key_value=past_key_values,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
cache_position=cache_position,
) )
# layer_outputs is a tuple with: # layer_outputs is a tuple with:
@@ -926,7 +943,7 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel):
if use_cache is False: if use_cache is False:
layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
hidden_states, present_key_value_state = layer_outputs[:2] hidden_states, next_decoder_cache = layer_outputs[:2]
# We share the position biases between the layers - the first layer store them # We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
@@ -934,9 +951,6 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel):
position_bias = layer_outputs[2] position_bias = layer_outputs[2]
if self.is_decoder and encoder_hidden_states is not None: if self.is_decoder and encoder_hidden_states is not None:
encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
# append next layer key value states
if use_cache:
present_key_value_states = present_key_value_states + (present_key_value_state,)
if output_attentions: if output_attentions:
all_attentions = all_attentions + (layer_outputs[3],) all_attentions = all_attentions + (layer_outputs[3],)
@@ -950,12 +964,18 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if return_self_attention_cache:
next_cache = past_key_values.self_attention_cache
if return_legacy_cache:
next_cache = past_key_values.to_legacy_cache()
if not return_dict: if not return_dict:
return tuple( return tuple(
v v
for v in [ for v in [
hidden_states, hidden_states,
present_key_value_states, next_cache,
all_hidden_states, all_hidden_states,
all_attentions, all_attentions,
all_cross_attentions, all_cross_attentions,
@@ -964,12 +984,135 @@ class Pop2PianoStack(Pop2PianoPreTrainedModel):
) )
return BaseModelOutputWithPastAndCrossAttentions( return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=present_key_value_states, past_key_values=next_cache,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_attentions, attentions=all_attentions,
cross_attentions=all_cross_attentions, cross_attentions=all_cross_attentions,
) )
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
class Pop2PianoConcatEmbeddingToMel(nn.Module): class Pop2PianoConcatEmbeddingToMel(nn.Module):
"""Embedding Matrix for `composer` tokens.""" """Embedding Matrix for `composer` tokens."""
@@ -1122,6 +1265,7 @@ class Pop2PianoForConditionalGeneration(Pop2PianoPreTrainedModel, GenerationMixi
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -1177,6 +1321,7 @@ class Pop2PianoForConditionalGeneration(Pop2PianoPreTrainedModel, GenerationMixi
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
cache_position=cache_position,
) )
sequence_output = decoder_outputs[0] sequence_output = decoder_outputs[0]

View File

@@ -24,7 +24,9 @@ import torch.nn as nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import ( from ...modeling_outputs import (
MoEModelOutput, MoEModelOutput,
MoEModelOutputWithPastAndCrossAttentions, MoEModelOutputWithPastAndCrossAttentions,
@@ -39,6 +41,7 @@ from ...utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_torch_fx_proxy, is_torch_fx_proxy,
is_torchdynamo_compiling,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
@@ -355,7 +358,12 @@ class SwitchTransformersLayerFF(nn.Module):
# Copied from transformers.models.t5.modeling_t5.T5Attention with T5->SwitchTransformers # Copied from transformers.models.t5.modeling_t5.T5Attention with T5->SwitchTransformers
class SwitchTransformersAttention(nn.Module): class SwitchTransformersAttention(nn.Module):
def __init__(self, config: SwitchTransformersConfig, has_relative_attention_bias=False): def __init__(
self,
config: SwitchTransformersConfig,
has_relative_attention_bias=False,
layer_idx: Optional[int] = None,
):
super().__init__() super().__init__()
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
self.has_relative_attention_bias = has_relative_attention_bias self.has_relative_attention_bias = has_relative_attention_bias
@@ -366,6 +374,13 @@ class SwitchTransformersAttention(nn.Module):
self.n_heads = config.num_heads self.n_heads = config.num_heads
self.dropout = config.dropout_rate self.dropout = config.dropout_rate
self.inner_dim = self.n_heads * self.key_value_proj_dim self.inner_dim = self.n_heads * self.key_value_proj_dim
self.layer_idx = layer_idx
if layer_idx is None and self.is_decoder:
logger.warning_once(
f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
"will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
# Mesh TensorFlow initialization to avoid scaling before softmax # Mesh TensorFlow initialization to avoid scaling before softmax
self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
@@ -442,11 +457,14 @@ class SwitchTransformersAttention(nn.Module):
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
return relative_buckets return relative_buckets
def compute_bias(self, query_length, key_length, device=None): def compute_bias(self, query_length, key_length, device=None, cache_position=None):
"""Compute binned relative position bias""" """Compute binned relative position bias"""
if device is None: if device is None:
device = self.relative_attention_bias.weight.device device = self.relative_attention_bias.weight.device
if cache_position is None:
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
else:
context_position = cache_position[:, None].to(device)
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
relative_position = memory_position - context_position # shape (query_length, key_length) relative_position = memory_position - context_position # shape (query_length, key_length)
relative_position_bucket = self._relative_position_bucket( relative_position_bucket = self._relative_position_bucket(
@@ -470,94 +488,72 @@ class SwitchTransformersAttention(nn.Module):
query_length=None, query_length=None,
use_cache=False, use_cache=False,
output_attentions=False, output_attentions=False,
cache_position=None,
): ):
""" """
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
""" """
# Input is (batch_size, seq_length, dim) # Input is (batch_size, seq_length, dim)
# Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder)
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
batch_size, seq_length = hidden_states.shape[:2] batch_size, seq_length = hidden_states.shape[:2]
real_seq_length = seq_length # if key_value_states are provided this layer is used as a cross-attention layer for the decoder
is_cross_attention = key_value_states is not None
query_states = self.q(hidden_states)
query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
if past_key_value is not None: if past_key_value is not None:
if len(past_key_value) != 2: is_updated = past_key_value.is_updated.get(self.layer_idx)
raise ValueError( if is_cross_attention:
f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" # after the first generated id, we can subsequently re-use all key/value_states from cache
) curr_past_key_value = past_key_value.cross_attention_cache
real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
def shape(states):
"""projection"""
return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
def unshape(states):
"""reshape"""
return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
def project(hidden_states, proj_layer, key_value_states, past_key_value):
"""projects hidden states correctly to key/query states"""
if key_value_states is None:
# self-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(hidden_states))
elif past_key_value is None:
# cross-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(key_value_states))
if past_key_value is not None:
if key_value_states is None:
# self-attn
# (batch_size, n_heads, key_length, dim_per_head)
hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
elif past_key_value.shape[2] != key_value_states.shape[1]:
# checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
# cross-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(key_value_states))
else: else:
# cross-attn curr_past_key_value = past_key_value.self_attention_cache
hidden_states = past_key_value
return hidden_states
# get query states current_states = key_value_states if is_cross_attention else hidden_states
query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) if is_cross_attention and past_key_value is not None and is_updated:
# reuse k,v, cross_attentions
key_states = curr_past_key_value.key_cache[self.layer_idx]
value_states = curr_past_key_value.value_cache[self.layer_idx]
else:
key_states = self.k(current_states)
value_states = self.v(current_states)
key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
# get key/value states if past_key_value is not None:
key_states = project( # save all key/value_states to cache to be re-used for fast auto-regressive generation
hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None cache_position = cache_position if not is_cross_attention else None
) key_states, value_states = curr_past_key_value.update(
value_states = project( key_states, value_states, self.layer_idx, {"cache_position": cache_position}
hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
) )
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
if is_cross_attention:
past_key_value.is_updated[self.layer_idx] = True
# compute scores # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
scores = torch.matmul( scores = torch.matmul(query_states, key_states.transpose(3, 2))
query_states, key_states.transpose(3, 2)
) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
if position_bias is None: if position_bias is None:
key_length = key_states.shape[-2]
# cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
real_seq_length = query_length if query_length is not None else cache_position[-1] + 1
if not self.has_relative_attention_bias: if not self.has_relative_attention_bias:
position_bias = torch.zeros( position_bias = torch.zeros(
(1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype
) )
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
position_bias.requires_grad = True position_bias.requires_grad = True
else: else:
position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) position_bias = self.compute_bias(
real_seq_length, key_length, device=scores.device, cache_position=cache_position
# if key and values are already calculated )
# we want only the last query position bias position_bias = position_bias[:, :, -seq_length:, :]
if past_key_value is not None:
position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
if mask is not None: if mask is not None:
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) causal_mask = mask[:, :, :, : key_states.shape[-2]]
position_bias = position_bias + causal_mask
if self.pruned_heads: if self.pruned_heads:
mask = torch.ones(position_bias.shape[1]) mask = torch.ones(position_bias.shape[1])
@@ -567,22 +563,22 @@ class SwitchTransformersAttention(nn.Module):
position_bias_masked = position_bias position_bias_masked = position_bias
scores += position_bias_masked scores += position_bias_masked
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
scores # (batch_size, n_heads, seq_length, key_length)
) # (batch_size, n_heads, seq_length, key_length) attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
attn_weights = nn.functional.dropout( attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_weights, p=self.dropout, training=self.training
) # (batch_size, n_heads, seq_length, key_length)
# Mask heads if we want to # Mask heads if we want to
if layer_head_mask is not None: if layer_head_mask is not None:
attn_weights = attn_weights * layer_head_mask attn_weights = attn_weights * layer_head_mask
attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, -1, self.inner_dim)
attn_output = self.o(attn_output) attn_output = self.o(attn_output)
present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None outputs = (attn_output, past_key_value, position_bias)
outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
if output_attentions: if output_attentions:
outputs = outputs + (attn_weights,) outputs = outputs + (attn_weights,)
@@ -591,10 +587,10 @@ class SwitchTransformersAttention(nn.Module):
# Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->SwitchTransformers # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->SwitchTransformers
class SwitchTransformersLayerSelfAttention(nn.Module): class SwitchTransformersLayerSelfAttention(nn.Module):
def __init__(self, config, has_relative_attention_bias=False): def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.SelfAttention = SwitchTransformersAttention( self.SelfAttention = SwitchTransformersAttention(
config, has_relative_attention_bias=has_relative_attention_bias config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx
) )
self.layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate) self.dropout = nn.Dropout(config.dropout_rate)
@@ -608,6 +604,7 @@ class SwitchTransformersLayerSelfAttention(nn.Module):
past_key_value=None, past_key_value=None,
use_cache=False, use_cache=False,
output_attentions=False, output_attentions=False,
cache_position=None,
): ):
normed_hidden_states = self.layer_norm(hidden_states) normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.SelfAttention( attention_output = self.SelfAttention(
@@ -618,6 +615,7 @@ class SwitchTransformersLayerSelfAttention(nn.Module):
past_key_value=past_key_value, past_key_value=past_key_value,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
cache_position=cache_position,
) )
hidden_states = hidden_states + self.dropout(attention_output[0]) hidden_states = hidden_states + self.dropout(attention_output[0])
outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
@@ -626,9 +624,11 @@ class SwitchTransformersLayerSelfAttention(nn.Module):
# Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->SwitchTransformers # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->SwitchTransformers
class SwitchTransformersLayerCrossAttention(nn.Module): class SwitchTransformersLayerCrossAttention(nn.Module):
def __init__(self, config): def __init__(self, config, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.EncDecAttention = SwitchTransformersAttention(config, has_relative_attention_bias=False) self.EncDecAttention = SwitchTransformersAttention(
config, has_relative_attention_bias=False, layer_idx=layer_idx
)
self.layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate) self.dropout = nn.Dropout(config.dropout_rate)
@@ -643,6 +643,7 @@ class SwitchTransformersLayerCrossAttention(nn.Module):
use_cache=False, use_cache=False,
query_length=None, query_length=None,
output_attentions=False, output_attentions=False,
cache_position=None,
): ):
normed_hidden_states = self.layer_norm(hidden_states) normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.EncDecAttention( attention_output = self.EncDecAttention(
@@ -655,6 +656,7 @@ class SwitchTransformersLayerCrossAttention(nn.Module):
use_cache=use_cache, use_cache=use_cache,
query_length=query_length, query_length=query_length,
output_attentions=output_attentions, output_attentions=output_attentions,
cache_position=cache_position,
) )
layer_output = hidden_states + self.dropout(attention_output[0]) layer_output = hidden_states + self.dropout(attention_output[0])
outputs = (layer_output,) + attention_output[1:] # add attentions if we output them outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
@@ -662,16 +664,18 @@ class SwitchTransformersLayerCrossAttention(nn.Module):
class SwitchTransformersBlock(nn.Module): class SwitchTransformersBlock(nn.Module):
def __init__(self, config, has_relative_attention_bias=False, is_sparse=False): def __init__(self, config, has_relative_attention_bias=False, is_sparse=False, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
self.is_sparse = is_sparse self.is_sparse = is_sparse
self.layer = nn.ModuleList() self.layer = nn.ModuleList()
self.layer.append( self.layer.append(
SwitchTransformersLayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias) SwitchTransformersLayerSelfAttention(
config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx
)
) )
if self.is_decoder: if self.is_decoder:
self.layer.append(SwitchTransformersLayerCrossAttention(config)) self.layer.append(SwitchTransformersLayerCrossAttention(config, layer_idx=layer_idx))
self.layer.append(SwitchTransformersLayerFF(config, is_sparse=self.is_sparse)) self.layer.append(SwitchTransformersLayerFF(config, is_sparse=self.is_sparse))
@@ -690,34 +694,19 @@ class SwitchTransformersBlock(nn.Module):
output_attentions=False, output_attentions=False,
output_router_logits=True, output_router_logits=True,
return_dict=True, return_dict=True,
cache_position=None,
): ):
if past_key_value is not None:
if not self.is_decoder:
logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.")
expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
if len(past_key_value) != expected_num_past_key_values:
raise ValueError(
f"There should be {expected_num_past_key_values} past states. "
f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
f"Got {len(past_key_value)} past key / value states"
)
self_attn_past_key_value = past_key_value[:2]
cross_attn_past_key_value = past_key_value[2:]
else:
self_attn_past_key_value, cross_attn_past_key_value = None, None
self_attention_outputs = self.layer[0]( self_attention_outputs = self.layer[0](
hidden_states, hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
position_bias=position_bias, position_bias=position_bias,
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
past_key_value=self_attn_past_key_value, past_key_value=past_key_value,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
cache_position=cache_position,
) )
hidden_states, present_key_value_state = self_attention_outputs[:2] hidden_states, past_key_value = self_attention_outputs[:2]
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
# clamp inf values to enable fp16 training # clamp inf values to enable fp16 training
@@ -727,35 +716,25 @@ class SwitchTransformersBlock(nn.Module):
do_cross_attention = self.is_decoder and encoder_hidden_states is not None do_cross_attention = self.is_decoder and encoder_hidden_states is not None
if do_cross_attention: if do_cross_attention:
# the actual query length is unknown for cross attention
# if using past key value states. Need to inject it here
if present_key_value_state is not None:
query_length = present_key_value_state[0].shape[2]
else:
query_length = None
cross_attention_outputs = self.layer[1]( cross_attention_outputs = self.layer[1](
hidden_states, hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
position_bias=encoder_decoder_position_bias, position_bias=encoder_decoder_position_bias,
layer_head_mask=cross_attn_layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=past_key_value,
query_length=query_length, query_length=cache_position[-1] + 1,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
cache_position=cache_position,
) )
hidden_states = cross_attention_outputs[0] hidden_states, past_key_value = cross_attention_outputs[:2]
# clamp inf values to enable fp16 training # clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
clamp_value = torch.finfo(hidden_states.dtype).max - 1000 clamp_value = torch.finfo(hidden_states.dtype).max - 1000
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
# Combine self attn and cross attn key value states
if present_key_value_state is not None:
present_key_value_state = present_key_value_state + cross_attention_outputs[1]
# Keep cross-attention outputs and relative position weights # Keep cross-attention outputs and relative position weights
attention_outputs = attention_outputs + cross_attention_outputs[2:] attention_outputs = attention_outputs + cross_attention_outputs[2:]
@@ -775,11 +754,11 @@ class SwitchTransformersBlock(nn.Module):
outputs = (hidden_states,) outputs = (hidden_states,)
if use_cache: if use_cache:
outputs = outputs + (present_key_value_state,) + attention_outputs + (router_tuple,) outputs = outputs + (past_key_value,) + attention_outputs + (router_tuple,)
else: else:
outputs = outputs + attention_outputs + (router_tuple,) outputs = outputs + attention_outputs + (router_tuple,)
return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights), (router_tuple) return outputs # hidden-states, past_key_value, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights), (router_tuple)
class SwitchTransformersPreTrainedModel(PreTrainedModel): class SwitchTransformersPreTrainedModel(PreTrainedModel):
@@ -791,6 +770,8 @@ class SwitchTransformersPreTrainedModel(PreTrainedModel):
config_class = SwitchTransformersConfig config_class = SwitchTransformersConfig
base_model_prefix = "switch_transformers" base_model_prefix = "switch_transformers"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_supports_cache_class = True
_supports_static_cache = False
_no_split_modules = ["SwitchTransformersBlock"] _no_split_modules = ["SwitchTransformersBlock"]
@property @property
@@ -897,7 +878,9 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel):
is_sparse = (i % sparse_step == 1 or sparse_step == 1) if sparse_step > 0 else False is_sparse = (i % sparse_step == 1 or sparse_step == 1) if sparse_step > 0 else False
self.block.append( self.block.append(
SwitchTransformersBlock(config, has_relative_attention_bias=bool(i == 0), is_sparse=is_sparse) SwitchTransformersBlock(
config, has_relative_attention_bias=bool(i == 0), is_sparse=is_sparse, layer_idx=i
)
) )
self.final_layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.final_layer_norm = SwitchTransformersLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
@@ -930,6 +913,7 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel):
output_hidden_states=None, output_hidden_states=None,
output_router_logits=True, output_router_logits=True,
return_dict=None, return_dict=None,
cache_position=None,
): ):
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -952,6 +936,13 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel):
err_msg_prefix = "decoder_" if self.is_decoder else "" err_msg_prefix = "decoder_" if self.is_decoder else ""
raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if inputs_embeds is None: if inputs_embeds is None:
if self.embed_tokens is None: if self.embed_tokens is None:
raise ValueError("You have to initialize the model with valid token embeddings") raise ValueError("You have to initialize the model with valid token embeddings")
@@ -959,28 +950,55 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel):
batch_size, seq_length = input_shape batch_size, seq_length = input_shape
# required mask seq length can be calculated via length of past
mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
if use_cache is True: if use_cache is True:
if not self.is_decoder: if not self.is_decoder:
raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
if attention_mask is None: # initialize past_key_values
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) return_legacy_cache = False
if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: return_self_attention_cache = False
encoder_seq_length = encoder_hidden_states.shape[1] if self.is_decoder and (use_cache or past_key_values is not None):
encoder_attention_mask = torch.ones( if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache):
batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long return_self_attention_cache = True
past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
elif not isinstance(past_key_values, EncoderDecoderCache):
return_legacy_cache = True
logger.warning_once(
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. "
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "
"`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
)
past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
elif past_key_values is None:
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
elif not self.is_decoder:
# do not pass cache object down the line for encoder stack
# it messes indexing later in decoder-stack because cache object is modified in-place
past_key_values = None
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
if cache_position is None:
cache_position = torch.arange(
past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
) )
# initialize past_key_values with `None` if past does not exist if attention_mask is None and not is_torchdynamo_compiling():
if past_key_values is None: # required mask seq length can be calculated via length of past cache
past_key_values = [None] * len(self.block) mask_seq_length = past_key_values_length + seq_length
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] if self.config.is_decoder:
# ourselves in which case we just need to make it broadcastable to all heads. causal_mask = self._update_causal_mask(
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) attention_mask,
inputs_embeds,
cache_position,
past_key_values.self_attention_cache if past_key_values is not None else None,
output_attentions,
)
else:
causal_mask = attention_mask[:, None, None, :]
causal_mask = causal_mask.to(dtype=inputs_embeds.dtype)
causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min
# If a 2D or 3D attention mask is provided for the cross-attention # If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
@@ -993,17 +1011,9 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel):
else: else:
encoder_extended_attention_mask = None encoder_extended_attention_mask = None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# Prepare head mask if needed # Prepare head mask if needed
head_mask = self.get_head_mask(head_mask, self.config.num_layers) head_mask = self.get_head_mask(head_mask, self.config.num_layers)
cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
present_key_value_states = () if use_cache else None
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
all_router_probs = () if output_router_logits else None all_router_probs = () if output_router_logits else None
@@ -1013,7 +1023,7 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel):
hidden_states = self.dropout(inputs_embeds) hidden_states = self.dropout(inputs_embeds)
for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): for i, layer_module in enumerate(self.block):
layer_head_mask = head_mask[i] layer_head_mask = head_mask[i]
cross_attn_layer_head_mask = cross_attn_head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i]
@@ -1024,7 +1034,7 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel):
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.forward, layer_module.forward,
hidden_states, hidden_states,
extended_attention_mask, causal_mask,
position_bias, position_bias,
encoder_hidden_states, encoder_hidden_states,
encoder_extended_attention_mask, encoder_extended_attention_mask,
@@ -1034,21 +1044,26 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel):
None, # past_key_value is always None with gradient checkpointing None, # past_key_value is always None with gradient checkpointing
use_cache, use_cache,
output_attentions, output_attentions,
output_router_logits,
return_dict,
cache_position,
) )
else: else:
layer_outputs = layer_module( layer_outputs = layer_module(
hidden_states, hidden_states,
attention_mask=extended_attention_mask, attention_mask=causal_mask,
position_bias=position_bias, position_bias=position_bias,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask, encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias, encoder_decoder_position_bias=encoder_decoder_position_bias,
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
cross_attn_layer_head_mask=cross_attn_layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask,
past_key_value=past_key_value, past_key_value=past_key_values,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_router_logits=output_router_logits, output_router_logits=output_router_logits,
return_dict=return_dict,
cache_position=cache_position,
) )
router_probs = layer_outputs[-1] router_probs = layer_outputs[-1]
@@ -1059,7 +1074,7 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel):
if use_cache is False: if use_cache is False:
layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
hidden_states, present_key_value_state = layer_outputs[:2] hidden_states, next_decoder_cache = layer_outputs[:2]
# We share the position biases between the layers - the first layer store them # We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
@@ -1067,9 +1082,6 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel):
position_bias = layer_outputs[2] position_bias = layer_outputs[2]
if self.is_decoder and encoder_hidden_states is not None: if self.is_decoder and encoder_hidden_states is not None:
encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
# append next layer key value states
if use_cache:
present_key_value_states = present_key_value_states + (present_key_value_state,)
if output_attentions: if output_attentions:
all_attentions = all_attentions + (layer_outputs[3],) all_attentions = all_attentions + (layer_outputs[3],)
@@ -1086,12 +1098,18 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if return_self_attention_cache:
next_cache = past_key_values.self_attention_cache
if return_legacy_cache:
next_cache = past_key_values.to_legacy_cache()
if not return_dict: if not return_dict:
return tuple( return tuple(
v v
for v in [ for v in [
hidden_states, hidden_states,
present_key_value_states, next_cache,
all_hidden_states, all_hidden_states,
all_attentions, all_attentions,
all_cross_attentions, all_cross_attentions,
@@ -1101,13 +1119,136 @@ class SwitchTransformersStack(SwitchTransformersPreTrainedModel):
) )
return MoEModelOutputWithPastAndCrossAttentions( return MoEModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=present_key_value_states, past_key_values=next_cache,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_attentions, attentions=all_attentions,
cross_attentions=all_cross_attentions, cross_attentions=all_cross_attentions,
router_probs=all_router_probs, router_probs=all_router_probs,
) )
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
SWITCH_TRANSFORMERS_START_DOCSTRING = r""" SWITCH_TRANSFORMERS_START_DOCSTRING = r"""
@@ -1228,6 +1369,9 @@ SWITCH_TRANSFORMERS_INPUTS_DOCSTRING = r"""
should not be returned during inference. should not be returned during inference.
return_dict (`bool`, *optional*): return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
cache in the correct position and to infer the complete sequence length.
""" """
SWITCH_TRANSFORMERS_ENCODER_INPUTS_DOCSTRING = r""" SWITCH_TRANSFORMERS_ENCODER_INPUTS_DOCSTRING = r"""
@@ -1355,6 +1499,7 @@ class SwitchTransformersModel(SwitchTransformersPreTrainedModel):
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = None, output_router_logits: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple[torch.FloatTensor], Seq2SeqMoEModelOutput]: ) -> Union[Tuple[torch.FloatTensor], Seq2SeqMoEModelOutput]:
r""" r"""
Returns: Returns:
@@ -1435,6 +1580,7 @@ class SwitchTransformersModel(SwitchTransformersPreTrainedModel):
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
output_router_logits=output_router_logits, output_router_logits=output_router_logits,
return_dict=return_dict, return_dict=return_dict,
cache_position=cache_position,
) )
if not return_dict: if not return_dict:
@@ -1535,6 +1681,7 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
output_router_logits: Optional[bool] = True, output_router_logits: Optional[bool] = True,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple[torch.FloatTensor], Seq2SeqMoEOutput]: ) -> Union[Tuple[torch.FloatTensor], Seq2SeqMoEOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -1618,6 +1765,7 @@ class SwitchTransformersForConditionalGeneration(SwitchTransformersPreTrainedMod
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
output_router_logits=output_router_logits, output_router_logits=output_router_logits,
return_dict=return_dict, return_dict=return_dict,
cache_position=cache_position,
) )
sequence_output = decoder_outputs[0] sequence_output = decoder_outputs[0]

View File

@@ -73,7 +73,12 @@ class T5Config(PretrainedConfig):
model_type = "t5" model_type = "t5"
keys_to_ignore_at_inference = ["past_key_values"] keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"} attribute_map = {
"hidden_size": "d_model",
"num_attention_heads": "num_heads",
"num_hidden_layers": "num_layers",
"head_dim": "d_kv",
}
def __init__( def __init__(
self, self,

View File

@@ -25,7 +25,9 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
@@ -43,6 +45,7 @@ from ...utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_torch_fx_proxy, is_torch_fx_proxy,
is_torchdynamo_compiling,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
@@ -339,7 +342,12 @@ class T5LayerFF(nn.Module):
class T5Attention(nn.Module): class T5Attention(nn.Module):
def __init__(self, config: T5Config, has_relative_attention_bias=False): def __init__(
self,
config: T5Config,
has_relative_attention_bias=False,
layer_idx: Optional[int] = None,
):
super().__init__() super().__init__()
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
self.has_relative_attention_bias = has_relative_attention_bias self.has_relative_attention_bias = has_relative_attention_bias
@@ -350,6 +358,13 @@ class T5Attention(nn.Module):
self.n_heads = config.num_heads self.n_heads = config.num_heads
self.dropout = config.dropout_rate self.dropout = config.dropout_rate
self.inner_dim = self.n_heads * self.key_value_proj_dim self.inner_dim = self.n_heads * self.key_value_proj_dim
self.layer_idx = layer_idx
if layer_idx is None and self.is_decoder:
logger.warning_once(
f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
"will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
# Mesh TensorFlow initialization to avoid scaling before softmax # Mesh TensorFlow initialization to avoid scaling before softmax
self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
@@ -426,11 +441,14 @@ class T5Attention(nn.Module):
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
return relative_buckets return relative_buckets
def compute_bias(self, query_length, key_length, device=None): def compute_bias(self, query_length, key_length, device=None, cache_position=None):
"""Compute binned relative position bias""" """Compute binned relative position bias"""
if device is None: if device is None:
device = self.relative_attention_bias.weight.device device = self.relative_attention_bias.weight.device
if cache_position is None:
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
else:
context_position = cache_position[:, None].to(device)
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
relative_position = memory_position - context_position # shape (query_length, key_length) relative_position = memory_position - context_position # shape (query_length, key_length)
relative_position_bucket = self._relative_position_bucket( relative_position_bucket = self._relative_position_bucket(
@@ -454,94 +472,72 @@ class T5Attention(nn.Module):
query_length=None, query_length=None,
use_cache=False, use_cache=False,
output_attentions=False, output_attentions=False,
cache_position=None,
): ):
""" """
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
""" """
# Input is (batch_size, seq_length, dim) # Input is (batch_size, seq_length, dim)
# Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder)
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
batch_size, seq_length = hidden_states.shape[:2] batch_size, seq_length = hidden_states.shape[:2]
real_seq_length = seq_length # if key_value_states are provided this layer is used as a cross-attention layer for the decoder
is_cross_attention = key_value_states is not None
query_states = self.q(hidden_states)
query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
if past_key_value is not None: if past_key_value is not None:
if len(past_key_value) != 2: is_updated = past_key_value.is_updated.get(self.layer_idx)
raise ValueError( if is_cross_attention:
f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" # after the first generated id, we can subsequently re-use all key/value_states from cache
) curr_past_key_value = past_key_value.cross_attention_cache
real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
def shape(states):
"""projection"""
return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
def unshape(states):
"""reshape"""
return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
def project(hidden_states, proj_layer, key_value_states, past_key_value):
"""projects hidden states correctly to key/query states"""
if key_value_states is None:
# self-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(hidden_states))
elif past_key_value is None:
# cross-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(key_value_states))
if past_key_value is not None:
if key_value_states is None:
# self-attn
# (batch_size, n_heads, key_length, dim_per_head)
hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
elif past_key_value.shape[2] != key_value_states.shape[1]:
# checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
# cross-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(key_value_states))
else: else:
# cross-attn curr_past_key_value = past_key_value.self_attention_cache
hidden_states = past_key_value
return hidden_states
# get query states current_states = key_value_states if is_cross_attention else hidden_states
query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) if is_cross_attention and past_key_value is not None and is_updated:
# reuse k,v, cross_attentions
key_states = curr_past_key_value.key_cache[self.layer_idx]
value_states = curr_past_key_value.value_cache[self.layer_idx]
else:
key_states = self.k(current_states)
value_states = self.v(current_states)
key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
# get key/value states if past_key_value is not None:
key_states = project( # save all key/value_states to cache to be re-used for fast auto-regressive generation
hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None cache_position = cache_position if not is_cross_attention else None
) key_states, value_states = curr_past_key_value.update(
value_states = project( key_states, value_states, self.layer_idx, {"cache_position": cache_position}
hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
) )
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
if is_cross_attention:
past_key_value.is_updated[self.layer_idx] = True
# compute scores # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
scores = torch.matmul( scores = torch.matmul(query_states, key_states.transpose(3, 2))
query_states, key_states.transpose(3, 2)
) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
if position_bias is None: if position_bias is None:
key_length = key_states.shape[-2]
# cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
real_seq_length = query_length if query_length is not None else cache_position[-1] + 1
if not self.has_relative_attention_bias: if not self.has_relative_attention_bias:
position_bias = torch.zeros( position_bias = torch.zeros(
(1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype
) )
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
position_bias.requires_grad = True position_bias.requires_grad = True
else: else:
position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) position_bias = self.compute_bias(
real_seq_length, key_length, device=scores.device, cache_position=cache_position
# if key and values are already calculated )
# we want only the last query position bias position_bias = position_bias[:, :, -seq_length:, :]
if past_key_value is not None:
position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
if mask is not None: if mask is not None:
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) causal_mask = mask[:, :, :, : key_states.shape[-2]]
position_bias = position_bias + causal_mask
if self.pruned_heads: if self.pruned_heads:
mask = torch.ones(position_bias.shape[1]) mask = torch.ones(position_bias.shape[1])
@@ -551,22 +547,22 @@ class T5Attention(nn.Module):
position_bias_masked = position_bias position_bias_masked = position_bias
scores += position_bias_masked scores += position_bias_masked
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
scores # (batch_size, n_heads, seq_length, key_length)
) # (batch_size, n_heads, seq_length, key_length) attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
attn_weights = nn.functional.dropout( attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_weights, p=self.dropout, training=self.training
) # (batch_size, n_heads, seq_length, key_length)
# Mask heads if we want to # Mask heads if we want to
if layer_head_mask is not None: if layer_head_mask is not None:
attn_weights = attn_weights * layer_head_mask attn_weights = attn_weights * layer_head_mask
attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, -1, self.inner_dim)
attn_output = self.o(attn_output) attn_output = self.o(attn_output)
present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None outputs = (attn_output, past_key_value, position_bias)
outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
if output_attentions: if output_attentions:
outputs = outputs + (attn_weights,) outputs = outputs + (attn_weights,)
@@ -574,9 +570,11 @@ class T5Attention(nn.Module):
class T5LayerSelfAttention(nn.Module): class T5LayerSelfAttention(nn.Module):
def __init__(self, config, has_relative_attention_bias=False): def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias) self.SelfAttention = T5Attention(
config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx
)
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate) self.dropout = nn.Dropout(config.dropout_rate)
@@ -589,6 +587,7 @@ class T5LayerSelfAttention(nn.Module):
past_key_value=None, past_key_value=None,
use_cache=False, use_cache=False,
output_attentions=False, output_attentions=False,
cache_position=None,
): ):
normed_hidden_states = self.layer_norm(hidden_states) normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.SelfAttention( attention_output = self.SelfAttention(
@@ -599,6 +598,7 @@ class T5LayerSelfAttention(nn.Module):
past_key_value=past_key_value, past_key_value=past_key_value,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
cache_position=cache_position,
) )
hidden_states = hidden_states + self.dropout(attention_output[0]) hidden_states = hidden_states + self.dropout(attention_output[0])
outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
@@ -606,9 +606,9 @@ class T5LayerSelfAttention(nn.Module):
class T5LayerCrossAttention(nn.Module): class T5LayerCrossAttention(nn.Module):
def __init__(self, config): def __init__(self, config, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False) self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False, layer_idx=layer_idx)
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate) self.dropout = nn.Dropout(config.dropout_rate)
@@ -623,6 +623,7 @@ class T5LayerCrossAttention(nn.Module):
use_cache=False, use_cache=False,
query_length=None, query_length=None,
output_attentions=False, output_attentions=False,
cache_position=None,
): ):
normed_hidden_states = self.layer_norm(hidden_states) normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.EncDecAttention( attention_output = self.EncDecAttention(
@@ -635,6 +636,7 @@ class T5LayerCrossAttention(nn.Module):
use_cache=use_cache, use_cache=use_cache,
query_length=query_length, query_length=query_length,
output_attentions=output_attentions, output_attentions=output_attentions,
cache_position=cache_position,
) )
layer_output = hidden_states + self.dropout(attention_output[0]) layer_output = hidden_states + self.dropout(attention_output[0])
outputs = (layer_output,) + attention_output[1:] # add attentions if we output them outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
@@ -642,13 +644,15 @@ class T5LayerCrossAttention(nn.Module):
class T5Block(nn.Module): class T5Block(nn.Module):
def __init__(self, config, has_relative_attention_bias=False): def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
self.layer = nn.ModuleList() self.layer = nn.ModuleList()
self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) self.layer.append(
T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx)
)
if self.is_decoder: if self.is_decoder:
self.layer.append(T5LayerCrossAttention(config)) self.layer.append(T5LayerCrossAttention(config, layer_idx=layer_idx))
self.layer.append(T5LayerFF(config)) self.layer.append(T5LayerFF(config))
@@ -666,34 +670,19 @@ class T5Block(nn.Module):
use_cache=False, use_cache=False,
output_attentions=False, output_attentions=False,
return_dict=True, return_dict=True,
cache_position=None,
): ):
if past_key_value is not None:
if not self.is_decoder:
logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.")
expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
if len(past_key_value) != expected_num_past_key_values:
raise ValueError(
f"There should be {expected_num_past_key_values} past states. "
f"{'2 (key / value) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
f"Got {len(past_key_value)} past key / value states"
)
self_attn_past_key_value = past_key_value[:2]
cross_attn_past_key_value = past_key_value[2:]
else:
self_attn_past_key_value, cross_attn_past_key_value = None, None
self_attention_outputs = self.layer[0]( self_attention_outputs = self.layer[0](
hidden_states, hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
position_bias=position_bias, position_bias=position_bias,
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
past_key_value=self_attn_past_key_value, past_key_value=past_key_value,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
cache_position=cache_position,
) )
hidden_states, present_key_value_state = self_attention_outputs[:2] hidden_states, past_key_value = self_attention_outputs[:2]
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
# clamp inf values to enable fp16 training # clamp inf values to enable fp16 training
@@ -707,25 +696,18 @@ class T5Block(nn.Module):
do_cross_attention = self.is_decoder and encoder_hidden_states is not None do_cross_attention = self.is_decoder and encoder_hidden_states is not None
if do_cross_attention: if do_cross_attention:
# the actual query length is unknown for cross attention
# if using past key value states. Need to inject it here
if present_key_value_state is not None:
query_length = present_key_value_state[0].shape[2]
else:
query_length = None
cross_attention_outputs = self.layer[1]( cross_attention_outputs = self.layer[1](
hidden_states, hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
position_bias=encoder_decoder_position_bias, position_bias=encoder_decoder_position_bias,
layer_head_mask=cross_attn_layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=past_key_value,
query_length=query_length, query_length=cache_position[-1] + 1,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = cross_attention_outputs[0] hidden_states, past_key_value = cross_attention_outputs[:2]
# clamp inf values to enable fp16 training # clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16: if hidden_states.dtype == torch.float16:
@@ -736,10 +718,6 @@ class T5Block(nn.Module):
) )
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
# Combine self attn and cross attn key value states
if present_key_value_state is not None:
present_key_value_state = present_key_value_state + cross_attention_outputs[1]
# Keep cross-attention outputs and relative position weights # Keep cross-attention outputs and relative position weights
attention_outputs = attention_outputs + cross_attention_outputs[2:] attention_outputs = attention_outputs + cross_attention_outputs[2:]
@@ -758,11 +736,11 @@ class T5Block(nn.Module):
outputs = (hidden_states,) outputs = (hidden_states,)
if use_cache: if use_cache:
outputs = outputs + (present_key_value_state,) + attention_outputs outputs = outputs + (past_key_value,) + attention_outputs
else: else:
outputs = outputs + attention_outputs outputs = outputs + attention_outputs
return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) return outputs # hidden-states, past_key_value, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
class T5ClassificationHead(nn.Module): class T5ClassificationHead(nn.Module):
@@ -794,6 +772,9 @@ class T5PreTrainedModel(PreTrainedModel):
base_model_prefix = "transformer" base_model_prefix = "transformer"
is_parallelizable = True is_parallelizable = True
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_supports_quantized_cache = False # enc-dec models don't support yet
_supports_static_cache = True
_supports_cache_class = True
_no_split_modules = ["T5Block"] _no_split_modules = ["T5Block"]
_keep_in_fp32_modules = ["wo"] _keep_in_fp32_modules = ["wo"]
@@ -905,7 +886,7 @@ class T5Stack(T5PreTrainedModel):
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
self.block = nn.ModuleList( self.block = nn.ModuleList(
[T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] [T5Block(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) for i in range(config.num_layers)]
) )
self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate) self.dropout = nn.Dropout(config.dropout_rate)
@@ -981,6 +962,7 @@ class T5Stack(T5PreTrainedModel):
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
cache_position=None,
): ):
# Model parallel # Model parallel
if self.model_parallel: if self.model_parallel:
@@ -1007,6 +989,13 @@ class T5Stack(T5PreTrainedModel):
err_msg_prefix = "decoder_" if self.is_decoder else "" err_msg_prefix = "decoder_" if self.is_decoder else ""
raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if inputs_embeds is None: if inputs_embeds is None:
if self.embed_tokens is None: if self.embed_tokens is None:
raise ValueError("You have to initialize the model with valid token embeddings") raise ValueError("You have to initialize the model with valid token embeddings")
@@ -1014,23 +1003,57 @@ class T5Stack(T5PreTrainedModel):
batch_size, seq_length = input_shape batch_size, seq_length = input_shape
# required mask seq length can be calculated via length of past
mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
if use_cache is True: if use_cache is True:
if not self.is_decoder: if not self.is_decoder:
raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
# initialize past_key_values with `None` if past does not exist # initialize past_key_values
if past_key_values is None: return_legacy_cache = False
past_key_values = [None] * len(self.block) return_self_attention_cache = False
if self.is_decoder and (use_cache or past_key_values is not None):
if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache):
return_self_attention_cache = True
past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
elif not isinstance(past_key_values, EncoderDecoderCache):
return_legacy_cache = True
logger.warning_once(
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. "
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "
"`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
)
past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
elif past_key_values is None:
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
elif not self.is_decoder:
# do not pass cache object down the line for encoder stack
# it messes indexing later in decoder-stack because cache object is modified in-place
past_key_values = None
if attention_mask is None: past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
if cache_position is None:
cache_position = torch.arange(
past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
)
if attention_mask is None and not is_torchdynamo_compiling():
# required mask seq length can be calculated via length of past cache
mask_seq_length = past_key_values_length + seq_length
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] if self.config.is_decoder:
# ourselves in which case we just need to make it broadcastable to all heads. causal_mask = self._update_causal_mask(
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) attention_mask,
inputs_embeds,
cache_position,
past_key_values.self_attention_cache if past_key_values is not None else None,
output_attentions,
)
elif attention_mask is not None:
causal_mask = attention_mask[:, None, None, :]
causal_mask = causal_mask.to(dtype=inputs_embeds.dtype)
causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min
else:
causal_mask = None
# If a 2D or 3D attention mask is provided for the cross-attention # If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
@@ -1045,17 +1068,9 @@ class T5Stack(T5PreTrainedModel):
else: else:
encoder_extended_attention_mask = None encoder_extended_attention_mask = None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# Prepare head mask if needed # Prepare head mask if needed
head_mask = self.get_head_mask(head_mask, self.config.num_layers) head_mask = self.get_head_mask(head_mask, self.config.num_layers)
cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
present_key_value_states = () if use_cache else None
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
all_cross_attentions = () if (output_attentions and self.is_decoder) else None all_cross_attentions = () if (output_attentions and self.is_decoder) else None
@@ -1064,15 +1079,15 @@ class T5Stack(T5PreTrainedModel):
hidden_states = self.dropout(inputs_embeds) hidden_states = self.dropout(inputs_embeds)
for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): for i, layer_module in enumerate(self.block):
layer_head_mask = head_mask[i] layer_head_mask = head_mask[i]
cross_attn_layer_head_mask = cross_attn_head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i]
# Model parallel # Model parallel
if self.model_parallel: if self.model_parallel:
torch.cuda.set_device(hidden_states.device) torch.cuda.set_device(hidden_states.device)
# Ensure that attention_mask is always on the same device as hidden_states # Ensure that attention_mask is always on the same device as hidden_states
if attention_mask is not None: if causal_mask is not None:
attention_mask = attention_mask.to(hidden_states.device) causal_mask = causal_mask.to(hidden_states.device)
if position_bias is not None: if position_bias is not None:
position_bias = position_bias.to(hidden_states.device) position_bias = position_bias.to(hidden_states.device)
if encoder_hidden_states is not None: if encoder_hidden_states is not None:
@@ -1092,7 +1107,7 @@ class T5Stack(T5PreTrainedModel):
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.forward, layer_module.forward,
hidden_states, hidden_states,
extended_attention_mask, causal_mask,
position_bias, position_bias,
encoder_hidden_states, encoder_hidden_states,
encoder_extended_attention_mask, encoder_extended_attention_mask,
@@ -1102,20 +1117,24 @@ class T5Stack(T5PreTrainedModel):
None, # past_key_value is always None with gradient checkpointing None, # past_key_value is always None with gradient checkpointing
use_cache, use_cache,
output_attentions, output_attentions,
return_dict,
cache_position,
) )
else: else:
layer_outputs = layer_module( layer_outputs = layer_module(
hidden_states, hidden_states,
attention_mask=extended_attention_mask, attention_mask=causal_mask,
position_bias=position_bias, position_bias=position_bias,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask, encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias, encoder_decoder_position_bias=encoder_decoder_position_bias,
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
cross_attn_layer_head_mask=cross_attn_layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask,
past_key_value=past_key_value, past_key_value=past_key_values,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
return_dict=return_dict,
cache_position=cache_position,
) )
# layer_outputs is a tuple with: # layer_outputs is a tuple with:
@@ -1123,7 +1142,7 @@ class T5Stack(T5PreTrainedModel):
if use_cache is False: if use_cache is False:
layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
hidden_states, present_key_value_state = layer_outputs[:2] hidden_states, next_decoder_cache = layer_outputs[:2]
# We share the position biases between the layers - the first layer store them # We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
@@ -1131,9 +1150,6 @@ class T5Stack(T5PreTrainedModel):
position_bias = layer_outputs[2] position_bias = layer_outputs[2]
if self.is_decoder and encoder_hidden_states is not None: if self.is_decoder and encoder_hidden_states is not None:
encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
# append next layer key value states
if use_cache:
present_key_value_states = present_key_value_states + (present_key_value_state,)
if output_attentions: if output_attentions:
all_attentions = all_attentions + (layer_outputs[3],) all_attentions = all_attentions + (layer_outputs[3],)
@@ -1153,12 +1169,18 @@ class T5Stack(T5PreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if return_self_attention_cache:
next_cache = past_key_values.self_attention_cache
if return_legacy_cache:
next_cache = past_key_values.to_legacy_cache()
if not return_dict: if not return_dict:
return tuple( return tuple(
v v
for v in [ for v in [
hidden_states, hidden_states,
present_key_value_states, next_cache,
all_hidden_states, all_hidden_states,
all_attentions, all_attentions,
all_cross_attentions, all_cross_attentions,
@@ -1167,12 +1189,135 @@ class T5Stack(T5PreTrainedModel):
) )
return BaseModelOutputWithPastAndCrossAttentions( return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=present_key_value_states, past_key_values=next_cache,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_attentions, attentions=all_attentions,
cross_attentions=all_cross_attentions, cross_attentions=all_cross_attentions,
) )
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
T5_START_DOCSTRING = r""" T5_START_DOCSTRING = r"""
@@ -1286,6 +1431,9 @@ T5_INPUTS_DOCSTRING = r"""
more detail. more detail.
return_dict (`bool`, *optional*): return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
cache in the correct position and to infer the complete sequence length.
""" """
T5_ENCODER_INPUTS_DOCSTRING = r""" T5_ENCODER_INPUTS_DOCSTRING = r"""
@@ -1446,6 +1594,7 @@ class T5Model(T5PreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
r""" r"""
Returns: Returns:
@@ -1525,6 +1674,7 @@ class T5Model(T5PreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
cache_position=cache_position,
) )
if not return_dict: if not return_dict:
@@ -1656,6 +1806,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel, GenerationMixin):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -1750,6 +1901,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel, GenerationMixin):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
cache_position=cache_position,
) )
sequence_output = decoder_outputs[0] sequence_output = decoder_outputs[0]

View File

@@ -34,13 +34,16 @@ from transformers.modeling_outputs import (
) )
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import ( from ...utils import (
ModelOutput, ModelOutput,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_torchdynamo_compiling,
replace_return_docstrings, replace_return_docstrings,
) )
@@ -154,6 +157,9 @@ UDOP_INPUTS_DOCSTRING = r"""
more detail. more detail.
return_dict (`bool`, *optional*): return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
cache in the correct position and to infer the complete sequence length.
""" """
@@ -411,6 +417,8 @@ class UdopPreTrainedModel(PreTrainedModel):
config_class = UdopConfig config_class = UdopConfig
base_model_prefix = "transformer" base_model_prefix = "transformer"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_supports_cache_class = True
_supports_static_cache = False
_keep_in_fp32_modules = ["wo"] _keep_in_fp32_modules = ["wo"]
def _init_weights(self, module): def _init_weights(self, module):
@@ -598,7 +606,12 @@ class UdopLayerFF(nn.Module):
# Copied from transformers.models.t5.modeling_t5.T5Attention with T5->Udop # Copied from transformers.models.t5.modeling_t5.T5Attention with T5->Udop
class UdopAttention(nn.Module): class UdopAttention(nn.Module):
def __init__(self, config: UdopConfig, has_relative_attention_bias=False): def __init__(
self,
config: UdopConfig,
has_relative_attention_bias=False,
layer_idx: Optional[int] = None,
):
super().__init__() super().__init__()
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
self.has_relative_attention_bias = has_relative_attention_bias self.has_relative_attention_bias = has_relative_attention_bias
@@ -609,6 +622,13 @@ class UdopAttention(nn.Module):
self.n_heads = config.num_heads self.n_heads = config.num_heads
self.dropout = config.dropout_rate self.dropout = config.dropout_rate
self.inner_dim = self.n_heads * self.key_value_proj_dim self.inner_dim = self.n_heads * self.key_value_proj_dim
self.layer_idx = layer_idx
if layer_idx is None and self.is_decoder:
logger.warning_once(
f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
"will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
# Mesh TensorFlow initialization to avoid scaling before softmax # Mesh TensorFlow initialization to avoid scaling before softmax
self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
@@ -685,11 +705,14 @@ class UdopAttention(nn.Module):
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
return relative_buckets return relative_buckets
def compute_bias(self, query_length, key_length, device=None): def compute_bias(self, query_length, key_length, device=None, cache_position=None):
"""Compute binned relative position bias""" """Compute binned relative position bias"""
if device is None: if device is None:
device = self.relative_attention_bias.weight.device device = self.relative_attention_bias.weight.device
if cache_position is None:
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
else:
context_position = cache_position[:, None].to(device)
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
relative_position = memory_position - context_position # shape (query_length, key_length) relative_position = memory_position - context_position # shape (query_length, key_length)
relative_position_bucket = self._relative_position_bucket( relative_position_bucket = self._relative_position_bucket(
@@ -713,94 +736,72 @@ class UdopAttention(nn.Module):
query_length=None, query_length=None,
use_cache=False, use_cache=False,
output_attentions=False, output_attentions=False,
cache_position=None,
): ):
""" """
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
""" """
# Input is (batch_size, seq_length, dim) # Input is (batch_size, seq_length, dim)
# Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder)
# past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
batch_size, seq_length = hidden_states.shape[:2] batch_size, seq_length = hidden_states.shape[:2]
real_seq_length = seq_length # if key_value_states are provided this layer is used as a cross-attention layer for the decoder
is_cross_attention = key_value_states is not None
query_states = self.q(hidden_states)
query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
if past_key_value is not None: if past_key_value is not None:
if len(past_key_value) != 2: is_updated = past_key_value.is_updated.get(self.layer_idx)
raise ValueError( if is_cross_attention:
f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" # after the first generated id, we can subsequently re-use all key/value_states from cache
) curr_past_key_value = past_key_value.cross_attention_cache
real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length
key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]
def shape(states):
"""projection"""
return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
def unshape(states):
"""reshape"""
return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)
def project(hidden_states, proj_layer, key_value_states, past_key_value):
"""projects hidden states correctly to key/query states"""
if key_value_states is None:
# self-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(hidden_states))
elif past_key_value is None:
# cross-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(key_value_states))
if past_key_value is not None:
if key_value_states is None:
# self-attn
# (batch_size, n_heads, key_length, dim_per_head)
hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
elif past_key_value.shape[2] != key_value_states.shape[1]:
# checking that the `sequence_length` of the `past_key_value` is the same as
# the provided `key_value_states` to support prefix tuning
# cross-attn
# (batch_size, n_heads, seq_length, dim_per_head)
hidden_states = shape(proj_layer(key_value_states))
else: else:
# cross-attn curr_past_key_value = past_key_value.self_attention_cache
hidden_states = past_key_value
return hidden_states
# get query states current_states = key_value_states if is_cross_attention else hidden_states
query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) if is_cross_attention and past_key_value is not None and is_updated:
# reuse k,v, cross_attentions
key_states = curr_past_key_value.key_cache[self.layer_idx]
value_states = curr_past_key_value.value_cache[self.layer_idx]
else:
key_states = self.k(current_states)
value_states = self.v(current_states)
key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
# get key/value states if past_key_value is not None:
key_states = project( # save all key/value_states to cache to be re-used for fast auto-regressive generation
hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None cache_position = cache_position if not is_cross_attention else None
) key_states, value_states = curr_past_key_value.update(
value_states = project( key_states, value_states, self.layer_idx, {"cache_position": cache_position}
hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
) )
# set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
if is_cross_attention:
past_key_value.is_updated[self.layer_idx] = True
# compute scores # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
scores = torch.matmul( scores = torch.matmul(query_states, key_states.transpose(3, 2))
query_states, key_states.transpose(3, 2)
) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
if position_bias is None: if position_bias is None:
key_length = key_states.shape[-2]
# cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
real_seq_length = query_length if query_length is not None else cache_position[-1] + 1
if not self.has_relative_attention_bias: if not self.has_relative_attention_bias:
position_bias = torch.zeros( position_bias = torch.zeros(
(1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype
) )
if self.gradient_checkpointing and self.training: if self.gradient_checkpointing and self.training:
position_bias.requires_grad = True position_bias.requires_grad = True
else: else:
position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) position_bias = self.compute_bias(
real_seq_length, key_length, device=scores.device, cache_position=cache_position
# if key and values are already calculated )
# we want only the last query position bias position_bias = position_bias[:, :, -seq_length:, :]
if past_key_value is not None:
position_bias = position_bias[:, :, -hidden_states.size(1) :, :]
if mask is not None: if mask is not None:
position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) causal_mask = mask[:, :, :, : key_states.shape[-2]]
position_bias = position_bias + causal_mask
if self.pruned_heads: if self.pruned_heads:
mask = torch.ones(position_bias.shape[1]) mask = torch.ones(position_bias.shape[1])
@@ -810,22 +811,22 @@ class UdopAttention(nn.Module):
position_bias_masked = position_bias position_bias_masked = position_bias
scores += position_bias_masked scores += position_bias_masked
attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
scores # (batch_size, n_heads, seq_length, key_length)
) # (batch_size, n_heads, seq_length, key_length) attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
attn_weights = nn.functional.dropout( attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
attn_weights, p=self.dropout, training=self.training
) # (batch_size, n_heads, seq_length, key_length)
# Mask heads if we want to # Mask heads if we want to
if layer_head_mask is not None: if layer_head_mask is not None:
attn_weights = attn_weights * layer_head_mask attn_weights = attn_weights * layer_head_mask
attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, -1, self.inner_dim)
attn_output = self.o(attn_output) attn_output = self.o(attn_output)
present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None outputs = (attn_output, past_key_value, position_bias)
outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)
if output_attentions: if output_attentions:
outputs = outputs + (attn_weights,) outputs = outputs + (attn_weights,)
@@ -834,9 +835,11 @@ class UdopAttention(nn.Module):
# Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->Udop # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->Udop
class UdopLayerSelfAttention(nn.Module): class UdopLayerSelfAttention(nn.Module):
def __init__(self, config, has_relative_attention_bias=False): def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.SelfAttention = UdopAttention(config, has_relative_attention_bias=has_relative_attention_bias) self.SelfAttention = UdopAttention(
config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx
)
self.layer_norm = UdopLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.layer_norm = UdopLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate) self.dropout = nn.Dropout(config.dropout_rate)
@@ -849,6 +852,7 @@ class UdopLayerSelfAttention(nn.Module):
past_key_value=None, past_key_value=None,
use_cache=False, use_cache=False,
output_attentions=False, output_attentions=False,
cache_position=None,
): ):
normed_hidden_states = self.layer_norm(hidden_states) normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.SelfAttention( attention_output = self.SelfAttention(
@@ -859,6 +863,7 @@ class UdopLayerSelfAttention(nn.Module):
past_key_value=past_key_value, past_key_value=past_key_value,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
cache_position=cache_position,
) )
hidden_states = hidden_states + self.dropout(attention_output[0]) hidden_states = hidden_states + self.dropout(attention_output[0])
outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
@@ -867,9 +872,9 @@ class UdopLayerSelfAttention(nn.Module):
# Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->Udop # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->Udop
class UdopLayerCrossAttention(nn.Module): class UdopLayerCrossAttention(nn.Module):
def __init__(self, config): def __init__(self, config, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.EncDecAttention = UdopAttention(config, has_relative_attention_bias=False) self.EncDecAttention = UdopAttention(config, has_relative_attention_bias=False, layer_idx=layer_idx)
self.layer_norm = UdopLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.layer_norm = UdopLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate) self.dropout = nn.Dropout(config.dropout_rate)
@@ -884,6 +889,7 @@ class UdopLayerCrossAttention(nn.Module):
use_cache=False, use_cache=False,
query_length=None, query_length=None,
output_attentions=False, output_attentions=False,
cache_position=None,
): ):
normed_hidden_states = self.layer_norm(hidden_states) normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.EncDecAttention( attention_output = self.EncDecAttention(
@@ -896,6 +902,7 @@ class UdopLayerCrossAttention(nn.Module):
use_cache=use_cache, use_cache=use_cache,
query_length=query_length, query_length=query_length,
output_attentions=output_attentions, output_attentions=output_attentions,
cache_position=cache_position,
) )
layer_output = hidden_states + self.dropout(attention_output[0]) layer_output = hidden_states + self.dropout(attention_output[0])
outputs = (layer_output,) + attention_output[1:] # add attentions if we output them outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
@@ -904,13 +911,17 @@ class UdopLayerCrossAttention(nn.Module):
# Copied from transformers.models.t5.modeling_t5.T5Block with T5->Udop # Copied from transformers.models.t5.modeling_t5.T5Block with T5->Udop
class UdopBlock(nn.Module): class UdopBlock(nn.Module):
def __init__(self, config, has_relative_attention_bias=False): def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
self.layer = nn.ModuleList() self.layer = nn.ModuleList()
self.layer.append(UdopLayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias)) self.layer.append(
UdopLayerSelfAttention(
config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx
)
)
if self.is_decoder: if self.is_decoder:
self.layer.append(UdopLayerCrossAttention(config)) self.layer.append(UdopLayerCrossAttention(config, layer_idx=layer_idx))
self.layer.append(UdopLayerFF(config)) self.layer.append(UdopLayerFF(config))
@@ -928,34 +939,19 @@ class UdopBlock(nn.Module):
use_cache=False, use_cache=False,
output_attentions=False, output_attentions=False,
return_dict=True, return_dict=True,
cache_position=None,
): ):
if past_key_value is not None:
if not self.is_decoder:
logger.warning("`past_key_values` is passed to the encoder. Please make sure this is intended.")
expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
if len(past_key_value) != expected_num_past_key_values:
raise ValueError(
f"There should be {expected_num_past_key_values} past states. "
f"{'2 (key / value) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
f"Got {len(past_key_value)} past key / value states"
)
self_attn_past_key_value = past_key_value[:2]
cross_attn_past_key_value = past_key_value[2:]
else:
self_attn_past_key_value, cross_attn_past_key_value = None, None
self_attention_outputs = self.layer[0]( self_attention_outputs = self.layer[0](
hidden_states, hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
position_bias=position_bias, position_bias=position_bias,
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
past_key_value=self_attn_past_key_value, past_key_value=past_key_value,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
cache_position=cache_position,
) )
hidden_states, present_key_value_state = self_attention_outputs[:2] hidden_states, past_key_value = self_attention_outputs[:2]
attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights
# clamp inf values to enable fp16 training # clamp inf values to enable fp16 training
@@ -969,25 +965,18 @@ class UdopBlock(nn.Module):
do_cross_attention = self.is_decoder and encoder_hidden_states is not None do_cross_attention = self.is_decoder and encoder_hidden_states is not None
if do_cross_attention: if do_cross_attention:
# the actual query length is unknown for cross attention
# if using past key value states. Need to inject it here
if present_key_value_state is not None:
query_length = present_key_value_state[0].shape[2]
else:
query_length = None
cross_attention_outputs = self.layer[1]( cross_attention_outputs = self.layer[1](
hidden_states, hidden_states,
key_value_states=encoder_hidden_states, key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
position_bias=encoder_decoder_position_bias, position_bias=encoder_decoder_position_bias,
layer_head_mask=cross_attn_layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=past_key_value,
query_length=query_length, query_length=cache_position[-1] + 1,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
hidden_states = cross_attention_outputs[0] hidden_states, past_key_value = cross_attention_outputs[:2]
# clamp inf values to enable fp16 training # clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16: if hidden_states.dtype == torch.float16:
@@ -998,10 +987,6 @@ class UdopBlock(nn.Module):
) )
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
# Combine self attn and cross attn key value states
if present_key_value_state is not None:
present_key_value_state = present_key_value_state + cross_attention_outputs[1]
# Keep cross-attention outputs and relative position weights # Keep cross-attention outputs and relative position weights
attention_outputs = attention_outputs + cross_attention_outputs[2:] attention_outputs = attention_outputs + cross_attention_outputs[2:]
@@ -1020,11 +1005,11 @@ class UdopBlock(nn.Module):
outputs = (hidden_states,) outputs = (hidden_states,)
if use_cache: if use_cache:
outputs = outputs + (present_key_value_state,) + attention_outputs outputs = outputs + (past_key_value,) + attention_outputs
else: else:
outputs = outputs + attention_outputs outputs = outputs + attention_outputs
return outputs # hidden-states, present_key_value_states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) return outputs # hidden-states, past_key_value, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
class UdopCellEmbeddings(nn.Module): class UdopCellEmbeddings(nn.Module):
@@ -1286,7 +1271,7 @@ class UdopStack(UdopPreTrainedModel):
self.num_layers = config.num_layers self.num_layers = config.num_layers
self.block = nn.ModuleList( self.block = nn.ModuleList(
[UdopBlock(config, has_relative_attention_bias=bool(i == 0)) for i in range(self.num_layers)] [UdopBlock(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) for i in range(self.num_layers)]
) )
self.final_layer_norm = UdopLayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.final_layer_norm = UdopLayerNorm(config.d_model, eps=config.layer_norm_epsilon)
@@ -1338,6 +1323,7 @@ class UdopStack(UdopPreTrainedModel):
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
cache_position=None,
): ):
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -1399,26 +1385,54 @@ class UdopStack(UdopPreTrainedModel):
batch_size, seq_length = input_shape batch_size, seq_length = input_shape
# required mask seq length can be calculated via length of past
mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
if use_cache is True: if use_cache is True:
assert self.is_decoder, "`use_cache` can only be set to `True` if {} is used as a decoder".format(self) assert self.is_decoder, "`use_cache` can only be set to `True` if {} is used as a decoder".format(self)
if attention_mask is None: # initialize past_key_values
attention_mask = torch.ones(batch_size, mask_seq_length).to(inputs_embeds.device) return_legacy_cache = False
if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: return_self_attention_cache = False
encoder_seq_length = encoder_hidden_states.shape[1] if self.is_decoder and (use_cache or past_key_values is not None):
encoder_attention_mask = torch.ones( if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache):
batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long return_self_attention_cache = True
past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
elif not isinstance(past_key_values, EncoderDecoderCache):
return_legacy_cache = True
logger.warning_once(
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. "
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "
"`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
)
past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
elif past_key_values is None:
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
elif not self.is_decoder:
# do not pass cache object down the line for encoder stack
# it messes indexing later in decoder-stack because cache object is modified in-place
past_key_values = None
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
if cache_position is None:
cache_position = torch.arange(
past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
) )
# initialize past_key_values with `None` if past does not exist if attention_mask is None and not is_torchdynamo_compiling():
if past_key_values is None: # required mask seq length can be calculated via length of past cache
past_key_values = [None] * len(self.block) mask_seq_length = past_key_values_length + seq_length
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
# ourselves in which case we just need to make it broadcastable to all heads. if self.config.is_decoder:
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) causal_mask = self._update_causal_mask(
attention_mask,
inputs_embeds,
cache_position,
past_key_values.self_attention_cache if past_key_values is not None else None,
output_attentions,
)
else:
causal_mask = attention_mask[:, None, None, :]
causal_mask = causal_mask.to(dtype=inputs_embeds.dtype)
causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min
if self.is_decoder and encoder_attention_mask is not None: if self.is_decoder and encoder_attention_mask is not None:
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
@@ -1427,7 +1441,6 @@ class UdopStack(UdopPreTrainedModel):
# Prepare head mask if needed # Prepare head mask if needed
head_mask = self.get_head_mask(head_mask, self.num_layers) head_mask = self.get_head_mask(head_mask, self.num_layers)
present_key_value_states = () if use_cache else None
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
all_cross_attentions = () if (output_attentions and self.is_decoder) else None all_cross_attentions = () if (output_attentions and self.is_decoder) else None
@@ -1436,34 +1449,35 @@ class UdopStack(UdopPreTrainedModel):
position_bias = None position_bias = None
else: else:
position_bias = self.relative_bias(attention_mask=attention_mask, bbox=bbox) position_bias = self.relative_bias(attention_mask=attention_mask, bbox=bbox)
position_bias = position_bias + extended_attention_mask position_bias = position_bias + causal_mask
encoder_decoder_position_bias = None encoder_decoder_position_bias = None
hidden_states = inputs_embeds hidden_states = inputs_embeds
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): for i, layer_module in enumerate(self.block):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module( layer_outputs = layer_module(
hidden_states, hidden_states,
attention_mask=extended_attention_mask, attention_mask=causal_mask,
position_bias=position_bias, position_bias=position_bias,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask, encoder_attention_mask=encoder_extended_attention_mask,
encoder_decoder_position_bias=encoder_decoder_position_bias, encoder_decoder_position_bias=encoder_decoder_position_bias,
layer_head_mask=head_mask[i], layer_head_mask=head_mask[i],
past_key_value=past_key_value, past_key_value=past_key_values,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
cache_position=cache_position,
) )
# layer_outputs is a tuple with: # layer_outputs is a tuple with:
# hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias) # hidden-states, key-value-states, (self-attention weights), (self-attention position bias), (cross-attention weights), (cross-attention position bias)
if use_cache is False: # MP fixes if use_cache is False: # MP fixes
layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:]
hidden_states, present_key_value_state = layer_outputs[:2] hidden_states, next_decoder_cache = layer_outputs[:2]
# We share the position biases between the layers - the first layer store them # We share the position biases between the layers - the first layer store them
# layer_outputs = hidden-states, key-value-states (self-attention weights), # layer_outputs = hidden-states, key-value-states (self-attention weights),
@@ -1472,9 +1486,6 @@ class UdopStack(UdopPreTrainedModel):
position_bias = layer_outputs[2] position_bias = layer_outputs[2]
if self.is_decoder and encoder_hidden_states is not None: if self.is_decoder and encoder_hidden_states is not None:
encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3]
# append next layer key value states
if use_cache:
present_key_value_states = present_key_value_states + (present_key_value_state,)
if output_attentions: if output_attentions:
all_attentions = all_attentions + (layer_outputs[2],) # We keep only self-attention weights for now all_attentions = all_attentions + (layer_outputs[2],) # We keep only self-attention weights for now
@@ -1488,13 +1499,19 @@ class UdopStack(UdopPreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if return_self_attention_cache:
next_cache = past_key_values.self_attention_cache
if return_legacy_cache:
next_cache = past_key_values.to_legacy_cache()
if not return_dict: if not return_dict:
return tuple( return tuple(
v v
for v in [ for v in [
hidden_states, hidden_states,
attention_mask, attention_mask,
present_key_value_states, next_cache,
all_hidden_states, all_hidden_states,
all_attentions, all_attentions,
all_cross_attentions, all_cross_attentions,
@@ -1505,12 +1522,135 @@ class UdopStack(UdopPreTrainedModel):
return BaseModelOutputWithAttentionMask( return BaseModelOutputWithAttentionMask(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
past_key_values=present_key_value_states, past_key_values=next_cache,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_attentions, attentions=all_attentions,
cross_attentions=all_cross_attentions, cross_attentions=all_cross_attentions,
) )
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
@add_start_docstrings( @add_start_docstrings(
"The bare UDOP encoder-decoder Transformer outputting raw hidden-states without any specific head on top.", "The bare UDOP encoder-decoder Transformer outputting raw hidden-states without any specific head on top.",
@@ -1584,6 +1724,7 @@ class UdopModel(UdopPreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[Tensor, ...]: ) -> Tuple[Tensor, ...]:
r""" r"""
Returns: Returns:
@@ -1653,6 +1794,7 @@ class UdopModel(UdopPreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
cache_position=cache_position,
) )
if not return_dict: if not return_dict:
@@ -1759,6 +1901,7 @@ class UdopForConditionalGeneration(UdopPreTrainedModel, GenerationMixin):
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
labels: Optional[Tensor] = None, labels: Optional[Tensor] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Tuple[Tensor, ...]: ) -> Tuple[Tensor, ...]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -1837,6 +1980,7 @@ class UdopForConditionalGeneration(UdopPreTrainedModel, GenerationMixin):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
cache_position=cache_position,
) )
sequence_output = decoder_outputs[0] sequence_output = decoder_outputs[0]

View File

@@ -72,7 +72,12 @@ class UMT5Config(PretrainedConfig):
model_type = "umt5" model_type = "umt5"
keys_to_ignore_at_inference = ["past_key_values"] keys_to_ignore_at_inference = ["past_key_values"]
attribute_map = {"hidden_size": "d_model", "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers"} attribute_map = {
"hidden_size": "d_model",
"num_attention_heads": "num_heads",
"num_hidden_layers": "num_layers",
"head_dim": "d_kv",
}
def __init__( def __init__(
self, self,

View File

@@ -23,7 +23,9 @@ from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
from ...generation import GenerationMixin from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
from ...modeling_outputs import ( from ...modeling_outputs import (
BaseModelOutput, BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions,
@@ -40,6 +42,7 @@ from ...utils import (
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_torch_fx_proxy, is_torch_fx_proxy,
is_torchdynamo_compiling,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
@@ -155,7 +158,7 @@ class UMT5Attention(nn.Module):
T5's attention using relative_attention_bias. T5's attention using relative_attention_bias.
""" """
def __init__(self, config, has_relative_attention_bias=False): def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
self.has_relative_attention_bias = has_relative_attention_bias self.has_relative_attention_bias = has_relative_attention_bias
@@ -166,6 +169,13 @@ class UMT5Attention(nn.Module):
self.n_heads = config.num_heads self.n_heads = config.num_heads
self.dropout = config.dropout_rate self.dropout = config.dropout_rate
self.inner_dim = self.n_heads * self.key_value_proj_dim self.inner_dim = self.n_heads * self.key_value_proj_dim
self.layer_idx = layer_idx
if layer_idx is None and self.is_decoder:
logger.warning_once(
f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
"will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)
# Mesh TensorFlow initialization to avoid scaling before softmax # Mesh TensorFlow initialization to avoid scaling before softmax
self.q = nn.Linear(self.d_model, self.inner_dim, bias=False) self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
@@ -230,11 +240,14 @@ class UMT5Attention(nn.Module):
relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
return relative_buckets return relative_buckets
def compute_bias(self, query_length, key_length, device=None): def compute_bias(self, query_length, key_length, device=None, cache_position=None):
"""Compute binned relative position bias""" """Compute binned relative position bias"""
if device is None: if device is None:
device = self.relative_attention_bias.weight.device device = self.relative_attention_bias.weight.device
if cache_position is None:
context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
else:
context_position = cache_position[:, None]
memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
relative_position = memory_position - context_position # shape (query_length, key_length) relative_position = memory_position - context_position # shape (query_length, key_length)
relative_position_bucket = self._relative_position_bucket(relative_position) relative_position_bucket = self._relative_position_bucket(relative_position)
@@ -249,78 +262,95 @@ class UMT5Attention(nn.Module):
past_key_value: Optional[Tuple[torch.Tensor]] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None,
cache_position: Optional[torch.Tensor] = None,
): ):
is_cross_attention = encoder_hidden_states is not None
batch_size, seq_length = hidden_states.shape[:2] batch_size, seq_length = hidden_states.shape[:2]
# use encoder_hidden_states if cross attention # if encoder_hidden_states are provided this layer is used as a cross-attention layer for the decoder
current_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states is_cross_attention = encoder_hidden_states is not None
# checking that the `sequence_length` of the `past_key_value` is the same as the he provided
# `encoder_hidden_states` to support prefix tuning query_states = self.q(hidden_states)
if is_cross_attention and past_key_value and past_key_value[0].shape[2] == current_states.shape[1]: query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
if past_key_value is not None:
is_updated = past_key_value.is_updated.get(self.layer_idx)
if is_cross_attention:
# after the first generated id, we can subsequently re-use all key/value_states from cache
curr_past_key_value = past_key_value.cross_attention_cache
else:
curr_past_key_value = past_key_value.self_attention_cache
current_states = encoder_hidden_states if is_cross_attention else hidden_states
if is_cross_attention and past_key_value is not None and is_updated:
# reuse k,v, cross_attentions # reuse k,v, cross_attentions
key_states = past_key_value[0] key_states = curr_past_key_value.key_cache[self.layer_idx]
value_states = past_key_value[1] value_states = curr_past_key_value.value_cache[self.layer_idx]
else: else:
key_states = self._shape(self.k(current_states)) key_states = self.k(current_states)
value_states = self._shape(self.v(current_states)) value_states = self.v(current_states)
if past_key_value is not None and not is_cross_attention: key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
# reuse k, v, self_attention value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
query_states = self._shape(self.q(hidden_states))
attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2))
# compute positional bias
if self.has_relative_attention_bias:
query_length = seq_length
if past_key_value is not None: if past_key_value is not None:
query_length += past_key_value[0].shape[2] # save all key/value_states to cache to be re-used for fast auto-regressive generation
position_bias = self.compute_bias(query_length, key_states.size(2), device=attention_scores.device) cache_position = cache_position if not is_cross_attention else None
else: key_states, value_states = curr_past_key_value.update(
position_bias = torch.zeros( key_states, value_states, self.layer_idx, {"cache_position": cache_position}
(1, self.n_heads, seq_length, key_states.size(2)),
device=attention_scores.device,
dtype=attention_scores.dtype,
requires_grad=self.training,
) )
if past_key_value is not None: # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
position_bias = position_bias[:, :, -hidden_states.size(1) :, :] if is_cross_attention:
past_key_value.is_updated[self.layer_idx] = True
# compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
scores = torch.matmul(query_states, key_states.transpose(3, 2))
# cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
real_seq_length = seq_length + past_key_value.get_seq_length() if past_key_value is not None else seq_length
key_length = key_states.shape[-2]
if not self.has_relative_attention_bias:
position_bias = torch.zeros(
(1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype
)
else:
position_bias = self.compute_bias(
real_seq_length, key_length, device=scores.device, cache_position=cache_position
)
position_bias = position_bias[:, :, -seq_length:, :]
if attention_mask is not None: if attention_mask is not None:
position_bias = position_bias + attention_mask # (batch_size, n_heads, seq_length, key_length) causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
position_bias = position_bias + causal_mask
if self.is_decoder: if self.pruned_heads:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. mask = torch.ones(position_bias.shape[1])
# Further calls to cross_attention layer can then reuse all cross-attention mask[list(self.pruned_heads)] = 0
# key/value_states (first "if" case) position_bias_masked = position_bias[:, mask.bool()]
# if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of else:
# all previous decoder key/value_states. Further calls to uni-directional self-attention position_bias_masked = position_bias
# can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
# if encoder bi-directional self-attention `past_key_value` is always `None` scores += position_bias_masked
past_key_value = (key_states, value_states)
attention_scores += position_bias
# (batch_size, n_heads, seq_length, key_length) # (batch_size, n_heads, seq_length, key_length)
attn_weights = nn.functional.softmax(attention_scores.float(), dim=-1).type_as(attention_scores) attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
# Mask heads if we want to # Mask heads if we want to
if layer_head_mask is not None: if layer_head_mask is not None:
attn_weights = attn_weights * layer_head_mask attn_weights = attn_weights * layer_head_mask
# attn_output = torch.bmm(attn_probs, value_states) ? attn_output = torch.matmul(attn_weights, value_states)
context_states = torch.matmul(attn_weights, value_states)
# attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) ? attn_output = attn_output.transpose(1, 2).contiguous()
context_states = context_states.permute(0, 2, 1, 3).contiguous().view(batch_size, seq_length, -1) attn_output = attn_output.view(batch_size, seq_length, -1)
attn_output = self.o(context_states)
attn_output = self.o(attn_output)
return attn_output, attn_weights, past_key_value return attn_output, attn_weights, past_key_value
class UMT5LayerSelfAttention(nn.Module): class UMT5LayerSelfAttention(nn.Module):
def __init__(self, config): def __init__(self, config, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.SelfAttention = UMT5Attention(config, has_relative_attention_bias=True) self.SelfAttention = UMT5Attention(config, has_relative_attention_bias=True, layer_idx=layer_idx)
self.layer_norm = UMT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.layer_norm = UMT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate) self.dropout = nn.Dropout(config.dropout_rate)
@@ -330,6 +360,7 @@ class UMT5LayerSelfAttention(nn.Module):
attention_mask=None, attention_mask=None,
layer_head_mask=None, layer_head_mask=None,
past_key_value=None, past_key_value=None,
cache_position=None,
): ):
normed_hidden_states = self.layer_norm(hidden_states) normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.SelfAttention( attention_output = self.SelfAttention(
@@ -337,6 +368,7 @@ class UMT5LayerSelfAttention(nn.Module):
attention_mask=attention_mask, attention_mask=attention_mask,
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
past_key_value=past_key_value, past_key_value=past_key_value,
cache_position=cache_position,
) )
hidden_states = hidden_states + self.dropout(attention_output[0]) hidden_states = hidden_states + self.dropout(attention_output[0])
outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
@@ -344,9 +376,9 @@ class UMT5LayerSelfAttention(nn.Module):
class UMT5LayerCrossAttention(nn.Module): class UMT5LayerCrossAttention(nn.Module):
def __init__(self, config): def __init__(self, config, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.EncDecAttention = UMT5Attention(config, has_relative_attention_bias=False) self.EncDecAttention = UMT5Attention(config, has_relative_attention_bias=False, layer_idx=layer_idx)
self.layer_norm = UMT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.layer_norm = UMT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate) self.dropout = nn.Dropout(config.dropout_rate)
@@ -357,6 +389,7 @@ class UMT5LayerCrossAttention(nn.Module):
attention_mask=None, attention_mask=None,
layer_head_mask=None, layer_head_mask=None,
past_key_value=None, past_key_value=None,
cache_position=None,
): ):
normed_hidden_states = self.layer_norm(hidden_states) normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.EncDecAttention( attention_output = self.EncDecAttention(
@@ -365,6 +398,7 @@ class UMT5LayerCrossAttention(nn.Module):
attention_mask=attention_mask, attention_mask=attention_mask,
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
past_key_value=past_key_value, past_key_value=past_key_value,
cache_position=cache_position,
) )
layer_output = hidden_states + self.dropout(attention_output[0]) layer_output = hidden_states + self.dropout(attention_output[0])
outputs = (layer_output,) + attention_output[1:] # add attentions if we output them outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
@@ -372,13 +406,13 @@ class UMT5LayerCrossAttention(nn.Module):
class UMT5Block(nn.Module): class UMT5Block(nn.Module):
def __init__(self, config): def __init__(self, config, layer_idx: Optional[int] = None):
super().__init__() super().__init__()
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
self.layer = nn.ModuleList() self.layer = nn.ModuleList()
self.layer.append(UMT5LayerSelfAttention(config)) self.layer.append(UMT5LayerSelfAttention(config, layer_idx=layer_idx))
if self.is_decoder: if self.is_decoder:
self.layer.append(UMT5LayerCrossAttention(config)) self.layer.append(UMT5LayerCrossAttention(config, layer_idx=layer_idx))
self.layer.append(UMT5LayerFF(config)) self.layer.append(UMT5LayerFF(config))
@@ -393,16 +427,14 @@ class UMT5Block(nn.Module):
past_key_value=None, past_key_value=None,
use_cache=False, use_cache=False,
output_attentions=False, output_attentions=False,
cache_position=None,
): ):
# Self Attention hidden_states, self_attn_weights, past_key_value = self.layer[0](
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
hidden_states, self_attn_weights, present_key_value = self.layer[0](
hidden_states, hidden_states,
attention_mask=attention_mask, attention_mask=attention_mask,
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
past_key_value=self_attn_past_key_value, past_key_value=past_key_value,
cache_position=cache_position,
) )
# clamp inf values to enable fp16 training # clamp inf values to enable fp16 training
@@ -412,18 +444,16 @@ class UMT5Block(nn.Module):
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
# Cross-Attention Block # Cross-Attention Block
cross_attn_present_key_value = None
cross_attn_weights = None cross_attn_weights = None
do_cross_attention = self.is_decoder and encoder_hidden_states is not None do_cross_attention = self.is_decoder and encoder_hidden_states is not None
if do_cross_attention: if do_cross_attention:
# cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple hidden_states, cross_attn_weights, past_key_value = self.layer[1](
cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
hidden_states, cross_attn_weights, cross_attn_present_key_value = self.layer[1](
hidden_states, hidden_states,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask, attention_mask=encoder_attention_mask,
layer_head_mask=cross_attn_layer_head_mask, layer_head_mask=cross_attn_layer_head_mask,
past_key_value=cross_attn_past_key_value, past_key_value=past_key_value,
cache_position=cache_position,
) )
# clamp inf values to enable fp16 training # clamp inf values to enable fp16 training
if hidden_states.dtype == torch.float16: if hidden_states.dtype == torch.float16:
@@ -431,8 +461,6 @@ class UMT5Block(nn.Module):
clamp_value = torch.where(torch.isinf(hidden_states).any(), max_dtype - 1000, max_dtype) clamp_value = torch.where(torch.isinf(hidden_states).any(), max_dtype - 1000, max_dtype)
hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
present_key_value += cross_attn_present_key_value
# Apply Feed Forward layer # Apply Feed Forward layer
hidden_states = self.layer[-1](hidden_states) hidden_states = self.layer[-1](hidden_states)
@@ -444,7 +472,7 @@ class UMT5Block(nn.Module):
outputs = ( outputs = (
hidden_states, hidden_states,
present_key_value, past_key_value,
) )
if output_attentions: if output_attentions:
@@ -481,6 +509,8 @@ class UMT5PreTrainedModel(PreTrainedModel):
config_class = UMT5Config config_class = UMT5Config
base_model_prefix = "transformer" base_model_prefix = "transformer"
supports_gradient_checkpointing = True supports_gradient_checkpointing = True
_supports_cache_class = True
_supports_static_cache = True
_no_split_modules = ["UMT5Block"] _no_split_modules = ["UMT5Block"]
_keep_in_fp32_modules = ["wo"] _keep_in_fp32_modules = ["wo"]
@@ -594,7 +624,7 @@ class UMT5Stack(UMT5PreTrainedModel):
super().__init__(config) super().__init__(config)
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
self.is_decoder = config.is_decoder self.is_decoder = config.is_decoder
self.block = nn.ModuleList([UMT5Block(config) for i in range(config.num_layers)]) self.block = nn.ModuleList([UMT5Block(config, layer_idx=i) for i in range(config.num_layers)])
self.final_layer_norm = UMT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.final_layer_norm = UMT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate) self.dropout = nn.Dropout(config.dropout_rate)
@@ -622,6 +652,7 @@ class UMT5Stack(UMT5PreTrainedModel):
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
return_dict=None, return_dict=None,
cache_position=None,
): ):
use_cache = use_cache if use_cache is not None else self.config.use_cache use_cache = use_cache if use_cache is not None else self.config.use_cache
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
@@ -644,6 +675,13 @@ class UMT5Stack(UMT5PreTrainedModel):
err_msg_prefix = "decoder_" if self.is_decoder else "" err_msg_prefix = "decoder_" if self.is_decoder else ""
raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
if inputs_embeds is None: if inputs_embeds is None:
if self.embed_tokens is None: if self.embed_tokens is None:
raise ValueError("You have to initialize the model with valid token embeddings") raise ValueError("You have to initialize the model with valid token embeddings")
@@ -651,28 +689,57 @@ class UMT5Stack(UMT5PreTrainedModel):
batch_size, seq_length = input_shape batch_size, seq_length = input_shape
# required mask seq length can be calculated via length of past
mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length
if use_cache is True: if use_cache is True:
if not self.is_decoder: if not self.is_decoder:
raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
if attention_mask is None: # initialize past_key_values
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) return_legacy_cache = False
if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: return_self_attention_cache = False
encoder_seq_length = encoder_hidden_states.shape[1] if self.is_decoder and (use_cache or past_key_values is not None):
encoder_attention_mask = torch.ones( if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache):
batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long return_self_attention_cache = True
past_key_values = EncoderDecoderCache(past_key_values, DynamicCache())
elif not isinstance(past_key_values, EncoderDecoderCache):
return_legacy_cache = True
logger.warning_once(
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. "
"You should pass an instance of `EncoderDecoderCache` instead, e.g. "
"`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
)
past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
elif past_key_values is None:
past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache())
elif not self.is_decoder:
# do not pass cache object down the line for encoder stack
# it messes indexing later in decoder-stack because cache object is modified in-place
past_key_values = None
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
if cache_position is None:
cache_position = torch.arange(
past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
) )
# initialize past_key_values with `None` if past does not exist if attention_mask is None and not is_torchdynamo_compiling():
if past_key_values is None: # required mask seq length can be calculated via length of past cache
past_key_values = [None] * len(self.block) mask_seq_length = past_key_values_length + seq_length
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] if self.is_decoder:
# ourselves in which case we just need to make it broadcastable to all heads. causal_mask = self._update_causal_mask(
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) attention_mask,
inputs_embeds,
cache_position,
past_key_values.self_attention_cache if past_key_values is not None else None,
output_attentions,
)
elif attention_mask is not None:
causal_mask = attention_mask[:, None, None, :]
causal_mask = causal_mask.to(dtype=inputs_embeds.dtype)
causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min
else:
causal_mask = None
# If a 2D or 3D attention mask is provided for the cross-attention # If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
@@ -685,24 +752,16 @@ class UMT5Stack(UMT5PreTrainedModel):
else: else:
encoder_extended_attention_mask = None encoder_extended_attention_mask = None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
# Prepare head mask if needed # Prepare head mask if needed
head_mask = self.get_head_mask(head_mask, self.config.num_layers) head_mask = self.get_head_mask(head_mask, self.config.num_layers)
cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
present_key_value_states = () if use_cache else None
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None all_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.is_decoder else None all_cross_attentions = () if output_attentions and self.is_decoder else None
hidden_states = self.dropout(inputs_embeds) hidden_states = self.dropout(inputs_embeds)
for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): for i, layer_module in enumerate(self.block):
layer_head_mask = head_mask[i] layer_head_mask = head_mask[i]
cross_attn_layer_head_mask = cross_attn_head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i]
@@ -713,7 +772,7 @@ class UMT5Stack(UMT5PreTrainedModel):
layer_outputs = self._gradient_checkpointing_func( layer_outputs = self._gradient_checkpointing_func(
layer_module.forward, layer_module.forward,
hidden_states, hidden_states,
extended_attention_mask, causal_mask,
encoder_hidden_states, encoder_hidden_states,
encoder_extended_attention_mask, encoder_extended_attention_mask,
layer_head_mask, layer_head_mask,
@@ -721,24 +780,26 @@ class UMT5Stack(UMT5PreTrainedModel):
None, # past_key_value is always None with gradient checkpointing None, # past_key_value is always None with gradient checkpointing
use_cache, use_cache,
output_attentions, output_attentions,
cache_position,
) )
else: else:
layer_outputs = layer_module( layer_outputs = layer_module(
hidden_states, hidden_states,
attention_mask=extended_attention_mask, attention_mask=causal_mask,
encoder_hidden_states=encoder_hidden_states, encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask, encoder_attention_mask=encoder_extended_attention_mask,
layer_head_mask=layer_head_mask, layer_head_mask=layer_head_mask,
cross_attn_layer_head_mask=cross_attn_layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask,
past_key_value=past_key_value, past_key_value=past_key_values,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
cache_position=cache_position,
) )
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
if use_cache: if use_cache:
present_key_value_states += (layer_outputs[1],) next_decoder_cache = layer_outputs[1]
if output_attentions: if output_attentions:
all_attentions += (layer_outputs[2],) all_attentions += (layer_outputs[2],)
@@ -752,12 +813,18 @@ class UMT5Stack(UMT5PreTrainedModel):
if output_hidden_states: if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,) all_hidden_states = all_hidden_states + (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
if return_self_attention_cache:
next_cache = past_key_values.self_attention_cache
if return_legacy_cache:
next_cache = past_key_values.to_legacy_cache()
if not return_dict: if not return_dict:
return tuple( return tuple(
v v
for v in [ for v in [
hidden_states, hidden_states,
present_key_value_states, next_cache,
all_hidden_states, all_hidden_states,
all_attentions, all_attentions,
all_cross_attentions, all_cross_attentions,
@@ -766,12 +833,135 @@ class UMT5Stack(UMT5PreTrainedModel):
) )
return BaseModelOutputWithPastAndCrossAttentions( return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states, last_hidden_state=hidden_states,
past_key_values=present_key_value_states, past_key_values=next_cache,
hidden_states=all_hidden_states, hidden_states=all_hidden_states,
attentions=all_attentions, attentions=all_attentions,
cross_attentions=all_cross_attentions, cross_attentions=all_cross_attentions,
) )
# Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask
def _update_causal_mask(
self,
attention_mask: torch.Tensor,
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
output_attentions: bool,
):
if self.config._attn_implementation == "flash_attention_2":
if attention_mask is not None and 0.0 in attention_mask:
return attention_mask
return None
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
using_static_cache = isinstance(past_key_values, StaticCache)
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
past_key_values_length=past_seen_tokens,
is_training=self.training,
):
return None
dtype, device = input_tensor.dtype, input_tensor.device
sequence_length = input_tensor.shape[1]
if using_static_cache:
target_length = past_key_values.get_max_cache_shape()
else:
target_length = (
attention_mask.shape[-1]
if isinstance(attention_mask, torch.Tensor)
else past_seen_tokens + sequence_length + 1
)
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
attention_mask,
sequence_length=sequence_length,
target_length=target_length,
dtype=dtype,
device=device,
cache_position=cache_position,
batch_size=input_tensor.shape[0],
)
if (
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type == "cuda"
and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
min_dtype = torch.finfo(dtype).min
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
@staticmethod
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position
def _prepare_4d_causal_attention_mask_with_cache_position(
attention_mask: torch.Tensor,
sequence_length: int,
target_length: int,
dtype: torch.dtype,
device: torch.device,
cache_position: torch.Tensor,
batch_size: int,
**kwargs,
):
"""
Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
`(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
Args:
attention_mask (`torch.Tensor`):
A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
`(batch_size, 1, query_length, key_value_length)`.
sequence_length (`int`):
The sequence length being processed.
target_length (`int`):
The target length: when generating with static cache, the mask should be as long as the static cache,
to account for the 0 padding, the part of the cache that is not filled yet.
dtype (`torch.dtype`):
The dtype to use for the 4D attention mask.
device (`torch.device`):
The device to plcae the 4D attention mask on.
cache_position (`torch.Tensor`):
Indices depicting the position of the input sequence tokens in the sequence.
batch_size (`torch.Tensor`):
Batch size.
"""
if attention_mask is not None and attention_mask.dim() == 4:
# In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
causal_mask = attention_mask
else:
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
(sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
)
if sequence_length != 1:
causal_mask = torch.triu(causal_mask, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
padding_mask = padding_mask == 0
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
padding_mask, min_dtype
)
return causal_mask
UMT5_START_DOCSTRING = r""" UMT5_START_DOCSTRING = r"""
@@ -885,6 +1075,9 @@ UMT5_INPUTS_DOCSTRING = r"""
more detail. more detail.
return_dict (`bool`, *optional*): return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
cache in the correct position and to infer the complete sequence length.
""" """
UMT5_ENCODER_INPUTS_DOCSTRING = r""" UMT5_ENCODER_INPUTS_DOCSTRING = r"""
@@ -1022,6 +1215,7 @@ class UMT5Model(UMT5PreTrainedModel):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
r""" r"""
Returns: Returns:
@@ -1084,6 +1278,7 @@ class UMT5Model(UMT5PreTrainedModel):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
cache_position=cache_position,
) )
if not return_dict: if not return_dict:
@@ -1197,6 +1392,7 @@ class UMT5ForConditionalGeneration(UMT5PreTrainedModel, GenerationMixin):
output_attentions: Optional[bool] = None, output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None, output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None, return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
r""" r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -1268,6 +1464,7 @@ class UMT5ForConditionalGeneration(UMT5PreTrainedModel, GenerationMixin):
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
return_dict=return_dict, return_dict=return_dict,
cache_position=cache_position,
) )
sequence_output = decoder_outputs[0] sequence_output = decoder_outputs[0]

View File

@@ -31,6 +31,7 @@ from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
import torch import torch
import torch.nn.functional as F
from transformers import ( from transformers import (
MODEL_FOR_QUESTION_ANSWERING_MAPPING, MODEL_FOR_QUESTION_ANSWERING_MAPPING,
@@ -574,6 +575,41 @@ class LongT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
lm_labels, lm_labels,
) )
# overwrite because T5 doesn't accept position ids as input and expects `decoder_input_ids`
def test_custom_4d_attention_mask(self):
for model_class in self.all_generative_model_classes:
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config).to(device=torch_device, dtype=torch.float32)
(
input_ids,
_,
input_ids_shared_prefix,
mask_shared_prefix,
_,
) = self._get_custom_4d_mask_test_data()
logits = model.forward(
decoder_input_ids=input_ids,
input_ids=input_dict["input_ids"][:3],
).logits
# logits.shape == torch.Size([3, 4, ...])
logits_shared_prefix = model(
input_ids=input_dict["input_ids"][:1],
decoder_input_ids=input_ids_shared_prefix,
decoder_attention_mask=mask_shared_prefix,
)[0]
# logits_shared_prefix.shape == torch.Size([1, 6, ...])
out_last_tokens = logits[:, -1, :] # last tokens in each batch line
out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens
# comparing softmax-normalized logits:
normalized_0 = F.softmax(out_last_tokens)
normalized_1 = F.softmax(out_shared_prefix_last_tokens)
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
def test_decoder_model_past_with_large_inputs(self): def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)
@@ -602,7 +638,7 @@ class LongT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMix
(config_and_inputs[1], config_and_inputs[3], config_and_inputs[2]), (config_and_inputs[1], config_and_inputs[3], config_and_inputs[2]),
f"{tmpdirname}/longt5_test.onnx", f"{tmpdirname}/longt5_test.onnx",
export_params=True, export_params=True,
opset_version=13, opset_version=14,
input_names=["input_ids", "decoder_input_ids"], input_names=["input_ids", "decoder_input_ids"],
) )

View File

@@ -40,6 +40,7 @@ if is_torch_fx_available():
if is_torch_available(): if is_torch_available():
import torch import torch
import torch.nn.functional as F
from transformers import ( from transformers import (
AutoModelForSeq2SeqLM, AutoModelForSeq2SeqLM,
@@ -575,6 +576,9 @@ class MT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
# The small MT5 model needs higher percentages for CPU/MP tests # The small MT5 model needs higher percentages for CPU/MP tests
model_split_percents = [0.5, 0.8, 0.9] model_split_percents = [0.5, 0.8, 0.9]
# used in `test_torch_compile`
_torch_compile_test_ckpt = "google/mt5-small"
def setUp(self): def setUp(self):
self.model_tester = MT5ModelTester(self) self.model_tester = MT5ModelTester(self)
self.config_tester = ConfigTester(self, config_class=MT5Config, d_model=37) self.config_tester = ConfigTester(self, config_class=MT5Config, d_model=37)
@@ -627,12 +631,9 @@ class MT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
] ]
if labels is not None: if labels is not None:
input_names.append("labels") input_names.append("labels")
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = list(filtered_inputs.keys()) input_names = list(filtered_inputs.keys())
model_output = model(**filtered_inputs) model_output = model(**filtered_inputs)
traced_model = symbolic_trace(model, input_names) traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs) traced_output = traced_model(**filtered_inputs)
else: else:
@@ -647,7 +648,6 @@ class MT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
"visual_feats", "visual_feats",
"visual_pos", "visual_pos",
] ]
labels = inputs.get("labels", None) labels = inputs.get("labels", None)
start_positions = inputs.get("start_positions", None) start_positions = inputs.get("start_positions", None)
end_positions = inputs.get("end_positions", None) end_positions = inputs.get("end_positions", None)
@@ -657,15 +657,12 @@ class MT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
input_names.append("start_positions") input_names.append("start_positions")
if end_positions is not None: if end_positions is not None:
input_names.append("end_positions") input_names.append("end_positions")
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = list(filtered_inputs.keys()) input_names = list(filtered_inputs.keys())
if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and ( if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and (
not hasattr(model.config, "problem_type") or model.config.problem_type is None not hasattr(model.config, "problem_type") or model.config.problem_type is None
): ):
model.config.problem_type = "single_label_classification" model.config.problem_type = "single_label_classification"
traced_model = symbolic_trace(model, input_names) traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs) traced_output = traced_model(**filtered_inputs)
model_output = model(**filtered_inputs) model_output = model(**filtered_inputs)
@@ -718,6 +715,41 @@ class MT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
# (Even with this call, there are still memory leak by ~0.04MB) # (Even with this call, there are still memory leak by ~0.04MB)
self.clear_torch_jit_class_registry() self.clear_torch_jit_class_registry()
# overwrite because MT5 doesn't accept position ids as input and expects `decoder_input_ids`
def test_custom_4d_attention_mask(self):
for model_class in self.all_generative_model_classes:
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config).to(device=torch_device, dtype=torch.float32)
(
input_ids,
_,
input_ids_shared_prefix,
mask_shared_prefix,
_,
) = self._get_custom_4d_mask_test_data()
logits = model.forward(
decoder_input_ids=input_ids,
input_ids=input_dict["input_ids"][:3],
).logits
# logits.shape == torch.Size([3, 4, ...])
logits_shared_prefix = model(
input_ids=input_dict["input_ids"][:1],
decoder_input_ids=input_ids_shared_prefix,
decoder_attention_mask=mask_shared_prefix,
)[0]
# logits_shared_prefix.shape == torch.Size([1, 6, ...])
out_last_tokens = logits[:, -1, :] # last tokens in each batch line
out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens
# comparing softmax-normalized logits:
normalized_0 = F.softmax(out_last_tokens)
normalized_1 = F.softmax(out_shared_prefix_last_tokens)
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
def test_config(self): def test_config(self):
self.config_tester.run_common_tests() self.config_tester.run_common_tests()

View File

@@ -620,7 +620,7 @@ class Pop2PianoModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTester
(config_and_inputs[1], config_and_inputs[3], config_and_inputs[2]), (config_and_inputs[1], config_and_inputs[3], config_and_inputs[2]),
f"{tmpdirname}/Pop2Piano_test.onnx", f"{tmpdirname}/Pop2Piano_test.onnx",
export_params=True, export_params=True,
opset_version=9, opset_version=14,
input_names=["input_ids", "decoder_input_ids"], input_names=["input_ids", "decoder_input_ids"],
) )

View File

@@ -36,6 +36,7 @@ from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
import torch import torch
import torch.nn.functional as F
from transformers import ( from transformers import (
AutoTokenizer, AutoTokenizer,
@@ -645,6 +646,41 @@ class SwitchTransformersModelTest(ModelTesterMixin, GenerationTesterMixin, Pipel
lm_labels, lm_labels,
) )
# overwrite because T5 doesn't accept position ids as input and expects `decoder_input_ids`
def test_custom_4d_attention_mask(self):
for model_class in self.all_generative_model_classes:
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config).to(device=torch_device, dtype=torch.float32)
(
input_ids,
_,
input_ids_shared_prefix,
mask_shared_prefix,
_,
) = self._get_custom_4d_mask_test_data()
logits = model.forward(
decoder_input_ids=input_ids,
input_ids=input_dict["input_ids"][:3],
).logits
# logits.shape == torch.Size([3, 4, ...])
logits_shared_prefix = model(
input_ids=input_dict["input_ids"][:1],
decoder_input_ids=input_ids_shared_prefix,
decoder_attention_mask=mask_shared_prefix,
)[0]
# logits_shared_prefix.shape == torch.Size([1, 6, ...])
out_last_tokens = logits[:, -1, :] # last tokens in each batch line
out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens
# comparing softmax-normalized logits:
normalized_0 = F.softmax(out_last_tokens)
normalized_1 = F.softmax(out_shared_prefix_last_tokens)
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
def test_decoder_model_past_with_large_inputs(self): def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)

View File

@@ -27,6 +27,7 @@ from transformers.testing_utils import (
require_sentencepiece, require_sentencepiece,
require_tokenizers, require_tokenizers,
require_torch, require_torch,
require_torch_gpu,
slow, slow,
torch_device, torch_device,
) )
@@ -44,6 +45,7 @@ if is_torch_fx_available():
if is_torch_available(): if is_torch_available():
import torch import torch
import torch.nn.functional as F
from transformers import ( from transformers import (
AutoTokenizer, AutoTokenizer,
@@ -578,6 +580,9 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
# The small T5 model needs higher percentages for CPU/MP tests # The small T5 model needs higher percentages for CPU/MP tests
model_split_percents = [0.5, 0.8, 0.9] model_split_percents = [0.5, 0.8, 0.9]
# used in `test_torch_compile`
_torch_compile_test_ckpt = "google-t5/t5-small"
def setUp(self): def setUp(self):
self.model_tester = T5ModelTester(self) self.model_tester = T5ModelTester(self)
self.config_tester = ConfigTester(self, config_class=T5Config, d_model=37) self.config_tester = ConfigTester(self, config_class=T5Config, d_model=37)
@@ -630,12 +635,9 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
] ]
if labels is not None: if labels is not None:
input_names.append("labels") input_names.append("labels")
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = list(filtered_inputs.keys()) input_names = list(filtered_inputs.keys())
model_output = model(**filtered_inputs) model_output = model(**filtered_inputs)
traced_model = symbolic_trace(model, input_names) traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs) traced_output = traced_model(**filtered_inputs)
else: else:
@@ -650,7 +652,6 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
"visual_feats", "visual_feats",
"visual_pos", "visual_pos",
] ]
labels = inputs.get("labels", None) labels = inputs.get("labels", None)
start_positions = inputs.get("start_positions", None) start_positions = inputs.get("start_positions", None)
end_positions = inputs.get("end_positions", None) end_positions = inputs.get("end_positions", None)
@@ -660,15 +661,12 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
input_names.append("start_positions") input_names.append("start_positions")
if end_positions is not None: if end_positions is not None:
input_names.append("end_positions") input_names.append("end_positions")
filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names} filtered_inputs = {k: v for (k, v) in inputs.items() if k in input_names}
input_names = list(filtered_inputs.keys()) input_names = list(filtered_inputs.keys())
if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and ( if model.__class__.__name__ in set(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES.values()) and (
not hasattr(model.config, "problem_type") or model.config.problem_type is None not hasattr(model.config, "problem_type") or model.config.problem_type is None
): ):
model.config.problem_type = "single_label_classification" model.config.problem_type = "single_label_classification"
traced_model = symbolic_trace(model, input_names) traced_model = symbolic_trace(model, input_names)
traced_output = traced_model(**filtered_inputs) traced_output = traced_model(**filtered_inputs)
model_output = model(**filtered_inputs) model_output = model(**filtered_inputs)
@@ -721,6 +719,41 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin,
# (Even with this call, there are still memory leak by ~0.04MB) # (Even with this call, there are still memory leak by ~0.04MB)
self.clear_torch_jit_class_registry() self.clear_torch_jit_class_registry()
# overwrite because T5 doesn't accept position ids as input and expects `decoder_input_ids`
def test_custom_4d_attention_mask(self):
for model_class in self.all_generative_model_classes:
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config).to(device=torch_device, dtype=torch.float32)
(
input_ids,
_,
input_ids_shared_prefix,
mask_shared_prefix,
_,
) = self._get_custom_4d_mask_test_data()
logits = model.forward(
decoder_input_ids=input_ids,
input_ids=input_dict["input_ids"][:3],
).logits
# logits.shape == torch.Size([3, 4, ...])
logits_shared_prefix = model(
input_ids=input_dict["input_ids"][:1],
decoder_input_ids=input_ids_shared_prefix,
decoder_attention_mask=mask_shared_prefix,
)[0]
# logits_shared_prefix.shape == torch.Size([1, 6, ...])
out_last_tokens = logits[:, -1, :] # last tokens in each batch line
out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens
# comparing softmax-normalized logits:
normalized_0 = F.softmax(out_last_tokens)
normalized_1 = F.softmax(out_shared_prefix_last_tokens)
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
def test_config(self): def test_config(self):
self.config_tester.run_common_tests() self.config_tester.run_common_tests()
@@ -1482,6 +1515,7 @@ class T5ModelIntegrationTests(unittest.TestCase):
[model.config.prefix + x for x in [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY]], [model.config.prefix + x for x in [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY]],
padding="max_length", padding="max_length",
truncation=True, truncation=True,
max_length=512,
return_tensors="pt", return_tensors="pt",
).to(torch_device) ).to(torch_device)
self.assertEqual(512, dct["input_ids"].shape[1]) self.assertEqual(512, dct["input_ids"].shape[1])
@@ -1604,6 +1638,8 @@ class T5ModelIntegrationTests(unittest.TestCase):
outputs = t5_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64) outputs = t5_model.generate(input_ids, penalty_alpha=0.5, top_k=5, max_length=64)
generated_text = t5_tokenizer.batch_decode(outputs, skip_special_tokens=True) generated_text = t5_tokenizer.batch_decode(outputs, skip_special_tokens=True)
# TODO: @arthur?
# PR #31938 caused regression on this test which was fixed by PR #34089
self.assertListEqual( self.assertListEqual(
generated_text, generated_text,
[ [
@@ -1612,6 +1648,66 @@ class T5ModelIntegrationTests(unittest.TestCase):
], ],
) )
@slow
@require_torch_gpu
def test_compile_static_cache(self):
NUM_TOKENS_TO_GENERATE = 40
EXPECTED_TEXT_COMPLETION = [
"theory of relativity states that 1) the speed of light is constant in all inertial reference frames. the laws of physics are the same for all inertial reference frames.",
"ketchup is my favorite condiment.",
]
prompts = [
"summarize: Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial "
"reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe "
"theory of relativity is not hard to grasp.",
"summarize: My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, "
"my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my pizza.",
]
model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small").to(torch_device)
tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
# Dynamic Cache
generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False)
dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, dynamic_text)
# Static Cache
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
)
static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_text)
# Static Cache + compile
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
generated_ids = model.generate(
**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static"
)
static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
self.assertEqual(EXPECTED_TEXT_COMPLETION, static_compiled_text)
@slow
@require_torch_gpu
def test_compile_static_cache_encoder(self):
prompts = [
"summarize: Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial "
"reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe "
"theory of relativity is not hard to grasp.",
"summarize: My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, "
"my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my pizza.",
]
model = T5EncoderModel.from_pretrained("google-t5/t5-small").to(torch_device)
tokenizer = T5Tokenizer.from_pretrained("google-t5/t5-small")
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
logits = model(**inputs)
model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True)
logits_compiled = model(**inputs)
self.assertTrue(torch.allclose(logits[0][:, -3:, -3], logits_compiled[0][:, -3:, -3], atol=1e-5))
@require_torch @require_torch
class TestAsymmetricT5(unittest.TestCase): class TestAsymmetricT5(unittest.TestCase):

View File

@@ -37,6 +37,7 @@ from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available(): if is_torch_available():
import torch import torch
import torch.nn.functional as F
from transformers import UdopEncoderModel, UdopForConditionalGeneration, UdopModel, UdopProcessor from transformers import UdopEncoderModel, UdopForConditionalGeneration, UdopModel, UdopProcessor
@@ -348,6 +349,7 @@ class UdopModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
expected_arg_names = [ expected_arg_names = [
"attention_mask", "attention_mask",
"bbox", "bbox",
"cache_position",
"cross_attn_head_mask", "cross_attn_head_mask",
"decoder_attention_mask", "decoder_attention_mask",
"decoder_head_mask", "decoder_head_mask",
@@ -365,6 +367,43 @@ class UdopModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
expected_arg_names = sorted(expected_arg_names) expected_arg_names = sorted(expected_arg_names)
self.assertListEqual(sorted(arg_names[: len(expected_arg_names)]), expected_arg_names) self.assertListEqual(sorted(arg_names[: len(expected_arg_names)]), expected_arg_names)
# overwrite because T5 doesn't accept position ids as input and expects `decoder_input_ids`
def test_custom_4d_attention_mask(self):
for model_class in self.all_generative_model_classes:
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config).to(device=torch_device, dtype=torch.float32)
(
input_ids,
_,
input_ids_shared_prefix,
mask_shared_prefix,
_,
) = self._get_custom_4d_mask_test_data()
logits = model.forward(
decoder_input_ids=input_ids,
input_ids=input_dict["input_ids"][:3],
bbox=input_dict["bbox"][:3],
).logits
# logits.shape == torch.Size([3, 4, ...])
logits_shared_prefix = model(
input_ids=input_dict["input_ids"][:1],
bbox=input_dict["bbox"][:1],
decoder_input_ids=input_ids_shared_prefix,
decoder_attention_mask=mask_shared_prefix,
)[0]
# logits_shared_prefix.shape == torch.Size([1, 6, ...])
out_last_tokens = logits[:, -1, :] # last tokens in each batch line
out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens
# comparing softmax-normalized logits:
normalized_0 = F.softmax(out_last_tokens)
normalized_1 = F.softmax(out_shared_prefix_last_tokens)
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
@unittest.skip( @unittest.skip(
"Not currently compatible. Fails with - NotImplementedError: Cannot copy out of meta tensor; no data!" "Not currently compatible. Fails with - NotImplementedError: Cannot copy out of meta tensor; no data!"
) )
@@ -534,6 +573,41 @@ class UdopEncoderOnlyModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
# overwrite because T5 doesn't accept position ids as input and expects `decoder_input_ids`
def test_custom_4d_attention_mask(self):
for model_class in self.all_generative_model_classes:
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config).to(device=torch_device, dtype=torch.float32)
(
input_ids,
_,
input_ids_shared_prefix,
mask_shared_prefix,
_,
) = self._get_custom_4d_mask_test_data()
logits = model.forward(
decoder_input_ids=input_ids,
input_ids=input_dict["input_ids"][:3],
).logits
# logits.shape == torch.Size([3, 4, ...])
logits_shared_prefix = model(
input_ids=input_dict["input_ids"][:1],
decoder_input_ids=input_ids_shared_prefix,
decoder_attention_mask=mask_shared_prefix,
)[0]
# logits_shared_prefix.shape == torch.Size([1, 6, ...])
out_last_tokens = logits[:, -1, :] # last tokens in each batch line
out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens
# comparing softmax-normalized logits:
normalized_0 = F.softmax(out_last_tokens)
normalized_1 = F.softmax(out_shared_prefix_last_tokens)
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
@unittest.skip( @unittest.skip(
"Not currently compatible. Fails with - NotImplementedError: Cannot copy out of meta tensor; no data!" "Not currently compatible. Fails with - NotImplementedError: Cannot copy out of meta tensor; no data!"
) )

View File

@@ -41,6 +41,7 @@ if is_torch_fx_available():
if is_torch_available(): if is_torch_available():
import torch import torch
import torch.nn.functional as F
from transformers import ( from transformers import (
AutoTokenizer, AutoTokenizer,
@@ -316,6 +317,9 @@ class UMT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
# The small UMT5 model needs higher percentages for CPU/MP tests # The small UMT5 model needs higher percentages for CPU/MP tests
model_split_percents = [0.5, 0.8, 0.9] model_split_percents = [0.5, 0.8, 0.9]
# used in `test_torch_compile`
_torch_compile_test_ckpt = "google/umt5-small"
def setUp(self): def setUp(self):
self.model_tester = UMT5ModelTester(self) self.model_tester = UMT5ModelTester(self)
@@ -486,6 +490,41 @@ class UMT5ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
with torch.no_grad(): with torch.no_grad():
model(**inputs)[0] model(**inputs)[0]
# overwrite because T5 doesn't accept position ids as input and expects `decoder_input_ids`
def test_custom_4d_attention_mask(self):
for model_class in self.all_generative_model_classes:
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
model = model_class(config).to(device=torch_device, dtype=torch.float32)
(
input_ids,
_,
input_ids_shared_prefix,
mask_shared_prefix,
_,
) = self._get_custom_4d_mask_test_data()
logits = model.forward(
decoder_input_ids=input_ids,
input_ids=input_dict["input_ids"][:3],
).logits
# logits.shape == torch.Size([3, 4, ...])
logits_shared_prefix = model(
input_ids=input_dict["input_ids"][:1],
decoder_input_ids=input_ids_shared_prefix,
decoder_attention_mask=mask_shared_prefix,
)[0]
# logits_shared_prefix.shape == torch.Size([1, 6, ...])
out_last_tokens = logits[:, -1, :] # last tokens in each batch line
out_shared_prefix_last_tokens = logits_shared_prefix[0, -3:, :] # last three tokens
# comparing softmax-normalized logits:
normalized_0 = F.softmax(out_last_tokens)
normalized_1 = F.softmax(out_shared_prefix_last_tokens)
torch.testing.assert_close(normalized_0, normalized_1, rtol=1e-3, atol=1e-4)
def test_with_sequence_classification_head(self): def test_with_sequence_classification_head(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_with_sequence_classification_head(*config_and_inputs) self.model_tester.create_and_check_with_sequence_classification_head(*config_and_inputs)

View File

@@ -37,6 +37,7 @@ import transformers
from transformers import ( from transformers import (
AutoModel, AutoModel,
AutoModelForCausalLM, AutoModelForCausalLM,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
AutoTokenizer, AutoTokenizer,
GenerationConfig, GenerationConfig,
@@ -5109,7 +5110,12 @@ class ModelTesterMixin:
batch_size = 1 batch_size = 1
n_iter = 3 n_iter = 3
tokenizer = AutoTokenizer.from_pretrained(ckpt, revision=revision) tokenizer = AutoTokenizer.from_pretrained(ckpt)
if self.is_encoder_decoder:
model = AutoModelForSeq2SeqLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to(
torch_device
)
else:
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to( model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to(
torch_device torch_device
) )
@@ -5184,7 +5190,12 @@ class ModelTesterMixin:
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
tokenizer = AutoTokenizer.from_pretrained(ckpt, revision=revision) tokenizer = AutoTokenizer.from_pretrained(ckpt)
if self.is_encoder_decoder:
model = AutoModelForSeq2SeqLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to(
torch_device
)
else:
model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to( model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype=torch.float16, revision=revision).to(
torch_device torch_device
) )