[Longformer] Multiple choice for longformer (#4645)

* add multiple choice for longformer

* add models to docs

* adapt docstring

* add test to longformer

* add longformer for mc in init and modeling auto

* fix tests
This commit is contained in:
Patrick von Platen
2020-05-29 13:46:08 +02:00
committed by GitHub
parent 91487cbb8e
commit 9c17256447
12 changed files with 227 additions and 63 deletions

View File

@@ -32,6 +32,7 @@ if is_torch_available():
LongformerForSequenceClassification,
LongformerForTokenClassification,
LongformerForQuestionAnswering,
LongformerForMultipleChoice,
)
@@ -228,6 +229,29 @@ class LongformerModelTester(object):
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.seq_length, self.num_labels])
self.check_loss_output(result)
def create_and_check_longformer_for_multiple_choice(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_choices = self.num_choices
model = LongformerForMultipleChoice(config=config)
model.to(torch_device)
model.eval()
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
loss, logits = model(
multiple_choice_inputs_ids,
attention_mask=multiple_choice_input_mask,
token_type_ids=multiple_choice_token_type_ids,
labels=choice_labels,
)
result = {
"loss": loss,
"logits": logits,
}
self.parent.assertListEqual(list(result["logits"].size()), [self.batch_size, self.num_choices])
self.check_loss_output(result)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
@@ -298,6 +322,10 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_for_token_classification(*config_and_inputs)
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_for_multiple_choice(*config_and_inputs)
class LongformerModelIntegrationTest(unittest.TestCase):
@slow