From 070098309074e161ed1df35cf918013aeccba462 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Tue, 8 Oct 2019 15:39:47 +0200 Subject: [PATCH] Add BertDecoderModel and Bert2Bert classes I am not sure what happens when the class is initialized with the pretrained weights. --- transformers/modeling_bert.py | 169 +++++++++++++++++++++++++++------- 1 file changed, 134 insertions(+), 35 deletions(-) diff --git a/transformers/modeling_bert.py b/transformers/modeling_bert.py index 1e228556b4..9ce32d808e 100644 --- a/transformers/modeling_bert.py +++ b/transformers/modeling_bert.py @@ -788,6 +788,110 @@ class BertModel(BertPreTrainedModel): return outputs # sequence_output, pooled_output, (hidden_states), (attentions) +@add_start_docstrings("""A bare Bert decoder Model transformer outputting raw hidden-states without any specific head on top. + The model follows the general transformer decoder architecture.""", + BERT_START_DOCSTRING, + BERT_INPUTS_DOCSTRING) +class BertDecoderModel(BertPreTrainedModel): + r""" + Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: + **last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)`` + Sequence of hidden-states at the output of the last layer of the model. + **pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)`` + Last layer hidden-state of the first token of the sequence (classification token) + further processed by a Linear layer and a Tanh activation function. The Linear + layer weights are trained from the next sentence prediction (classification) + objective during Bert pretraining. This output is usually *not* a good summary + of the semantic content of the input, you're often better with averaging or pooling + the sequence of hidden-states for the whole input sequence. + **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) + list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) + of shape ``(batch_size, sequence_length, hidden_size)``: + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + **attentions**: (`optional`, returned when ``config.output_attentions=True``) + list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. + + Examples:: + + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + model = BertDecoderModel.from_pretrained('bert-base-uncased') + input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1 + outputs = model(input_ids) + last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple + + """ + def __init__(self, config): + super(BertModel, self).__init__(config) + + self.embeddings = BertEmbeddings(config) + self.decoder = BertDecoder(config) + self.pooler = BertPooler(config) + + self.init_weights() + + def _resize_token_embeddings(self, new_num_tokens): + old_embeddings = self.embeddings.word_embeddings + new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens) + self.embeddings.word_embeddings = new_embeddings + return self.embeddings.word_embeddings + + def _prune_heads(self, heads_to_prune): + """ Prunes heads of the model. + heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + See base class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def forward(self, input_ids, encoder_outputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None): + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x n_heads x N x N + # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] + # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] + if head_mask is not None: + if head_mask.dim() == 1: + head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) + elif head_mask.dim() == 2: + head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer + head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility + else: + head_mask = [None] * self.config.num_hidden_layers + + embedding_output = self.embeddings(input_ids, position_ids=position_ids, token_type_ids=token_type_ids) + decoder_outputs = self.decoder(embedding_output, + encoder_outputs, + extended_attention_mask, + head_mask=head_mask) + sequence_output = decoder_outputs[0] + pooled_output = self.pooler(sequence_output) + + outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here + return outputs # sequence_output, pooled_output, (hidden_states), (attentions) + + @add_start_docstrings("""Bert Model with two heads on top as done during the pre-training: a `masked language modeling` head and a `next sentence prediction (classification)` head. """, BERT_START_DOCSTRING, @@ -1312,13 +1416,20 @@ class BertForQuestionAnswering(BertPreTrainedModel): return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions) -@add_start_docstrings("Bert encoder-decoder model", +@add_start_docstrings("Bert encoder-decoder model for sequence generation.", BERT_START_DOCSTRING, BERT_INPUTS_DOCSTRING) class Bert2Bert(BertPreTrainedModel): r""" Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: + **hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``) + list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings) + of shape ``(batch_size, sequence_length, hidden_size)``: + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + **attentions**: (`optional`, returned when ``config.output_attentions=True``) + list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``: + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. Examples:: @@ -1328,45 +1439,33 @@ class Bert2Bert(BertPreTrainedModel): outputs = model(input) output_text = tokenize.decode(outputs[0]) print(output_text) + + References:: + + [1] "Leveraging Pre-trained Checkpoints for Sequence Generation Tasks", S.Rothe, S.Narayan & A.Severyn (2019) ArXiV:1907.12461v1 + [2] Tensor2Tensor library https://github.com/tensorflow/tensor2tensor + """ def __init__(self, config): super(Bert2Bert, self).__init__(config) - self.embeddings = BertEmbeddings(config) - self.encoder = BertEncoder(config) - self.decoder = BertDecoder(config) + self.encoder = BertModel(config) + self.decoder = BertDecoderModel(config) - def forward(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None): - if attention_mask is None: - attention_mask = torch.ones_like(inputs) - if token_type_ids is None: - token_type_ids = torch.zeros_like(inputs) + self.init_weights() - # We create a 3D attention mask from a 2D tensor mask. - # Sizes are [batch_size, 1, 1, to_seq_length] - # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] - # this attention mask is more simple than the triangular masking of causal attention - # used in OpenAI GPT, we just need to prepare the broadcast dimension here. - extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) - - # Since attention_mask is 1.0 for positions we want to attend and 0.0 for - # masked positions, this operation will create a tensor which is 0.0 for - # positions we want to attend and -10000.0 for masked positions. - # Since we are adding it to the raw scores before the softmax, this is - # effectively the same as removing these entirely. - extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility - extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 - - embedding_output = self.embeddings(inputs, position_ids=position_ids, token_type_ids=token_type_ids) - encoder_outputs = self.encoder(embedding_output, - extended_attention_mask, + def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None): + encoder_outputs = self.encoder(input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, head_mask=head_mask) - decoder_outputs = self.decoder(embedding_output, - encoder_outputs[0], - extended_attention_mask, - head_mask=head_mask) - sequence_output = decoder_outputs[0] - pooled_output = self.pooler(sequence_output) + encoder_output = encoder_outputs[0] - outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here - return outputs # sequence_output, pooled_output, (hidden_states), (attentions) + decoder_input = torch.empty_like(input_ids).normal_(mean=0.0, std=self.config.initializer_range) + decoder_outputs = self.decoder(decoder_input, + encoder_output, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask) + return decoder_outputs