add attention to all bert models and add test
This commit is contained in:
@@ -813,15 +813,20 @@ class BertForPreTraining(BertPreTrainedModel):
|
||||
masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, output_attentions=False):
|
||||
super(BertForPreTraining, self).__init__(config)
|
||||
self.bert = BertModel(config)
|
||||
self.output_attentions = output_attentions
|
||||
self.bert = BertModel(config, output_attentions=output_attentions)
|
||||
self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight)
|
||||
self.apply(self.init_bert_weights)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None):
|
||||
sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask,
|
||||
outputs = self.bert(input_ids, token_type_ids, attention_mask,
|
||||
output_all_encoded_layers=False)
|
||||
if self.output_attentions:
|
||||
all_attentions, sequence_output, pooled_output = outputs
|
||||
else:
|
||||
sequence_output, pooled_output = outputs
|
||||
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
||||
|
||||
if masked_lm_labels is not None and next_sentence_label is not None:
|
||||
@@ -830,8 +835,9 @@ class BertForPreTraining(BertPreTrainedModel):
|
||||
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
|
||||
total_loss = masked_lm_loss + next_sentence_loss
|
||||
return total_loss
|
||||
else:
|
||||
return prediction_scores, seq_relationship_score
|
||||
elif self.output_attentions:
|
||||
return all_attentions, prediction_scores, seq_relationship_score
|
||||
return prediction_scores, seq_relationship_score
|
||||
|
||||
|
||||
class BertForMaskedLM(BertPreTrainedModel):
|
||||
@@ -876,23 +882,29 @@ class BertForMaskedLM(BertPreTrainedModel):
|
||||
masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, output_attentions=False):
|
||||
super(BertForMaskedLM, self).__init__(config)
|
||||
self.bert = BertModel(config)
|
||||
self.output_attentions = output_attentions
|
||||
self.bert = BertModel(config, output_attentions=output_attentions)
|
||||
self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight)
|
||||
self.apply(self.init_bert_weights)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None):
|
||||
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask,
|
||||
outputs = self.bert(input_ids, token_type_ids, attention_mask,
|
||||
output_all_encoded_layers=False)
|
||||
if self.output_attentions:
|
||||
all_attentions, sequence_output, _ = outputs
|
||||
else:
|
||||
sequence_output, _ = outputs
|
||||
prediction_scores = self.cls(sequence_output)
|
||||
|
||||
if masked_lm_labels is not None:
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1))
|
||||
return masked_lm_loss
|
||||
else:
|
||||
return prediction_scores
|
||||
elif self.output_attentions:
|
||||
return all_attentions, prediction_scores
|
||||
return prediction_scores
|
||||
|
||||
|
||||
class BertForNextSentencePrediction(BertPreTrainedModel):
|
||||
@@ -938,23 +950,29 @@ class BertForNextSentencePrediction(BertPreTrainedModel):
|
||||
seq_relationship_logits = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, output_attentions=False):
|
||||
super(BertForNextSentencePrediction, self).__init__(config)
|
||||
self.bert = BertModel(config)
|
||||
self.output_attentions = output_attentions
|
||||
self.bert = BertModel(config, output_attentions=output_attentions)
|
||||
self.cls = BertOnlyNSPHead(config)
|
||||
self.apply(self.init_bert_weights)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None):
|
||||
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask,
|
||||
outputs = self.bert(input_ids, token_type_ids, attention_mask,
|
||||
output_all_encoded_layers=False)
|
||||
seq_relationship_score = self.cls( pooled_output)
|
||||
if self.output_attentions:
|
||||
all_attentions, _, pooled_output = outputs
|
||||
else:
|
||||
_, pooled_output = outputs
|
||||
seq_relationship_score = self.cls(pooled_output)
|
||||
|
||||
if next_sentence_label is not None:
|
||||
loss_fct = CrossEntropyLoss(ignore_index=-1)
|
||||
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
|
||||
return next_sentence_loss
|
||||
else:
|
||||
return seq_relationship_score
|
||||
elif self.output_attentions:
|
||||
return all_attentions, seq_relationship_score
|
||||
return seq_relationship_score
|
||||
|
||||
|
||||
class BertForSequenceClassification(BertPreTrainedModel):
|
||||
@@ -1002,16 +1020,21 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
||||
logits = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config, num_labels):
|
||||
def __init__(self, config, num_labels, output_attentions=False):
|
||||
super(BertForSequenceClassification, self).__init__(config)
|
||||
self.output_attentions = output_attentions
|
||||
self.num_labels = num_labels
|
||||
self.bert = BertModel(config)
|
||||
self.bert = BertModel(config, output_attentions=output_attentions)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, num_labels)
|
||||
self.apply(self.init_bert_weights)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
|
||||
_, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
|
||||
outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
|
||||
if self.output_attentions:
|
||||
all_attentions, _, pooled_output = outputs
|
||||
else:
|
||||
_, pooled_output = outputs
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
|
||||
@@ -1019,8 +1042,9 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
return loss
|
||||
else:
|
||||
return logits
|
||||
elif self.output_attentions:
|
||||
return all_attentions, logits
|
||||
return logits
|
||||
|
||||
|
||||
class BertForMultipleChoice(BertPreTrainedModel):
|
||||
@@ -1067,10 +1091,11 @@ class BertForMultipleChoice(BertPreTrainedModel):
|
||||
logits = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config, num_choices):
|
||||
def __init__(self, config, num_choices, output_attentions=False):
|
||||
super(BertForMultipleChoice, self).__init__(config)
|
||||
self.output_attentions = output_attentions
|
||||
self.num_choices = num_choices
|
||||
self.bert = BertModel(config)
|
||||
self.bert = BertModel(config, output_attentions=output_attentions)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, 1)
|
||||
self.apply(self.init_bert_weights)
|
||||
@@ -1079,7 +1104,11 @@ class BertForMultipleChoice(BertPreTrainedModel):
|
||||
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
|
||||
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
||||
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
||||
_, pooled_output = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False)
|
||||
outputs = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False)
|
||||
if self.output_attentions:
|
||||
all_attentions, _, pooled_output = outputs
|
||||
else:
|
||||
_, pooled_output = outputs
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
logits = self.classifier(pooled_output)
|
||||
reshaped_logits = logits.view(-1, self.num_choices)
|
||||
@@ -1088,8 +1117,9 @@ class BertForMultipleChoice(BertPreTrainedModel):
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(reshaped_logits, labels)
|
||||
return loss
|
||||
else:
|
||||
return reshaped_logits
|
||||
elif self.output_attentions:
|
||||
return all_attentions, reshaped_logits
|
||||
return reshaped_logits
|
||||
|
||||
|
||||
class BertForTokenClassification(BertPreTrainedModel):
|
||||
@@ -1137,16 +1167,21 @@ class BertForTokenClassification(BertPreTrainedModel):
|
||||
logits = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config, num_labels):
|
||||
def __init__(self, config, num_labels, output_attentions=False):
|
||||
super(BertForTokenClassification, self).__init__(config)
|
||||
self.output_attentions = output_attentions
|
||||
self.num_labels = num_labels
|
||||
self.bert = BertModel(config)
|
||||
self.bert = BertModel(config, output_attentions=output_attentions)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.classifier = nn.Linear(config.hidden_size, num_labels)
|
||||
self.apply(self.init_bert_weights)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
|
||||
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
|
||||
outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
|
||||
if self.output_attentions:
|
||||
all_attentions, sequence_output, _ = outputs
|
||||
else:
|
||||
sequence_output, _ = outputs
|
||||
sequence_output = self.dropout(sequence_output)
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
@@ -1161,8 +1196,9 @@ class BertForTokenClassification(BertPreTrainedModel):
|
||||
else:
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
return loss
|
||||
else:
|
||||
return logits
|
||||
elif self.output_attentions:
|
||||
return all_attentions, logits
|
||||
return logits
|
||||
|
||||
|
||||
class BertForQuestionAnswering(BertPreTrainedModel):
|
||||
@@ -1212,16 +1248,19 @@ class BertForQuestionAnswering(BertPreTrainedModel):
|
||||
start_logits, end_logits = model(input_ids, token_type_ids, input_mask)
|
||||
```
|
||||
"""
|
||||
def __init__(self, config):
|
||||
def __init__(self, config, output_attentions=False):
|
||||
super(BertForQuestionAnswering, self).__init__(config)
|
||||
self.bert = BertModel(config)
|
||||
# TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version
|
||||
# self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.output_attentions = output_attentions
|
||||
self.bert = BertModel(config, output_attentions=output_attentions)
|
||||
self.qa_outputs = nn.Linear(config.hidden_size, 2)
|
||||
self.apply(self.init_bert_weights)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None):
|
||||
sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
|
||||
outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)
|
||||
if self.output_attentions:
|
||||
all_attentions, sequence_output, _ = outputs
|
||||
else:
|
||||
sequence_output, _ = outputs
|
||||
logits = self.qa_outputs(sequence_output)
|
||||
start_logits, end_logits = logits.split(1, dim=-1)
|
||||
start_logits = start_logits.squeeze(-1)
|
||||
@@ -1243,5 +1282,6 @@ class BertForQuestionAnswering(BertPreTrainedModel):
|
||||
end_loss = loss_fct(end_logits, end_positions)
|
||||
total_loss = (start_loss + end_loss) / 2
|
||||
return total_loss
|
||||
else:
|
||||
return start_logits, end_logits
|
||||
elif self.output_attentions:
|
||||
return all_attentions, start_logits, end_logits
|
||||
return start_logits, end_logits
|
||||
|
||||
Reference in New Issue
Block a user