Cleaning up seq2seq [WIP]
This commit is contained in:
@@ -199,12 +199,14 @@ class BertSelfAttention(nn.Module):
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None):
|
||||
mixed_key_layer = self.key(hidden_states)
|
||||
mixed_value_layer = self.value(hidden_states)
|
||||
if encoder_hidden_states is not None: # if encoder-decoder attention
|
||||
mixed_query_layer = self.query(encoder_hidden_states)
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
# if the attention Module is a encoder-decoder self attention module
|
||||
if encoder_hidden_states is not None:
|
||||
mixed_key_layer = self.key(encoder_hidden_states)
|
||||
mixed_value_layer = self.value(encoder_hidden_states)
|
||||
else:
|
||||
mixed_query_layer = self.query(hidden_states)
|
||||
mixed_key_layer = self.key(hidden_states)
|
||||
mixed_value_layer = self.value(hidden_states)
|
||||
|
||||
query_layer = self.transpose_for_scores(mixed_query_layer)
|
||||
key_layer = self.transpose_for_scores(mixed_key_layer)
|
||||
@@ -322,26 +324,25 @@ class BertLayer(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertLayer, self).__init__()
|
||||
self.attention = BertAttention(config)
|
||||
if getattr(config, "is_decoder", False):
|
||||
self.is_decoder = config.is_decoder
|
||||
if self.is_decoder:
|
||||
self.crossattention = BertAttention(config)
|
||||
self.intermediate = BertIntermediate(config)
|
||||
self.output = BertOutput(config)
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_state=None):
|
||||
attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
|
||||
attention_output = attention_outputs[0]
|
||||
self_attention_outputs = self.attention(hidden_states, attention_mask, head_mask)
|
||||
attention_output = self_attention_outputs[0]
|
||||
outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
|
||||
|
||||
if encoder_hidden_state is not None:
|
||||
try:
|
||||
attention_outputs = self.crossattention(attention_output, attention_mask, head_mask, encoder_hidden_state)
|
||||
except AttributeError as ae:
|
||||
print("You need to set `is_encoder` to True in the configuration to instantiate an encoder layer:", ae)
|
||||
raise
|
||||
if self.is_decoder and encoder_hidden_state is not None:
|
||||
cross_attention_outputs = self.crossattention(attention_output, attention_mask, head_mask, encoder_hidden_state)
|
||||
attention_output = cross_attention_outputs[0]
|
||||
outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
|
||||
|
||||
attention_output = attention_outputs[0]
|
||||
intermediate_output = self.intermediate(attention_output)
|
||||
layer_output = self.output(intermediate_output, attention_output)
|
||||
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
|
||||
outputs = (layer_output,) + outputs
|
||||
return outputs
|
||||
|
||||
|
||||
@@ -352,14 +353,14 @@ class BertEncoder(nn.Module):
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
def forward(self, hidden_states, attention_mask=None, head_mask=None):
|
||||
def forward(self, hidden_states, attention_mask=None, head_mask=None, encoder_hidden_states=None):
|
||||
all_hidden_states = ()
|
||||
all_attentions = ()
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i])
|
||||
layer_outputs = layer_module(hidden_states, attention_mask, head_mask[i], encoder_hidden_states)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if self.output_attentions:
|
||||
@@ -377,42 +378,6 @@ class BertEncoder(nn.Module):
|
||||
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
||||
|
||||
|
||||
class BertDecoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertDecoder, self).__init__()
|
||||
config.is_decoder = True
|
||||
self.output_attentions = config.output_attentions
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
def forward(self, hidden_states, encoder_outputs, attention_mask=None, head_mask=None):
|
||||
all_hidden_states = ()
|
||||
all_attentions = ()
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_outputs = layer_module(hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
head_mask=head_mask[i],
|
||||
encoder_hidden_state=encoder_outputs)
|
||||
if self.output_attentions:
|
||||
all_attentions = all_attentions + (layer_outputs[1],)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
# Add last layer
|
||||
if self.output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if self.output_hidden_states:
|
||||
outputs = outputs + (all_hidden_states,)
|
||||
if self.output_attentions:
|
||||
outputs = outputs + (all_attentions,)
|
||||
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
||||
|
||||
|
||||
class BertPooler(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertPooler, self).__init__()
|
||||
@@ -635,7 +600,8 @@ class BertModel(BertPreTrainedModel):
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None,
|
||||
head_mask=None, encoder_hidden_state=None):
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
if token_type_ids is None:
|
||||
@@ -673,8 +639,9 @@ class BertModel(BertPreTrainedModel):
|
||||
|
||||
embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
|
||||
encoder_outputs = self.encoder(embedding_output,
|
||||
extended_attention_mask,
|
||||
head_mask=head_mask)
|
||||
attention_mask=extended_attention_mask,
|
||||
head_mask=head_mask,
|
||||
encoder_hidden_state=encoder_hidden_state)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
|
||||
@@ -682,111 +649,6 @@ class BertModel(BertPreTrainedModel):
|
||||
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""A bare Bert decoder Model transformer outputting raw hidden-states without any specific head on top.
|
||||
The model follows the general transformer decoder architecture.""",
|
||||
BERT_START_DOCSTRING,
|
||||
BERT_INPUTS_DOCSTRING)
|
||||
class BertDecoderModel(BertPreTrainedModel):
|
||||
r"""
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
|
||||
Sequence of hidden-states at the output of the last layer of the model.
|
||||
**pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)``
|
||||
Last layer hidden-state of the first token of the sequence (classification token)
|
||||
further processed by a Linear layer and a Tanh activation function. The Linear
|
||||
layer weights are trained from the next sentence prediction (classification)
|
||||
objective during Bert pretraining. This output is usually *not* a good summary
|
||||
of the semantic content of the input, you're often better with averaging or pooling
|
||||
the sequence of hidden-states for the whole input sequence.
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
|
||||
Examples::
|
||||
|
||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||
model = BertDecoderModel.from_pretrained('bert-base-uncased')
|
||||
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||
outputs = model(input_ids)
|
||||
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
||||
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(BertDecoderModel, self).__init__(config)
|
||||
|
||||
self.embeddings = BertEmbeddings(config)
|
||||
self.decoder = BertDecoder(config)
|
||||
self.pooler = BertPooler(config)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def _resize_token_embeddings(self, new_num_tokens):
|
||||
old_embeddings = self.embeddings.word_embeddings
|
||||
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
||||
self.embeddings.word_embeddings = new_embeddings
|
||||
return self.embeddings.word_embeddings
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
""" Prunes heads of the model.
|
||||
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||
See base class PreTrainedModel
|
||||
"""
|
||||
for layer, heads in heads_to_prune.items():
|
||||
self.decoder.layer[layer].attention.prune_heads(heads)
|
||||
self.decoder.layer[layer].crossattention.prune_heads(heads)
|
||||
|
||||
def forward(self, input_ids, encoder_outputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, training=False):
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros_like(input_ids)
|
||||
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
# Prepare head mask if needed
|
||||
# 1.0 in head_mask indicate we keep the head
|
||||
# attention_probs has shape bsz x n_heads x N x N
|
||||
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||
if head_mask is not None:
|
||||
if head_mask.dim() == 1:
|
||||
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
||||
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
|
||||
elif head_mask.dim() == 2:
|
||||
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
||||
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
|
||||
else:
|
||||
head_mask = [None] * self.config.num_hidden_layers
|
||||
|
||||
embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
|
||||
decoder_outputs = self.decoder(embedding_output,
|
||||
encoder_outputs,
|
||||
extended_attention_mask,
|
||||
head_mask=head_mask)
|
||||
sequence_output = decoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
|
||||
outputs = (sequence_output, pooled_output,) + decoder_outputs[1:] # add hidden_states and attentions if they are here
|
||||
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""Bert Model with two heads on top as done during the pre-training:
|
||||
a `masked language modeling` head and a `next sentence prediction (classification)` head. """,
|
||||
BERT_START_DOCSTRING,
|
||||
@@ -1309,101 +1171,3 @@ class BertForQuestionAnswering(BertPreTrainedModel):
|
||||
outputs = (total_loss,) + outputs
|
||||
|
||||
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("Bert encoder-decoder model for sequence generation.",
|
||||
BERT_START_DOCSTRING,
|
||||
BERT_INPUTS_DOCSTRING)
|
||||
class Bert2Rnd(BertPreTrainedModel):
|
||||
r"""
|
||||
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||
|
||||
Examples::
|
||||
|
||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||
model = Bert2Rnd.from_pretrained('bert-base-uncased')
|
||||
# fine-tuning magic happens here
|
||||
input = tokenizer.encode("Hello, how are you?")
|
||||
outputs = model(input)
|
||||
output_text = tokenize.decode(outputs[0])
|
||||
print(output_text)
|
||||
|
||||
References::
|
||||
|
||||
[1] "Leveraging Pre-trained Checkpoints for Sequence Generation Tasks", S.Rothe, S.Narayan & A.Severyn (2019) ArXiV:1907.12461v1
|
||||
[2] Tensor2Tensor library https://github.com/tensorflow/tensor2tensor
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(Bert2Rnd, self).__init__(config)
|
||||
self.encoder = BertModel(config)
|
||||
self.decoder = BertDecoderModel(config)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_or_path, *model_args, **model_kwargs):
|
||||
""" Load the pretrained weights in the encoder.
|
||||
|
||||
The encoder of `Bert2Rand` is initialized with pretrained weights; the
|
||||
weights of the decoder are initialized at random except the embeddings
|
||||
which are initialized with the pretrained embeddings. We thus need to override
|
||||
the base class' `from_pretrained` method.
|
||||
"""
|
||||
|
||||
# Load the configuration
|
||||
config = model_kwargs.pop('config', None)
|
||||
if config is None:
|
||||
cache_dir = model_kwargs.pop('cache_dir', None)
|
||||
force_download = model_kwargs.pop('force_download', False)
|
||||
config, _ = cls.config_class.from_pretrained(
|
||||
pretrained_model_or_path,
|
||||
*model_args,
|
||||
cache_dir=cache_dir,
|
||||
return_unused_kwargs=True,
|
||||
force_download=force_download,
|
||||
**model_kwargs
|
||||
)
|
||||
model = cls(config)
|
||||
|
||||
# We load the encoder with pretrained weights
|
||||
pretrained_encoder = BertModel.from_pretrained(pretrained_model_or_path, *model_args, **model_kwargs)
|
||||
model.encoder = pretrained_encoder
|
||||
|
||||
# We load the decoder with pretrained weights and then randomize all weights but embeddings-related one.
|
||||
def randomize_decoder_weights(module):
|
||||
if isinstance(module, nn.Linear):
|
||||
# Slightly different from the TF version which uses truncated_normal for initialization
|
||||
# cf https://github.com/pytorch/pytorch/pull/5617
|
||||
module.weight.data.normal_(mean=0.0, std=config.initializer_range)
|
||||
elif isinstance(module, BertLayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
pretrained_decoder = BertDecoderModel.from_pretrained(pretrained_model_or_path, *model_args, **model_kwargs)
|
||||
pretrained_decoder.apply(randomize_decoder_weights)
|
||||
model.decoder = pretrained_decoder
|
||||
|
||||
return model
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
||||
encoder_outputs = self.encoder(input_ids,
|
||||
attention_mask=attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask)
|
||||
decoder_outputs = self.decoder(input_ids,
|
||||
encoder_outputs[0],
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
head_mask=head_mask)
|
||||
return decoder_outputs
|
||||
|
||||
Reference in New Issue
Block a user