Cleaning up seq2seq [WIP]

This commit is contained in:
thomwolf
2019-10-14 11:58:13 +02:00
parent b3261e7ace
commit 0ef9bc923a
2 changed files with 273 additions and 260 deletions

View File

@@ -199,12 +199,14 @@ class BertSelfAttention(nn.Module):
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None):
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
if encoder_hidden_states is not None: # if encoder-decoder attention
mixed_query_layer = self.query(encoder_hidden_states)
mixed_query_layer = self.query(hidden_states)
# if the attention Module is a encoder-decoder self attention module
if encoder_hidden_states is not None:
mixed_key_layer = self.key(encoder_hidden_states)
mixed_value_layer = self.value(encoder_hidden_states)
else:
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
@@ -322,26 +324,25 @@ class BertLayer(nn.Module):
def __init__(self, config):
super(BertLayer, self).__init__()
self.attention = BertAttention(config)
if getattr(config, "is_decoder", False):
self.is_decoder = config.is_decoder
if self.is_decoder:
self.crossattention = BertAttention(config)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_state=None):
attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
attention_output = attention_outputs[0]
self_attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
attention_output = self_attention_outputs[0]
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
if encoder_hidden_state is not None:
try:
attention_outputs = self.crossattention(attention_output, attention_mask, head_mask, encoder_hidden_state)
except AttributeError as ae:
print("You need to set `is_encoder` to True in the configuration to instantiate an encoder layer:", ae)
raise
if self.is_decoder and encoder_hidden_state is not None:
cross_attention_outputs = self.crossattention(attention_output, attention_mask, head_mask, encoder_hidden_state)
attention_output = cross_attention_outputs[0]
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
attention_output = attention_outputs[0]
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
outputs = (layer_output,) + outputs
return outputs
@@ -352,14 +353,14 @@ class BertEncoder(nn.Module):
self.output_hidden_states = config.output_hidden_states
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, attention_mask=None, head_mask=None):
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None):
all_hidden_states = ()
all_attentions = ()
for i, layer_module in enumerate(self.layer):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i])
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], encoder_hidden_states)
hidden_states = layer_outputs[0]
if self.output_attentions:
@@ -377,42 +378,6 @@ class BertEncoder(nn.Module):
return outputs # last-layer hidden state, (all hidden states), (all attentions)
class BertDecoder(nn.Module):
def __init__(self, config):
super(BertDecoder, self).__init__()
config.is_decoder = True
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
def forward(self, hidden_states, encoder_outputs, attention_mask=None, head_mask=None):
all_hidden_states = ()
all_attentions = ()
for i, layer_module in enumerate(self.layer):
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(hidden_states,
attention_mask=attention_mask,
head_mask=head_mask[i],
encoder_hidden_state=encoder_outputs)
if self.output_attentions:
all_attentions = all_attentions + (layer_outputs[1],)
hidden_states = layer_outputs[0]
# Add last layer
if self.output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if self.output_hidden_states:
outputs = outputs + (all_hidden_states,)
if self.output_attentions:
outputs = outputs + (all_attentions,)
return outputs # last-layer hidden state, (all hidden states), (all attentions)
class BertPooler(nn.Module):
def __init__(self, config):
super(BertPooler, self).__init__()
@@ -635,7 +600,8 @@ class BertModel(BertPreTrainedModel):
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None,
head_mask=None, encoder_hidden_state=None):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
@@ -673,8 +639,9 @@ class BertModel(BertPreTrainedModel):
embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
encoder_outputs = self.encoder(embedding_output,
extended_attention_mask,
head_mask=head_mask)
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_state=encoder_hidden_state)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output)
@@ -682,111 +649,6 @@ class BertModel(BertPreTrainedModel):
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
@add_start_docstrings("""A bare Bert decoder Model transformer outputting raw hidden-states without any specific head on top.
The model follows the general transformer decoder architecture.""",
BERT_START_DOCSTRING,
BERT_INPUTS_DOCSTRING)
class BertDecoderModel(BertPreTrainedModel):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
Sequence of hidden-states at the output of the last layer of the model.
**pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)``
Last layer hidden-state of the first token of the sequence (classification token)
further processed by a Linear layer and a Tanh activation function. The Linear
layer weights are trained from the next sentence prediction (classification)
objective during Bert pretraining. This output is usually *not* a good summary
of the semantic content of the input, you're often better with averaging or pooling
the sequence of hidden-states for the whole input sequence.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertDecoderModel.from_pretrained('bert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
"""
def __init__(self, config):
super(BertDecoderModel, self).__init__(config)
self.embeddings = BertEmbeddings(config)
self.decoder = BertDecoder(config)
self.pooler = BertPooler(config)
self.init_weights()
def _resize_token_embeddings(self, new_num_tokens):
old_embeddings = self.embeddings.word_embeddings
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
self.embeddings.word_embeddings = new_embeddings
return self.embeddings.word_embeddings
def _prune_heads(self, heads_to_prune):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.decoder.layer[layer].attention.prune_heads(heads)
self.decoder.layer[layer].crossattention.prune_heads(heads)
def forward(self, input_ids, encoder_outputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if head_mask is not None:
if head_mask.dim() == 1:
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
elif head_mask.dim() == 2:
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
else:
head_mask = [None] * self.config.num_hidden_layers
embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
decoder_outputs = self.decoder(embedding_output,
encoder_outputs,
extended_attention_mask,
head_mask=head_mask)
sequence_output = decoder_outputs[0]
pooled_output = self.pooler(sequence_output)
outputs = (sequence_output, pooled_output,) + decoder_outputs[1:] # add hidden_states and attentions if they are here
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
@add_start_docstrings("""Bert Model with two heads on top as done during the pre-training:
a `masked language modeling` head and a `next sentence prediction (classification)` head. """,
BERT_START_DOCSTRING,
@@ -1309,101 +1171,3 @@ class BertForQuestionAnswering(BertPreTrainedModel):
outputs = (total_loss,) + outputs
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
@add_start_docstrings("Bert encoder-decoder model for sequence generation.",
BERT_START_DOCSTRING,
BERT_INPUTS_DOCSTRING)
class Bert2Rnd(BertPreTrainedModel):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = Bert2Rnd.from_pretrained('bert-base-uncased')
# fine-tuning magic happens here
input = tokenizer.encode("Hello, how are you?")
outputs = model(input)
output_text = tokenize.decode(outputs[0])
print(output_text)
References::
[1] "Leveraging Pre-trained Checkpoints for Sequence Generation Tasks", S.Rothe, S.Narayan & A.Severyn (2019) ArXiV:1907.12461v1
[2] Tensor2Tensor library https://github.com/tensorflow/tensor2tensor
"""
def __init__(self, config):
super(Bert2Rnd, self).__init__(config)
self.encoder = BertModel(config)
self.decoder = BertDecoderModel(config)
@classmethod
def from_pretrained(cls, pretrained_model_or_path, *model_args, **model_kwargs):
""" Load the pretrained weights in the encoder.
The encoder of `Bert2Rand` is initialized with pretrained weights; the
weights of the decoder are initialized at random except the embeddings
which are initialized with the pretrained embeddings. We thus need to override
the base class' `from_pretrained` method.
"""
# Load the configuration
config = model_kwargs.pop('config', None)
if config is None:
cache_dir = model_kwargs.pop('cache_dir', None)
force_download = model_kwargs.pop('force_download', False)
config, _ = cls.config_class.from_pretrained(
pretrained_model_or_path,
*model_args,
cache_dir=cache_dir,
return_unused_kwargs=True,
force_download=force_download,
**model_kwargs
)
model = cls(config)
# We load the encoder with pretrained weights
pretrained_encoder = BertModel.from_pretrained(pretrained_model_or_path, *model_args, **model_kwargs)
model.encoder = pretrained_encoder
# We load the decoder with pretrained weights and then randomize all weights but embeddings-related one.
def randomize_decoder_weights(module):
if isinstance(module, nn.Linear):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(mean=0.0, std=config.initializer_range)
elif isinstance(module, BertLayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
if isinstance(module, nn.Linear) and module.bias is not None:
module.bias.data.zero_()
pretrained_decoder = BertDecoderModel.from_pretrained(pretrained_model_or_path, *model_args, **model_kwargs)
pretrained_decoder.apply(randomize_decoder_weights)
model.decoder = pretrained_decoder
return model
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
encoder_outputs = self.encoder(input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
decoder_outputs = self.decoder(input_ids,
encoder_outputs[0],
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask)
return decoder_outputs