From 5e1207b8ad00fd649c0f35b9697cd67ce9897505 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Fri, 14 Jun 2019 16:28:25 +0200 Subject: [PATCH] add attention to all bert models and add test --- pytorch_pretrained_bert/modeling.py | 116 +++++++++++++++++++--------- tests/modeling_test.py | 71 ++++++++++++++--- 2 files changed, 140 insertions(+), 47 deletions(-) diff --git a/pytorch_pretrained_bert/modeling.py b/pytorch_pretrained_bert/modeling.py index 27682eb369..bfcbcc9edf 100644 --- a/pytorch_pretrained_bert/modeling.py +++ b/pytorch_pretrained_bert/modeling.py @@ -813,15 +813,20 @@ class BertForPreTraining(BertPreTrainedModel): masked_lm_logits_scores, seq_relationship_logits = model(input_ids, token_type_ids, input_mask) ``` """ - def __init__(self, config): + def __init__(self, config, output_attentions=False): super(BertForPreTraining, self).__init__(config) - self.bert = BertModel(config) + self.output_attentions = output_attentions + self.bert = BertModel(config, output_attentions=output_attentions) self.cls = BertPreTrainingHeads(config, self.bert.embeddings.word_embeddings.weight) self.apply(self.init_bert_weights) def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None, next_sentence_label=None): - sequence_output, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, + outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) + if self.output_attentions: + all_attentions, sequence_output, pooled_output = outputs + else: + sequence_output, pooled_output = outputs prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) if masked_lm_labels is not None and next_sentence_label is not None: @@ -830,8 +835,9 @@ class BertForPreTraining(BertPreTrainedModel): next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) total_loss = masked_lm_loss + next_sentence_loss return total_loss - else: - return prediction_scores, seq_relationship_score + elif self.output_attentions: + return all_attentions, prediction_scores, seq_relationship_score + return prediction_scores, seq_relationship_score class BertForMaskedLM(BertPreTrainedModel): @@ -876,23 +882,29 @@ class BertForMaskedLM(BertPreTrainedModel): masked_lm_logits_scores = model(input_ids, token_type_ids, input_mask) ``` """ - def __init__(self, config): + def __init__(self, config, output_attentions=False): super(BertForMaskedLM, self).__init__(config) - self.bert = BertModel(config) + self.output_attentions = output_attentions + self.bert = BertModel(config, output_attentions=output_attentions) self.cls = BertOnlyMLMHead(config, self.bert.embeddings.word_embeddings.weight) self.apply(self.init_bert_weights) def forward(self, input_ids, token_type_ids=None, attention_mask=None, masked_lm_labels=None): - sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, + outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) + if self.output_attentions: + all_attentions, sequence_output, _ = outputs + else: + sequence_output, _ = outputs prediction_scores = self.cls(sequence_output) if masked_lm_labels is not None: loss_fct = CrossEntropyLoss(ignore_index=-1) masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), masked_lm_labels.view(-1)) return masked_lm_loss - else: - return prediction_scores + elif self.output_attentions: + return all_attentions, prediction_scores + return prediction_scores class BertForNextSentencePrediction(BertPreTrainedModel): @@ -938,23 +950,29 @@ class BertForNextSentencePrediction(BertPreTrainedModel): seq_relationship_logits = model(input_ids, token_type_ids, input_mask) ``` """ - def __init__(self, config): + def __init__(self, config, output_attentions=False): super(BertForNextSentencePrediction, self).__init__(config) - self.bert = BertModel(config) + self.output_attentions = output_attentions + self.bert = BertModel(config, output_attentions=output_attentions) self.cls = BertOnlyNSPHead(config) self.apply(self.init_bert_weights) def forward(self, input_ids, token_type_ids=None, attention_mask=None, next_sentence_label=None): - _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, + outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) - seq_relationship_score = self.cls( pooled_output) + if self.output_attentions: + all_attentions, _, pooled_output = outputs + else: + _, pooled_output = outputs + seq_relationship_score = self.cls(pooled_output) if next_sentence_label is not None: loss_fct = CrossEntropyLoss(ignore_index=-1) next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) return next_sentence_loss - else: - return seq_relationship_score + elif self.output_attentions: + return all_attentions, seq_relationship_score + return seq_relationship_score class BertForSequenceClassification(BertPreTrainedModel): @@ -1002,16 +1020,21 @@ class BertForSequenceClassification(BertPreTrainedModel): logits = model(input_ids, token_type_ids, input_mask) ``` """ - def __init__(self, config, num_labels): + def __init__(self, config, num_labels, output_attentions=False): super(BertForSequenceClassification, self).__init__(config) + self.output_attentions = output_attentions self.num_labels = num_labels - self.bert = BertModel(config) + self.bert = BertModel(config, output_attentions=output_attentions) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, num_labels) self.apply(self.init_bert_weights) def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): - _, pooled_output = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) + outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) + if self.output_attentions: + all_attentions, _, pooled_output = outputs + else: + _, pooled_output = outputs pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) @@ -1019,8 +1042,9 @@ class BertForSequenceClassification(BertPreTrainedModel): loss_fct = CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) return loss - else: - return logits + elif self.output_attentions: + return all_attentions, logits + return logits class BertForMultipleChoice(BertPreTrainedModel): @@ -1067,10 +1091,11 @@ class BertForMultipleChoice(BertPreTrainedModel): logits = model(input_ids, token_type_ids, input_mask) ``` """ - def __init__(self, config, num_choices): + def __init__(self, config, num_choices, output_attentions=False): super(BertForMultipleChoice, self).__init__(config) + self.output_attentions = output_attentions self.num_choices = num_choices - self.bert = BertModel(config) + self.bert = BertModel(config, output_attentions=output_attentions) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, 1) self.apply(self.init_bert_weights) @@ -1079,7 +1104,11 @@ class BertForMultipleChoice(BertPreTrainedModel): flat_input_ids = input_ids.view(-1, input_ids.size(-1)) flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None - _, pooled_output = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False) + outputs = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False) + if self.output_attentions: + all_attentions, _, pooled_output = outputs + else: + _, pooled_output = outputs pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) reshaped_logits = logits.view(-1, self.num_choices) @@ -1088,8 +1117,9 @@ class BertForMultipleChoice(BertPreTrainedModel): loss_fct = CrossEntropyLoss() loss = loss_fct(reshaped_logits, labels) return loss - else: - return reshaped_logits + elif self.output_attentions: + return all_attentions, reshaped_logits + return reshaped_logits class BertForTokenClassification(BertPreTrainedModel): @@ -1137,16 +1167,21 @@ class BertForTokenClassification(BertPreTrainedModel): logits = model(input_ids, token_type_ids, input_mask) ``` """ - def __init__(self, config, num_labels): + def __init__(self, config, num_labels, output_attentions=False): super(BertForTokenClassification, self).__init__(config) + self.output_attentions = output_attentions self.num_labels = num_labels - self.bert = BertModel(config) + self.bert = BertModel(config, output_attentions=output_attentions) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.classifier = nn.Linear(config.hidden_size, num_labels) self.apply(self.init_bert_weights) def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): - sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) + outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) + if self.output_attentions: + all_attentions, sequence_output, _ = outputs + else: + sequence_output, _ = outputs sequence_output = self.dropout(sequence_output) logits = self.classifier(sequence_output) @@ -1161,8 +1196,9 @@ class BertForTokenClassification(BertPreTrainedModel): else: loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) return loss - else: - return logits + elif self.output_attentions: + return all_attentions, logits + return logits class BertForQuestionAnswering(BertPreTrainedModel): @@ -1212,16 +1248,19 @@ class BertForQuestionAnswering(BertPreTrainedModel): start_logits, end_logits = model(input_ids, token_type_ids, input_mask) ``` """ - def __init__(self, config): + def __init__(self, config, output_attentions=False): super(BertForQuestionAnswering, self).__init__(config) - self.bert = BertModel(config) - # TODO check with Google if it's normal there is no dropout on the token classifier of SQuAD in the TF version - # self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.output_attentions = output_attentions + self.bert = BertModel(config, output_attentions=output_attentions) self.qa_outputs = nn.Linear(config.hidden_size, 2) self.apply(self.init_bert_weights) def forward(self, input_ids, token_type_ids=None, attention_mask=None, start_positions=None, end_positions=None): - sequence_output, _ = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) + outputs = self.bert(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False) + if self.output_attentions: + all_attentions, sequence_output, _ = outputs + else: + sequence_output, _ = outputs logits = self.qa_outputs(sequence_output) start_logits, end_logits = logits.split(1, dim=-1) start_logits = start_logits.squeeze(-1) @@ -1243,5 +1282,6 @@ class BertForQuestionAnswering(BertPreTrainedModel): end_loss = loss_fct(end_logits, end_positions) total_loss = (start_loss + end_loss) / 2 return total_loss - else: - return start_logits, end_logits + elif self.output_attentions: + return all_attentions, start_logits, end_logits + return start_logits, end_logits diff --git a/tests/modeling_test.py b/tests/modeling_test.py index 5cde383fdf..79993ed840 100644 --- a/tests/modeling_test.py +++ b/tests/modeling_test.py @@ -28,7 +28,7 @@ import torch from pytorch_pretrained_bert import (BertConfig, BertModel, BertForMaskedLM, BertForNextSentencePrediction, BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, - BertForTokenClassification) + BertForTokenClassification, BertForMultipleChoice) from pytorch_pretrained_bert.modeling import PRETRAINED_MODEL_ARCHIVE_MAP @@ -56,6 +56,7 @@ class BertModelTest(unittest.TestCase): type_sequence_label_size=2, initializer_range=0.02, num_labels=3, + num_choices=4, scope=None): self.parent = parent self.batch_size = batch_size @@ -77,6 +78,7 @@ class BertModelTest(unittest.TestCase): self.type_sequence_label_size = type_sequence_label_size self.initializer_range = initializer_range self.num_labels = num_labels + self.num_choices = num_choices self.scope = scope def prepare_config_and_inputs(self): @@ -92,9 +94,11 @@ class BertModelTest(unittest.TestCase): sequence_labels = None token_labels = None + choice_labels = None if self.use_labels: sequence_labels = BertModelTest.ids_tensor([self.batch_size], self.type_sequence_label_size) token_labels = BertModelTest.ids_tensor([self.batch_size, self.seq_length], self.num_labels) + choice_labels = BertModelTest.ids_tensor([self.batch_size], self.num_choices) config = BertConfig( vocab_size_or_config_json_file=self.vocab_size, @@ -109,14 +113,14 @@ class BertModelTest(unittest.TestCase): type_vocab_size=self.type_vocab_size, initializer_range=self.initializer_range) - return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels + return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels def check_loss_output(self, result): self.parent.assertListEqual( list(result["loss"].size()), []) - def create_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): + def create_bert_model(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): model = BertModel(config=config) model.eval() all_encoder_layers, pooled_output = model(input_ids, token_type_ids, input_mask) @@ -137,7 +141,7 @@ class BertModelTest(unittest.TestCase): self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size]) - def create_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): + def create_bert_for_masked_lm(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): model = BertForMaskedLM(config=config) model.eval() loss = model(input_ids, token_type_ids, input_mask, token_labels) @@ -153,7 +157,7 @@ class BertModelTest(unittest.TestCase): list(result["prediction_scores"].size()), [self.batch_size, self.seq_length, self.vocab_size]) - def create_bert_for_next_sequence_prediction(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): + def create_bert_for_next_sequence_prediction(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): model = BertForNextSentencePrediction(config=config) model.eval() loss = model(input_ids, token_type_ids, input_mask, sequence_labels) @@ -170,7 +174,7 @@ class BertModelTest(unittest.TestCase): [self.batch_size, 2]) - def create_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): + def create_bert_for_pretraining(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): model = BertForPreTraining(config=config) model.eval() loss = model(input_ids, token_type_ids, input_mask, token_labels, sequence_labels) @@ -191,7 +195,7 @@ class BertModelTest(unittest.TestCase): [self.batch_size, 2]) - def create_bert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): + def create_bert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): model = BertForQuestionAnswering(config=config) model.eval() loss = model(input_ids, token_type_ids, input_mask, sequence_labels, sequence_labels) @@ -212,7 +216,7 @@ class BertModelTest(unittest.TestCase): [self.batch_size, self.seq_length]) - def create_bert_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): + def create_bert_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): model = BertForSequenceClassification(config=config, num_labels=self.num_labels) model.eval() loss = model(input_ids, token_type_ids, input_mask, sequence_labels) @@ -229,7 +233,7 @@ class BertModelTest(unittest.TestCase): [self.batch_size, self.num_labels]) - def create_bert_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels): + def create_bert_for_token_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): model = BertForTokenClassification(config=config, num_labels=self.num_labels) model.eval() loss = model(input_ids, token_type_ids, input_mask, token_labels) @@ -246,6 +250,49 @@ class BertModelTest(unittest.TestCase): [self.batch_size, self.seq_length, self.num_labels]) + def create_bert_for_multiple_choice(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): + model = BertForMultipleChoice(config=config, num_choices=self.num_choices) + model.eval() + multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + loss = model(multiple_choice_inputs_ids, + multiple_choice_token_type_ids, + multiple_choice_input_mask, + choice_labels) + logits = model(multiple_choice_inputs_ids, + multiple_choice_token_type_ids, + multiple_choice_input_mask) + outputs = { + "loss": loss, + "logits": logits, + } + return outputs + + def check_bert_for_multiple_choice(self, result): + self.parent.assertListEqual( + list(result["logits"].size()), + [self.batch_size, self.num_choices]) + + + def create_and_check_bert_for_attentions(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): + for model_class in (BertModel, BertForMaskedLM, BertForNextSentencePrediction, + BertForPreTraining, BertForQuestionAnswering, BertForSequenceClassification, + BertForTokenClassification): + if model_class in [BertForSequenceClassification, + BertForTokenClassification]: + model = model_class(config=config, num_labels=self.num_labels, output_attentions=True) + else: + model = model_class(config=config, output_attentions=True) + model.eval() + output = model(input_ids, token_type_ids, input_mask) + attentions = output[0] + self.parent.assertEqual(len(attentions), self.num_hidden_layers) + self.parent.assertListEqual( + list(attentions[0].size()), + [self.batch_size, self.num_attention_heads, self.seq_length, self.seq_length]) + + def test_default(self): self.run_tester(BertModelTest.BertModelTester(self)) @@ -300,6 +347,12 @@ class BertModelTest(unittest.TestCase): tester.check_bert_for_token_classification_output(output_result) tester.check_loss_output(output_result) + output_result = tester.create_bert_for_multiple_choice(*config_and_inputs) + tester.check_bert_for_multiple_choice(output_result) + tester.check_loss_output(output_result) + + tester.create_and_check_bert_for_attentions(*config_and_inputs) + @classmethod def ids_tensor(cls, shape, vocab_size, rng=None, name=None): """Creates a random int32 tensor of the shape within the vocab size."""