add attention to all bert models and add test

This commit is contained in:
thomwolf
2019-06-14 16:28:25 +02:00
parent bcc9e93e6f
commit 5e1207b8ad
2 changed files with 140 additions and 47 deletions

View File

@@ -813,15 +813,20 @@ class BertForPreTraining(BertPreTrainedModel):
masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config):
def __init__(self, config, output_attentions=False):
super(BertForPreTraining, self).__init__(config)
self.bert = BertModel(config)
self.output_attentions = output_attentions
self.bert = BertModel(config, output_attentions=output_attentions)
self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None):
sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask,
outputs = self.bert(input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False)
if self.output_attentions:
all_attentions, sequence_output, pooled_output = outputs
else:
sequence_output, pooled_output = outputs
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
if masked_lm_labels is not None and next_sentence_label is not None:
@@ -830,8 +835,9 @@ class BertForPreTraining(BertPreTrainedModel):
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
total_loss = masked_lm_loss + next_sentence_loss
return total_loss
else:
return prediction_scores, seq_relationship_score
elif self.output_attentions:
return all_attentions, prediction_scores, seq_relationship_score
return prediction_scores, seq_relationship_score
class BertForMaskedLM(BertPreTrainedModel):
@@ -876,23 +882,29 @@ class BertForMaskedLM(BertPreTrainedModel):
masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config):
def __init__(self, config, output_attentions=False):
super(BertForMaskedLM, self).__init__(config)
self.bert = BertModel(config)
self.output_attentions = output_attentions
self.bert = BertModel(config, output_attentions=output_attentions)
self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None):
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask,
outputs = self.bert(input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False)
if self.output_attentions:
all_attentions, sequence_output, _ = outputs
else:
sequence_output, _ = outputs
prediction_scores = self.cls(sequence_output)
if masked_lm_labels is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1)
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
return masked_lm_loss
else:
return prediction_scores
elif self.output_attentions:
return all_attentions, prediction_scores
return prediction_scores
class BertForNextSentencePrediction(BertPreTrainedModel):
@@ -938,23 +950,29 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config):
def __init__(self, config, output_attentions=False):
super(BertForNextSentencePrediction, self).__init__(config)
self.bert = BertModel(config)
self.output_attentions = output_attentions
self.bert = BertModel(config, output_attentions=output_attentions)
self.cls = BertOnlyNSPHead(config)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None):
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask,
outputs = self.bert(input_ids, token_type_ids, attention_mask,
output_all_encoded_layers=False)
seq_relationship_score = self.cls( pooled_output)
if self.output_attentions:
all_attentions, _, pooled_output = outputs
else:
_, pooled_output = outputs
seq_relationship_score = self.cls(pooled_output)
if next_sentence_label is not None:
loss_fct = CrossEntropyLoss(ignore_index=-1)
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
return next_sentence_loss
else:
return seq_relationship_score
elif self.output_attentions:
return all_attentions, seq_relationship_score
return seq_relationship_score
class BertForSequenceClassification(BertPreTrainedModel):
@@ -1002,16 +1020,21 @@ class BertForSequenceClassification(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, num_labels):
def __init__(self, config, num_labels, output_attentions=False):
super(BertForSequenceClassification, self).__init__(config)
self.output_attentions = output_attentions
self.num_labels = num_labels
self.bert = BertModel(config)
self.bert = BertModel(config, output_attentions=output_attentions)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
if self.output_attentions:
all_attentions, _, pooled_output = outputs
else:
_, pooled_output = outputs
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
@@ -1019,8 +1042,9 @@ class BertForSequenceClassification(BertPreTrainedModel):
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return loss
else:
return logits
elif self.output_attentions:
return all_attentions, logits
return logits
class BertForMultipleChoice(BertPreTrainedModel):
@@ -1067,10 +1091,11 @@ class BertForMultipleChoice(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, num_choices):
def __init__(self, config, num_choices, output_attentions=False):
super(BertForMultipleChoice, self).__init__(config)
self.output_attentions = output_attentions
self.num_choices = num_choices
self.bert = BertModel(config)
self.bert = BertModel(config, output_attentions=output_attentions)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, 1)
self.apply(self.init_bert_weights)
@@ -1079,7 +1104,11 @@ class BertForMultipleChoice(BertPreTrainedModel):
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
_, pooled_output = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False)
outputs = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False)
if self.output_attentions:
all_attentions, _, pooled_output = outputs
else:
_, pooled_output = outputs
pooled_output = self.dropout(pooled_output)
logits = self.classifier(pooled_output)
reshaped_logits = logits.view(-1, self.num_choices)
@@ -1088,8 +1117,9 @@ class BertForMultipleChoice(BertPreTrainedModel):
loss_fct = CrossEntropyLoss()
loss = loss_fct(reshaped_logits, labels)
return loss
else:
return reshaped_logits
elif self.output_attentions:
return all_attentions, reshaped_logits
return reshaped_logits
class BertForTokenClassification(BertPreTrainedModel):
@@ -1137,16 +1167,21 @@ class BertForTokenClassification(BertPreTrainedModel):
logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config, num_labels):
def __init__(self, config, num_labels, output_attentions=False):
super(BertForTokenClassification, self).__init__(config)
self.output_attentions = output_attentions
self.num_labels = num_labels
self.bert = BertModel(config)
self.bert = BertModel(config, output_attentions=output_attentions)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.classifier = nn.Linear(config.hidden_size, num_labels)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
if self.output_attentions:
all_attentions, sequence_output, _ = outputs
else:
sequence_output, _ = outputs
sequence_output = self.dropout(sequence_output)
logits = self.classifier(sequence_output)
@@ -1161,8 +1196,9 @@ class BertForTokenClassification(BertPreTrainedModel):
else:
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
return loss
else:
return logits
elif self.output_attentions:
return all_attentions, logits
return logits
class BertForQuestionAnswering(BertPreTrainedModel):
@@ -1212,16 +1248,19 @@ class BertForQuestionAnswering(BertPreTrainedModel):
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
```
"""
def __init__(self, config):
def __init__(self, config, output_attentions=False):
super(BertForQuestionAnswering, self).__init__(config)
self.bert = BertModel(config)
# TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version
# self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.output_attentions = output_attentions
self.bert = BertModel(config, output_attentions=output_attentions)
self.qa_outputs = nn.Linear(config.hidden_size, 2)
self.apply(self.init_bert_weights)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None):
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
if self.output_attentions:
all_attentions, sequence_output, _ = outputs
else:
sequence_output, _ = outputs
logits = self.qa_outputs(sequence_output)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1)
@@ -1243,5 +1282,6 @@ class BertForQuestionAnswering(BertPreTrainedModel):
end_loss = loss_fct(end_logits, end_positions)
total_loss = (start_loss + end_loss) / 2
return total_loss
else:
return start_logits, end_logits
elif self.output_attentions:
return all_attentions, start_logits, end_logits
return start_logits, end_logits