Longformer for question answering (#4500)
* added LongformerForQuestionAnswering * add LongformerForQuestionAnswering * fix import for LongformerForMaskedLM * add LongformerForQuestionAnswering * hardcoded sep_token_id * compute attention_mask if not provided * combine global_attention_mask with attention_mask when provided * update example in docstring * add assert error messages, better attention combine * add test for longformerForQuestionAnswering * typo * cast gloabl_attention_mask to long * make style * Update src/transformers/configuration_longformer.py * Update src/transformers/configuration_longformer.py * fix the code quality * Merge branch 'longformer-for-question-answering' of https://github.com/patil-suraj/transformers into longformer-for-question-answering Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -29,6 +29,7 @@ if is_torch_available():
|
||||
LongformerConfig,
|
||||
LongformerModel,
|
||||
LongformerForMaskedLM,
|
||||
LongformerForQuestionAnswering,
|
||||
)
|
||||
|
||||
|
||||
@@ -171,6 +172,28 @@ class LongformerModelTester(object):
|
||||
)
|
||||
self.check_loss_output(result)
|
||||
|
||||
def create_and_check_longformer_for_question_answering(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = LongformerForQuestionAnswering(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
loss, start_logits, end_logits = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
start_positions=sequence_labels,
|
||||
end_positions=sequence_labels,
|
||||
)
|
||||
result = {
|
||||
"loss": loss,
|
||||
"start_logits": start_logits,
|
||||
"end_logits": end_logits,
|
||||
}
|
||||
self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length])
|
||||
self.check_loss_output(result)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
@@ -185,6 +208,26 @@ class LongformerModelTester(object):
|
||||
inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask}
|
||||
return config, inputs_dict
|
||||
|
||||
def prepare_config_and_inputs_for_question_answering(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = config_and_inputs
|
||||
|
||||
# Replace sep_token_id by some random id
|
||||
input_ids[input_ids == config.sep_token_id] = torch.randint(0, config.vocab_size, (1,)).item()
|
||||
# Make sure there are exactly three sep_token_id
|
||||
input_ids[:, -3:] = config.sep_token_id
|
||||
input_mask = torch.ones_like(input_ids)
|
||||
|
||||
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
|
||||
|
||||
@require_torch
|
||||
class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
@@ -209,6 +252,10 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_longformer_for_masked_lm(*config_and_inputs)
|
||||
|
||||
def test_longformer_for_question_answering(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_question_answering()
|
||||
self.model_tester.create_and_check_longformer_for_question_answering(*config_and_inputs)
|
||||
|
||||
|
||||
class LongformerModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
|
||||
Reference in New Issue
Block a user