[EncoderDecoder] Add Cross Attention for GPT2 (#6415)

* add cross attention layers for gpt2

* make gpt2 cross attention work

* finish bert2gpt2

* add explicit comments

* remove attention mask since not yet supported

* revert attn mask in pipeline

* Update src/transformers/modeling_gpt2.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/modeling_encoder_decoder.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Patrick von Platen
2020-08-14 09:43:29 +02:00
committed by GitHub
parent eb613b566a
commit 1d6e71e116
5 changed files with 239 additions and 60 deletions

View File

@@ -372,11 +372,16 @@ class GenerationMixin:
if self.config.is_encoder_decoder:
if decoder_start_token_id is None:
decoder_start_token_id = bos_token_id
# see if BOS token can be used for decoder_start_token_id
if bos_token_id is not None:
decoder_start_token_id = bos_token_id
elif hasattr(self.config, "decoder") and hasattr(self.config.decoder, "bos_token_id"):
decoder_start_token_id = self.config.decoder.bos_token_id
else:
raise ValueError(
"decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
)
assert (
decoder_start_token_id is not None
), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)

View File

@@ -287,6 +287,8 @@ class EncoderDecoderModel(PreTrainedModel):
**kwargs_decoder,
)
# TODO(PVP): currently it is not possible to use `past`
# with the encoder/decoder framework -> should be implemented
return decoder_outputs + encoder_outputs
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, **kwargs):
@@ -299,15 +301,24 @@ class EncoderDecoderModel(PreTrainedModel):
encoder_outputs = (past,)
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids)
return {
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
input_dict = {
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_inputs["attention_mask"],
"decoder_attention_mask": decoder_attention_mask,
"decoder_input_ids": decoder_inputs["input_ids"],
"encoder_outputs": encoder_outputs,
}
# Ideally all models should have a `use_cache`
# leave following to ifs until all have it implemented
if "use_cache" in decoder_inputs:
input_dict["decoder_use_cache"] = decoder_inputs["use_cache"]
if "past_key_values" in decoder_inputs:
input_dict["decoder_past_key_values"] = decoder_inputs["past_key_values"]
return input_dict
def _reorder_cache(self, past, beam_idx):
# as a default encoder-decoder models do not re-order the past.
# TODO(PVP): might have to be updated, e.g. if GPT2 is to be used as a decoder
return past
# apply decoder cache reordering here
return self.decoder._reorder_cache(past, beam_idx)

View File

@@ -118,7 +118,7 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
class Attention(nn.Module):
def __init__(self, nx, n_ctx, config, scale=False):
def __init__(self, nx, n_ctx, config, scale=False, is_cross_attention=False):
super().__init__()
n_state = nx # in Attention: n_state=768 (nx=n_embd)
@@ -131,8 +131,12 @@ class Attention(nn.Module):
self.n_head = config.n_head
self.split_size = n_state
self.scale = scale
self.c_attn = Conv1D(n_state * 3, nx)
self.is_cross_attention = is_cross_attention
if self.is_cross_attention:
self.c_attn = Conv1D(2 * n_state, nx)
self.q_attn = Conv1D(n_state, nx)
else:
self.c_attn = Conv1D(3 * n_state, nx)
self.c_proj = Conv1D(n_state, nx)
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
@@ -160,8 +164,11 @@ class Attention(nn.Module):
if self.scale:
w = w / (float(v.size(-1)) ** 0.5)
nd, ns = w.size(-2), w.size(-1)
mask = self.bias[:, :, ns - nd : ns, :ns]
w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype))
if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
mask = self.bias[:, :, ns - nd : ns, :ns]
w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype))
if attention_mask is not None:
# Apply the attention mask
@@ -193,10 +200,26 @@ class Attention(nn.Module):
return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
def forward(
self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, output_attentions=False
self,
hidden_states,
layer_past=None,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=False,
output_attentions=False,
):
x = self.c_attn(x)
query, key, value = x.split(self.split_size, dim=2)
if encoder_hidden_states is not None:
assert hasattr(
self, "q_attn"
), "If class is used as cross attention, the weights `q_attn` have to be defined. Please make sure to instantiate class with `Attention(..., is_cross_attention=True)`."
query = self.q_attn(hidden_states)
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
attention_mask = encoder_attention_mask
else:
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
query = self.split_heads(query)
key = self.split_heads(key, k=True)
value = self.split_heads(value)
@@ -239,32 +262,64 @@ class MLP(nn.Module):
class Block(nn.Module):
def __init__(self, n_ctx, config, scale=False):
super().__init__()
nx = config.n_embd
inner_dim = config.n_inner if config.n_inner is not None else 4 * nx
self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
self.attn = Attention(nx, n_ctx, config, scale)
self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon)
hidden_size = config.n_embd
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = Attention(hidden_size, n_ctx, config, scale)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
if config.add_cross_attention:
self.crossattention = Attention(hidden_size, n_ctx, config, scale, is_cross_attention=True)
self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = MLP(inner_dim, config)
def forward(
self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, output_attentions=False,
self,
hidden_states,
layer_past=None,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=False,
output_attentions=False,
):
output_attn = self.attn(
self.ln_1(x),
attn_outputs = self.attn(
self.ln_1(hidden_states),
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
a = output_attn[0] # output_attn: a, present, (attentions)
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
outputs = attn_outputs[1:]
# residual connection
hidden_states = attn_output + hidden_states
x = x + a
m = self.mlp(self.ln_2(x))
x = x + m
if encoder_hidden_states is not None:
# add one self-attention block for cross-attention
assert hasattr(
self, "crossattention"
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
cross_attn_outputs = self.crossattention(
self.ln_cross_attn(hidden_states),
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
)
attn_output = cross_attn_outputs[0]
# residual connection
hidden_states = hidden_states + attn_output
outputs = outputs + cross_attn_outputs[1:] # add cross attentions if we output attention weights
outputs = [x] + output_attn[1:]
return outputs # x, present, (attentions)
feed_forward_hidden_states = self.mlp(self.ln_2(hidden_states))
# residual connection
hidden_states = hidden_states + feed_forward_hidden_states
outputs = [hidden_states] + outputs
return outputs # hidden_states, present, (cross_attentions, attentions)
class GPT2PreTrainedModel(PreTrainedModel):
@@ -449,6 +504,8 @@ class GPT2Model(GPT2PreTrainedModel):
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
@@ -506,7 +563,7 @@ class GPT2Model(GPT2PreTrainedModel):
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = attention_mask[:, None, None, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
@@ -516,6 +573,17 @@ class GPT2Model(GPT2PreTrainedModel):
attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0
# If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
if self.config.add_cross_attention and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
@@ -546,6 +614,8 @@ class GPT2Model(GPT2PreTrainedModel):
layer_past=layer_past,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
)
@@ -593,17 +663,21 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
def get_output_embeddings(self):
return self.lm_head
def prepare_inputs_for_generation(self, input_ids, past, **kwargs):
def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
# only last token for inputs_ids if past is defined in kwargs
if past:
input_ids = input_ids[:, -1].unsqueeze(-1)
return {"input_ids": input_ids, "past_key_values": past, "use_cache": kwargs["use_cache"]}
return {
"input_ids": input_ids,
"past_key_values": past,
"use_cache": kwargs.get("use_cache"),
}
@add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
@add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="ctrl",
checkpoint="gpt2",
output_type=CausalLMOutputWithPast,
config_class=_CONFIG_FOR_DOC,
)
@@ -616,6 +690,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None,
use_cache=None,
output_attentions=None,
@@ -648,6 +724,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,