Add DistilBertForMultipleChoice (#5032)

* Add `DistilBertForMultipleChoice`
This commit is contained in:
Sylvain Gugger
2020-06-15 18:31:41 -04:00
committed by GitHub
parent 36434220fc
commit f9f8a5312e
6 changed files with 144 additions and 1 deletions

View File

@@ -28,6 +28,7 @@ if is_torch_available():
DistilBertConfig,
DistilBertModel,
DistilBertForMaskedLM,
DistilBertForMultipleChoice,
DistilBertForTokenClassification,
DistilBertForQuestionAnswering,
DistilBertForSequenceClassification,
@@ -41,6 +42,7 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
(
DistilBertModel,
DistilBertForMaskedLM,
DistilBertForMultipleChoice,
DistilBertForQuestionAnswering,
DistilBertForSequenceClassification,
DistilBertForTokenClassification,
@@ -218,6 +220,25 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
)
self.check_loss_output(result)
def create_and_check_distilbert_for_multiple_choice(
self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_choices = self.num_choices
model = DistilBertForMultipleChoice(config=config)
model.to(torch_device)
model.eval()
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
loss, logits = model(
multiple_choice_inputs_ids, attention_mask=multiple_choice_input_mask, labels=choice_labels,
)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices])
self.check_loss_output(result)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(config, input_ids, input_mask, sequence_labels, token_labels, choice_labels) = config_and_inputs
@@ -251,6 +272,10 @@ class DistilBertModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_distilbert_for_token_classification(*config_and_inputs)
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_distilbert_for_multiple_choice(*config_and_inputs)
# @slow
# def test_model_from_pretrained(self):
# for model_name in DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: