[EncoderDecoder] Add Cross Attention for GPT2 (#6415)
* add cross attention layers for gpt2 * make gpt2 cross attention work * finish bert2gpt2 * add explicit comments * remove attention mask since not yet supported * revert attn mask in pipeline * Update src/transformers/modeling_gpt2.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_encoder_decoder.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
parent
eb613b566a
commit
1d6e71e116
@@ -372,11 +372,16 @@ class GenerationMixin:
|
||||
|
||||
if self.config.is_encoder_decoder:
|
||||
if decoder_start_token_id is None:
|
||||
decoder_start_token_id = bos_token_id
|
||||
# see if BOS token can be used for decoder_start_token_id
|
||||
if bos_token_id is not None:
|
||||
decoder_start_token_id = bos_token_id
|
||||
elif hasattr(self.config, "decoder") and hasattr(self.config.decoder, "bos_token_id"):
|
||||
decoder_start_token_id = self.config.decoder.bos_token_id
|
||||
else:
|
||||
raise ValueError(
|
||||
"decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
|
||||
)
|
||||
|
||||
assert (
|
||||
decoder_start_token_id is not None
|
||||
), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
|
||||
assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
|
||||
assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
|
||||
|
||||
|
||||
@@ -287,6 +287,8 @@ class EncoderDecoderModel(PreTrainedModel):
|
||||
**kwargs_decoder,
|
||||
)
|
||||
|
||||
# TODO(PVP): currently it is not possible to use `past`
|
||||
# with the encoder/decoder framework -> should be implemented
|
||||
return decoder_outputs + encoder_outputs
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, **kwargs):
|
||||
@@ -299,15 +301,24 @@ class EncoderDecoderModel(PreTrainedModel):
|
||||
encoder_outputs = (past,)
|
||||
|
||||
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids)
|
||||
|
||||
return {
|
||||
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
|
||||
input_dict = {
|
||||
"attention_mask": attention_mask,
|
||||
"decoder_attention_mask": decoder_inputs["attention_mask"],
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"decoder_input_ids": decoder_inputs["input_ids"],
|
||||
"encoder_outputs": encoder_outputs,
|
||||
}
|
||||
|
||||
# Ideally all models should have a `use_cache`
|
||||
# leave following to ifs until all have it implemented
|
||||
if "use_cache" in decoder_inputs:
|
||||
input_dict["decoder_use_cache"] = decoder_inputs["use_cache"]
|
||||
|
||||
if "past_key_values" in decoder_inputs:
|
||||
input_dict["decoder_past_key_values"] = decoder_inputs["past_key_values"]
|
||||
|
||||
return input_dict
|
||||
|
||||
def _reorder_cache(self, past, beam_idx):
|
||||
# as a default encoder-decoder models do not re-order the past.
|
||||
# TODO(PVP): might have to be updated, e.g. if GPT2 is to be used as a decoder
|
||||
return past
|
||||
# apply decoder cache reordering here
|
||||
return self.decoder._reorder_cache(past, beam_idx)
|
||||
|
||||
@@ -118,7 +118,7 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, nx, n_ctx, config, scale=False):
|
||||
def __init__(self, nx, n_ctx, config, scale=False, is_cross_attention=False):
|
||||
super().__init__()
|
||||
|
||||
n_state = nx # in Attention: n_state=768 (nx=n_embd)
|
||||
@@ -131,8 +131,12 @@ class Attention(nn.Module):
|
||||
self.n_head = config.n_head
|
||||
self.split_size = n_state
|
||||
self.scale = scale
|
||||
|
||||
self.c_attn = Conv1D(n_state * 3, nx)
|
||||
self.is_cross_attention = is_cross_attention
|
||||
if self.is_cross_attention:
|
||||
self.c_attn = Conv1D(2 * n_state, nx)
|
||||
self.q_attn = Conv1D(n_state, nx)
|
||||
else:
|
||||
self.c_attn = Conv1D(3 * n_state, nx)
|
||||
self.c_proj = Conv1D(n_state, nx)
|
||||
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
||||
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
||||
@@ -160,8 +164,11 @@ class Attention(nn.Module):
|
||||
if self.scale:
|
||||
w = w / (float(v.size(-1)) ** 0.5)
|
||||
nd, ns = w.size(-2), w.size(-1)
|
||||
mask = self.bias[:, :, ns - nd : ns, :ns]
|
||||
w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype))
|
||||
|
||||
if not self.is_cross_attention:
|
||||
# if only "normal" attention layer implements causal mask
|
||||
mask = self.bias[:, :, ns - nd : ns, :ns]
|
||||
w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype))
|
||||
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask
|
||||
@@ -193,10 +200,26 @@ class Attention(nn.Module):
|
||||
return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
|
||||
|
||||
def forward(
|
||||
self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, output_attentions=False
|
||||
self,
|
||||
hidden_states,
|
||||
layer_past=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
):
|
||||
x = self.c_attn(x)
|
||||
query, key, value = x.split(self.split_size, dim=2)
|
||||
if encoder_hidden_states is not None:
|
||||
assert hasattr(
|
||||
self, "q_attn"
|
||||
), "If class is used as cross attention, the weights `q_attn` have to be defined. Please make sure to instantiate class with `Attention(..., is_cross_attention=True)`."
|
||||
query = self.q_attn(hidden_states)
|
||||
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
|
||||
attention_mask = encoder_attention_mask
|
||||
else:
|
||||
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
||||
|
||||
query = self.split_heads(query)
|
||||
key = self.split_heads(key, k=True)
|
||||
value = self.split_heads(value)
|
||||
@@ -239,32 +262,64 @@ class MLP(nn.Module):
|
||||
class Block(nn.Module):
|
||||
def __init__(self, n_ctx, config, scale=False):
|
||||
super().__init__()
|
||||
nx = config.n_embd
|
||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * nx
|
||||
self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
|
||||
self.attn = Attention(nx, n_ctx, config, scale)
|
||||
self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
|
||||
hidden_size = config.n_embd
|
||||
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
||||
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.attn = Attention(hidden_size, n_ctx, config, scale)
|
||||
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
if config.add_cross_attention:
|
||||
self.crossattention = Attention(hidden_size, n_ctx, config, scale, is_cross_attention=True)
|
||||
self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
||||
self.mlp = MLP(inner_dim, config)
|
||||
|
||||
def forward(
|
||||
self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, output_attentions=False,
|
||||
self,
|
||||
hidden_states,
|
||||
layer_past=None,
|
||||
attention_mask=None,
|
||||
head_mask=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
use_cache=False,
|
||||
output_attentions=False,
|
||||
):
|
||||
output_attn = self.attn(
|
||||
self.ln_1(x),
|
||||
attn_outputs = self.attn(
|
||||
self.ln_1(hidden_states),
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
a = output_attn[0] # output_attn: a, present, (attentions)
|
||||
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
|
||||
outputs = attn_outputs[1:]
|
||||
# residual connection
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
x = x + a
|
||||
m = self.mlp(self.ln_2(x))
|
||||
x = x + m
|
||||
if encoder_hidden_states is not None:
|
||||
# add one self-attention block for cross-attention
|
||||
assert hasattr(
|
||||
self, "crossattention"
|
||||
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
|
||||
cross_attn_outputs = self.crossattention(
|
||||
self.ln_cross_attn(hidden_states),
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
attn_output = cross_attn_outputs[0]
|
||||
# residual connection
|
||||
hidden_states = hidden_states + attn_output
|
||||
outputs = outputs + cross_attn_outputs[1:] # add cross attentions if we output attention weights
|
||||
|
||||
outputs = [x] + output_attn[1:]
|
||||
return outputs # x, present, (attentions)
|
||||
feed_forward_hidden_states = self.mlp(self.ln_2(hidden_states))
|
||||
# residual connection
|
||||
hidden_states = hidden_states + feed_forward_hidden_states
|
||||
|
||||
outputs = [hidden_states] + outputs
|
||||
return outputs # hidden_states, present, (cross_attentions, attentions)
|
||||
|
||||
|
||||
class GPT2PreTrainedModel(PreTrainedModel):
|
||||
@@ -449,6 +504,8 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
@@ -506,7 +563,7 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
attention_mask = attention_mask[:, None, None, :]
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
@@ -516,6 +573,17 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
attention_mask = (1.0 - attention_mask) * -10000.0
|
||||
|
||||
# If a 2D ou 3D attention mask is provided for the cross-attention
|
||||
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
|
||||
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
||||
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
||||
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
||||
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
||||
else:
|
||||
encoder_attention_mask = None
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
@@ -546,6 +614,8 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
layer_past=layer_past,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask[i],
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
)
|
||||
@@ -593,17 +663,21 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||
def get_output_embeddings(self):
|
||||
return self.lm_head
|
||||
|
||||
def prepare_inputs_for_generation(self, input_ids, past, **kwargs):
|
||||
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if past:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
return {"input_ids": input_ids, "past_key_values": past, "use_cache": kwargs["use_cache"]}
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"past_key_values": past,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
}
|
||||
|
||||
@add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
|
||||
@add_code_sample_docstrings(
|
||||
tokenizer_class=_TOKENIZER_FOR_DOC,
|
||||
checkpoint="ctrl",
|
||||
checkpoint="gpt2",
|
||||
output_type=CausalLMOutputWithPast,
|
||||
config_class=_CONFIG_FOR_DOC,
|
||||
)
|
||||
@@ -616,6 +690,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||
position_ids=None,
|
||||
head_mask=None,
|
||||
inputs_embeds=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
labels=None,
|
||||
use_cache=None,
|
||||
output_attentions=None,
|
||||
@@ -648,6 +724,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
|
||||
Reference in New Issue
Block a user