From fc5a38ac92c74577523665d8f8602c45ae8dcd8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gr=C3=A9gory=20Ch=C3=A2tel?= Date: Thu, 6 Dec 2018 18:42:23 +0100 Subject: [PATCH] Adding the BertForMultipleChoiceClass. --- pytorch_pretrained_bert/__init__.py | 4 +- pytorch_pretrained_bert/modeling.py | 69 +++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 2 deletions(-) diff --git a/pytorch_pretrained_bert/__init__.py b/pytorch_pretrained_bert/__init__.py index fc9b15a12d..e1ecabf31d 100644 --- a/pytorch_pretrained_bert/__init__.py +++ b/pytorch_pretrained_bert/__init__.py @@ -1,7 +1,7 @@ from .tokenization import BertTokenizer, BasicTokenizer, WordpieceTokenizer from .modeling import (BertConfig, BertModel, BertForPreTraining, BertForMaskedLM, BertForNextSentencePrediction, - BertForSequenceClassification, BertForTokenClassification, - BertForQuestionAnswering) + BertForSequenceClassification, BertForMultipleChoice, + BertForTokenClassification, BertForQuestionAnswering) from .optimization import BertAdam from .file_utils import PYTORCH_PRETRAINED_BERT_CACHE diff --git a/pytorch_pretrained_bert/modeling.py b/pytorch_pretrained_bert/modeling.py index 3af5854072..e23e2c1a1c 100644 --- a/pytorch_pretrained_bert/modeling.py +++ b/pytorch_pretrained_bert/modeling.py @@ -877,6 +877,75 @@ class BertForSequenceClassification(PreTrainedBertModel): return logits +class BertForMultipleChoice(PreTrainedBertModel): + """BERT model for multiple choice tasks. + This module is composed of the BERT model with a linear layer on top of + the pooled output. + + Params: + `config`: a BertConfig class instance with the configuration to build a new model. + `num_choices`: the number of classes for the classifier. Default = 2. + + Inputs: + `input_ids`: a torch.LongTensor of shape [batch_size, num_choices, sequence_length] + with the word token indices in the vocabulary(see the tokens preprocessing logic in the scripts + `extract_features.py`, `run_classifier.py` and `run_squad.py`) + `token_type_ids`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] + with the token types indices selected in [0, 1]. Type 0 corresponds to a `sentence A` + and type 1 corresponds to a `sentence B` token (see BERT paper for more details). + `attention_mask`: an optional torch.LongTensor of shape [batch_size, num_choices, sequence_length] with indices + selected in [0, 1]. It's a mask to be used if the input sequence length is smaller than the max + input sequence length in the current batch. It's the mask that we typically use for attention when + a batch has varying length sentences. + `labels`: labels for the classification output: torch.LongTensor of shape [batch_size] + with indices selected in [0, ..., num_choices]. + + Outputs: + if `labels` is not `None`: + Outputs the CrossEntropy classification loss of the output with the labels. + if `labels` is `None`: + Outputs the classification logits of shape [batch_size, num_labels]. + + Example usage: + ```python + # Already been converted into WordPiece token ids + input_ids = torch.LongTensor([[[31, 51, 99], [15, 5, 0]], [[12, 16, 42], [14, 28, 57]]]) + input_mask = torch.LongTensor([[[1, 1, 1], [1, 1, 0]],[[1,1,0], [1, 0, 0]]]) + token_type_ids = torch.LongTensor([[[0, 0, 1], [0, 1, 0]],[[0, 1, 1], [0, 0, 1]]]) + config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768, + num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072) + + num_choices = 2 + + model = BertForMultipleChoice(config, num_choices) + logits = model(input_ids, token_type_ids, input_mask) + ``` + """ + def __init__(self, config, num_choices=2): + super(BertForMultipleChoice, self).__init__(config) + self.num_choices = num_choices + self.bert = BertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.classifier = nn.Linear(config.hidden_size, 1) + self.apply(self.init_bert_weights) + + def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None): + flat_input_ids = input_ids.view(-1, input_ids.size(-1)) + flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) + flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) + _, pooled_output = self.bert(flat_input_ids, flat_token_type_ids, flat_attention_mask, output_all_encoded_layers=False) + pooled_output = self.dropout(pooled_output) + logits = self.classifier(pooled_output) + reshaped_logits = logits.view(-1, self.num_choices) + + if labels is not None: + loss_fct = CrossEntropyLoss() + loss = loss_fct(reshaped_logits, labels) + return loss + else: + return reshaped_logits + + class BertForTokenClassification(PreTrainedBertModel): """BERT model for token-level classification. This module is composed of the BERT model with a linear layer on top of