Add BertDecoderModel and Bert2Bert classes
I am not sure what happens when the class is initialized with the pretrained weights.
This commit is contained in:
@@ -788,6 +788,110 @@ class BertModel(BertPreTrainedModel):
|
|||||||
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
|
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
|
@add_start_docstrings("""A bare Bert decoder Model transformer outputting raw hidden-states without any specific head on top.
|
||||||
|
The model follows the general transformer decoder architecture.""",
|
||||||
|
BERT_START_DOCSTRING,
|
||||||
|
BERT_INPUTS_DOCSTRING)
|
||||||
|
class BertDecoderModel(BertPreTrainedModel):
|
||||||
|
r"""
|
||||||
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||||
|
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
|
||||||
|
Sequence of hidden-states at the output of the last layer of the model.
|
||||||
|
**pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)``
|
||||||
|
Last layer hidden-state of the first token of the sequence (classification token)
|
||||||
|
further processed by a Linear layer and a Tanh activation function. The Linear
|
||||||
|
layer weights are trained from the next sentence prediction (classification)
|
||||||
|
objective during Bert pretraining. This output is usually *not* a good summary
|
||||||
|
of the semantic content of the input, you're often better with averaging or pooling
|
||||||
|
the sequence of hidden-states for the whole input sequence.
|
||||||
|
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||||
|
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||||
|
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
|
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||||
|
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||||
|
model = BertDecoderModel.from_pretrained('bert-base-uncased')
|
||||||
|
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
|
||||||
|
outputs = model(input_ids)
|
||||||
|
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
|
||||||
|
|
||||||
|
"""
|
||||||
|
def __init__(self, config):
|
||||||
|
super(BertModel, self).__init__(config)
|
||||||
|
|
||||||
|
self.embeddings = BertEmbeddings(config)
|
||||||
|
self.decoder = BertDecoder(config)
|
||||||
|
self.pooler = BertPooler(config)
|
||||||
|
|
||||||
|
self.init_weights()
|
||||||
|
|
||||||
|
def _resize_token_embeddings(self, new_num_tokens):
|
||||||
|
old_embeddings = self.embeddings.word_embeddings
|
||||||
|
new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens)
|
||||||
|
self.embeddings.word_embeddings = new_embeddings
|
||||||
|
return self.embeddings.word_embeddings
|
||||||
|
|
||||||
|
def _prune_heads(self, heads_to_prune):
|
||||||
|
""" Prunes heads of the model.
|
||||||
|
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
||||||
|
See base class PreTrainedModel
|
||||||
|
"""
|
||||||
|
for layer, heads in heads_to_prune.items():
|
||||||
|
self.encoder.layer[layer].attention.prune_heads(heads)
|
||||||
|
|
||||||
|
def forward(self, input_ids, encoder_outputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
||||||
|
if attention_mask is None:
|
||||||
|
attention_mask = torch.ones_like(input_ids)
|
||||||
|
if token_type_ids is None:
|
||||||
|
token_type_ids = torch.zeros_like(input_ids)
|
||||||
|
|
||||||
|
# We create a 3D attention mask from a 2D tensor mask.
|
||||||
|
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||||
|
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||||
|
# this attention mask is more simple than the triangular masking of causal attention
|
||||||
|
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||||
|
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||||
|
|
||||||
|
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||||
|
# masked positions, this operation will create a tensor which is 0.0 for
|
||||||
|
# positions we want to attend and -10000.0 for masked positions.
|
||||||
|
# Since we are adding it to the raw scores before the softmax, this is
|
||||||
|
# effectively the same as removing these entirely.
|
||||||
|
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||||
|
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||||
|
|
||||||
|
# Prepare head mask if needed
|
||||||
|
# 1.0 in head_mask indicate we keep the head
|
||||||
|
# attention_probs has shape bsz x n_heads x N x N
|
||||||
|
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
|
||||||
|
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
|
||||||
|
if head_mask is not None:
|
||||||
|
if head_mask.dim() == 1:
|
||||||
|
head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
||||||
|
head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1)
|
||||||
|
elif head_mask.dim() == 2:
|
||||||
|
head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
|
||||||
|
head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility
|
||||||
|
else:
|
||||||
|
head_mask = [None] * self.config.num_hidden_layers
|
||||||
|
|
||||||
|
embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids)
|
||||||
|
decoder_outputs = self.decoder(embedding_output,
|
||||||
|
encoder_outputs,
|
||||||
|
extended_attention_mask,
|
||||||
|
head_mask=head_mask)
|
||||||
|
sequence_output = decoder_outputs[0]
|
||||||
|
pooled_output = self.pooler(sequence_output)
|
||||||
|
|
||||||
|
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
|
||||||
|
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings("""Bert Model with two heads on top as done during the pre-training:
|
@add_start_docstrings("""Bert Model with two heads on top as done during the pre-training:
|
||||||
a `masked language modeling` head and a `next sentence prediction (classification)` head. """,
|
a `masked language modeling` head and a `next sentence prediction (classification)` head. """,
|
||||||
BERT_START_DOCSTRING,
|
BERT_START_DOCSTRING,
|
||||||
@@ -1312,13 +1416,20 @@ class BertForQuestionAnswering(BertPreTrainedModel):
|
|||||||
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
|
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings("Bert encoder-decoder model",
|
@add_start_docstrings("Bert encoder-decoder model for sequence generation.",
|
||||||
BERT_START_DOCSTRING,
|
BERT_START_DOCSTRING,
|
||||||
BERT_INPUTS_DOCSTRING)
|
BERT_INPUTS_DOCSTRING)
|
||||||
class Bert2Bert(BertPreTrainedModel):
|
class Bert2Bert(BertPreTrainedModel):
|
||||||
r"""
|
r"""
|
||||||
|
|
||||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||||
|
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||||
|
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||||
|
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||||
|
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||||
|
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
|
||||||
|
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
|
||||||
|
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
|
|
||||||
@@ -1328,45 +1439,33 @@ class Bert2Bert(BertPreTrainedModel):
|
|||||||
outputs = model(input)
|
outputs = model(input)
|
||||||
output_text = tokenize.decode(outputs[0])
|
output_text = tokenize.decode(outputs[0])
|
||||||
print(output_text)
|
print(output_text)
|
||||||
|
|
||||||
|
References::
|
||||||
|
|
||||||
|
[1] "Leveraging Pre-trained Checkpoints for Sequence Generation Tasks", S.Rothe, S.Narayan & A.Severyn (2019) ArXiV:1907.12461v1
|
||||||
|
[2] Tensor2Tensor library https://github.com/tensorflow/tensor2tensor
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(Bert2Bert, self).__init__(config)
|
super(Bert2Bert, self).__init__(config)
|
||||||
self.embeddings = BertEmbeddings(config)
|
self.encoder = BertModel(config)
|
||||||
self.encoder = BertEncoder(config)
|
self.decoder = BertDecoderModel(config)
|
||||||
self.decoder = BertDecoder(config)
|
|
||||||
|
|
||||||
def forward(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
self.init_weights()
|
||||||
if attention_mask is None:
|
|
||||||
attention_mask = torch.ones_like(inputs)
|
|
||||||
if token_type_ids is None:
|
|
||||||
token_type_ids = torch.zeros_like(inputs)
|
|
||||||
|
|
||||||
# We create a 3D attention mask from a 2D tensor mask.
|
def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
||||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
encoder_outputs = self.encoder(input_ids,
|
||||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
attention_mask=attention_mask,
|
||||||
# this attention mask is more simple than the triangular masking of causal attention
|
token_type_ids=token_type_ids,
|
||||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
position_ids=position_ids,
|
||||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
|
||||||
|
|
||||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
|
||||||
# masked positions, this operation will create a tensor which is 0.0 for
|
|
||||||
# positions we want to attend and -10000.0 for masked positions.
|
|
||||||
# Since we are adding it to the raw scores before the softmax, this is
|
|
||||||
# effectively the same as removing these entirely.
|
|
||||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
|
||||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
|
||||||
|
|
||||||
embedding_output = self.embeddings(inputs, position_ids=position_ids, token_type_ids=token_type_ids)
|
|
||||||
encoder_outputs = self.encoder(embedding_output,
|
|
||||||
extended_attention_mask,
|
|
||||||
head_mask=head_mask)
|
head_mask=head_mask)
|
||||||
decoder_outputs = self.decoder(embedding_output,
|
encoder_output = encoder_outputs[0]
|
||||||
encoder_outputs[0],
|
|
||||||
extended_attention_mask,
|
|
||||||
head_mask=head_mask)
|
|
||||||
sequence_output = decoder_outputs[0]
|
|
||||||
pooled_output = self.pooler(sequence_output)
|
|
||||||
|
|
||||||
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
|
decoder_input = torch.empty_like(input_ids).normal_(mean=0.0, std=self.config.initializer_range)
|
||||||
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
|
decoder_outputs = self.decoder(decoder_input,
|
||||||
|
encoder_output,
|
||||||
|
token_type_ids=token_type_ids,
|
||||||
|
position_ids=position_ids,
|
||||||
|
head_mask=head_mask)
|
||||||
|
return decoder_outputs
|
||||||
|
|||||||
Reference in New Issue
Block a user