From 75feacf172d5aa26314929c0f3bb54d9e845e00b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Tue, 8 Oct 2019 11:39:04 +0200 Subject: [PATCH] add general structure for Bert2Bert class --- transformers/modeling_bert.py | 60 +++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/transformers/modeling_bert.py b/transformers/modeling_bert.py index a5e36eaed0..1e228556b4 100644 --- a/transformers/modeling_bert.py +++ b/transformers/modeling_bert.py @@ -1310,3 +1310,63 @@ class BertForQuestionAnswering(BertPreTrainedModel): outputs = (total_loss,) + outputs return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions) + + +@add_start_docstrings("Bert encoder-decoder model", + BERT_START_DOCSTRING, + BERT_INPUTS_DOCSTRING) +class Bert2Bert(BertPreTrainedModel): + r""" + + Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: + + Examples:: + + tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') + model = Bert2Bert.from_pretrained('bert-base-uncased') + input = tokenizer.encode("Hello, how are you?") + outputs = model(input) + output_text = tokenize.decode(outputs[0]) + print(output_text) + """ + + def __init__(self, config): + super(Bert2Bert, self).__init__(config) + self.embeddings = BertEmbeddings(config) + self.encoder = BertEncoder(config) + self.decoder = BertDecoder(config) + + def forward(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None): + if attention_mask is None: + attention_mask = torch.ones_like(inputs) + if token_type_ids is None: + token_type_ids = torch.zeros_like(inputs) + + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + embedding_output = self.embeddings(inputs, position_ids=position_ids, token_type_ids=token_type_ids) + encoder_outputs = self.encoder(embedding_output, + extended_attention_mask, + head_mask=head_mask) + decoder_outputs = self.decoder(embedding_output, + encoder_outputs[0], + extended_attention_mask, + head_mask=head_mask) + sequence_output = decoder_outputs[0] + pooled_output = self.pooler(sequence_output) + + outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here + return outputs # sequence_output, pooled_output, (hidden_states), (attentions)