Longformer for question answering (#4500)

* added LongformerForQuestionAnswering

* add LongformerForQuestionAnswering

* fix import for LongformerForMaskedLM

* add LongformerForQuestionAnswering

* hardcoded sep_token_id

* compute attention_mask if not provided

* combine global_attention_mask with attention_mask when provided

* update example in  docstring

* add assert error messages, better attention combine

* add test for longformerForQuestionAnswering

* typo

* cast gloabl_attention_mask to long

* make style

* Update src/transformers/configuration_longformer.py

* Update src/transformers/configuration_longformer.py

* fix the code quality

* Merge branch 'longformer-for-question-answering' of https://github.com/patil-suraj/transformers into longformer-for-question-answering

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Suraj Patil
2020-05-25 22:13:36 +05:30
committed by GitHub
parent a34a9896ac
commit 03d8527de0
5 changed files with 203 additions and 3 deletions

View File

@@ -29,6 +29,7 @@ if is_torch_available():
LongformerConfig,
LongformerModel,
LongformerForMaskedLM,
LongformerForQuestionAnswering,
)
@@ -171,6 +172,28 @@ class LongformerModelTester(object):
)
self.check_loss_output(result)
def create_and_check_longformer_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = LongformerForQuestionAnswering(config=config)
model.to(torch_device)
model.eval()
loss, start_logits, end_logits = model(
input_ids,
attention_mask=input_mask,
token_type_ids=token_type_ids,
start_positions=sequence_labels,
end_positions=sequence_labels,
)
result = {
"loss": loss,
"start_logits": start_logits,
"end_logits": end_logits,
}
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
self.check_loss_output(result)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
@@ -185,6 +208,26 @@ class LongformerModelTester(object):
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
return config, inputs_dict
def prepare_config_and_inputs_for_question_answering(self):
config_and_inputs = self.prepare_config_and_inputs()
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = config_and_inputs
# Replace sep_token_id by some random id
input_ids[input_ids == config.sep_token_id] = torch.randint(0, config.vocab_size, (1,)).item()
# Make sure there are exactly three sep_token_id
input_ids[:, -3:] = config.sep_token_id
input_mask = torch.ones_like(input_ids)
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
@require_torch
class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
@@ -209,6 +252,10 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_for_masked_lm(*config_and_inputs)
def test_longformer_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_question_answering()
self.model_tester.create_and_check_longformer_for_question_answering(*config_and_inputs)
class LongformerModelIntegrationTest(unittest.TestCase):
@slow