add general structure for Bert2Bert class
This commit is contained in:
@@ -1310,3 +1310,63 @@ class BertForQuestionAnswering(BertPreTrainedModel):
|
||||
outputs = (total_loss,) + outputs
|
||||
|
||||
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("Bert encoder-decoder model",
|
||||
BERT_START_DOCSTRING,
|
||||
BERT_INPUTS_DOCSTRING)
|
||||
class Bert2Bert(BertPreTrainedModel):
|
||||
r"""
|
||||
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
|
||||
Examples::
|
||||
|
||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||
model = Bert2Bert.from_pretrained('bert-base-uncased')
|
||||
input = tokenizer.encode("Hello, how are you?")
|
||||
outputs = model(input)
|
||||
output_text = tokenize.decode(outputs[0])
|
||||
print(output_text)
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(Bert2Bert, self).__init__(config)
|
||||
self.embeddings = BertEmbeddings(config)
|
||||
self.encoder = BertEncoder(config)
|
||||
self.decoder = BertDecoder(config)
|
||||
|
||||
def forward(self, inputs, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(inputs)
|
||||
if token_type_ids is None:
|
||||
token_type_ids = torch.zeros_like(inputs)
|
||||
|
||||
# We create a 3D attention mask from a 2D tensor mask.
|
||||
# Sizes are [batch_size, 1, 1, to_seq_length]
|
||||
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
||||
# this attention mask is more simple than the triangular masking of causal attention
|
||||
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
||||
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
|
||||
|
||||
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
||||
# masked positions, this operation will create a tensor which is 0.0 for
|
||||
# positions we want to attend and -10000.0 for masked positions.
|
||||
# Since we are adding it to the raw scores before the softmax, this is
|
||||
# effectively the same as removing these entirely.
|
||||
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
|
||||
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
|
||||
|
||||
embedding_output = self.embeddings(inputs, position_ids=position_ids, token_type_ids=token_type_ids)
|
||||
encoder_outputs = self.encoder(embedding_output,
|
||||
extended_attention_mask,
|
||||
head_mask=head_mask)
|
||||
decoder_outputs = self.decoder(embedding_output,
|
||||
encoder_outputs[0],
|
||||
extended_attention_mask,
|
||||
head_mask=head_mask)
|
||||
sequence_output = decoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
|
||||
outputs = (sequence_output, pooled_output,) + encoder_outputs[1:] # add hidden_states and attentions if they are here
|
||||
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
|
||||
|
||||
Reference in New Issue
Block a user