From 03d8527de06405fcc940c39aa4ce93f769774cd7 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Mon, 25 May 2020 22:13:36 +0530 Subject: [PATCH] Longformer for question answering (#4500) * added LongformerForQuestionAnswering * add LongformerForQuestionAnswering * fix import for LongformerForMaskedLM * add LongformerForQuestionAnswering * hardcoded sep_token_id * compute attention_mask if not provided * combine global_attention_mask with attention_mask when provided * update example in docstring * add assert error messages, better attention combine * add test for longformerForQuestionAnswering * typo * cast gloabl_attention_mask to long * make style * Update src/transformers/configuration_longformer.py * Update src/transformers/configuration_longformer.py * fix the code quality * Merge branch 'longformer-for-question-answering' of https://github.com/patil-suraj/transformers into longformer-for-question-answering Co-authored-by: Patrick von Platen --- src/transformers/__init__.py | 7 +- src/transformers/configuration_longformer.py | 3 +- src/transformers/modeling_auto.py | 8 +- src/transformers/modeling_longformer.py | 141 +++++++++++++++++++ tests/test_modeling_longformer.py | 47 +++++++ 5 files changed, 203 insertions(+), 3 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index dc3bfbbeda..3c034d1768 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -335,7 +335,12 @@ if is_torch_available(): REFORMER_PRETRAINED_MODEL_ARCHIVE_MAP, ) - from .modeling_longformer import LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP, LongformerModel, LongformerForMaskedLM + from .modeling_longformer import ( + LongformerModel, + LongformerForMaskedLM, + LongformerForQuestionAnswering, + LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP, + ) # Optimization from .optimization import ( diff --git a/src/transformers/configuration_longformer.py b/src/transformers/configuration_longformer.py index c1862b6b8e..dedafac943 100644 --- a/src/transformers/configuration_longformer.py +++ b/src/transformers/configuration_longformer.py @@ -64,6 +64,7 @@ class LongformerConfig(RobertaConfig): pretrained_config_archive_map = LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP model_type = "longformer" - def __init__(self, attention_window: Union[List[int], int] = 512, **kwargs): + def __init__(self, attention_window: Union[List[int], int] = 512, sep_token_id: int = 2, **kwargs): super().__init__(**kwargs) self.attention_window = attention_window + self.sep_token_id = sep_token_id diff --git a/src/transformers/modeling_auto.py b/src/transformers/modeling_auto.py index 9bc0869f76..7d4ec99d09 100644 --- a/src/transformers/modeling_auto.py +++ b/src/transformers/modeling_auto.py @@ -101,7 +101,12 @@ from .modeling_flaubert import ( FlaubertWithLMHeadModel, ) from .modeling_gpt2 import GPT2_PRETRAINED_MODEL_ARCHIVE_MAP, GPT2LMHeadModel, GPT2Model -from .modeling_longformer import LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP, LongformerForMaskedLM, LongformerModel +from .modeling_longformer import ( + LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP, + LongformerForMaskedLM, + LongformerForQuestionAnswering, + LongformerModel, +) from .modeling_marian import MarianMTModel from .modeling_openai import OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP, OpenAIGPTLMHeadModel, OpenAIGPTModel from .modeling_reformer import ReformerModel, ReformerModelWithLMHead @@ -260,6 +265,7 @@ MODEL_FOR_QUESTION_ANSWERING_MAPPING = OrderedDict( [ (DistilBertConfig, DistilBertForQuestionAnswering), (AlbertConfig, AlbertForQuestionAnswering), + (LongformerConfig, LongformerForQuestionAnswering), (RobertaConfig, RobertaForQuestionAnswering), (BertConfig, BertForQuestionAnswering), (XLNetConfig, XLNetForQuestionAnsweringSimple), diff --git a/src/transformers/modeling_longformer.py b/src/transformers/modeling_longformer.py index bce1d881c6..8ff2534bcb 100644 --- a/src/transformers/modeling_longformer.py +++ b/src/transformers/modeling_longformer.py @@ -707,3 +707,144 @@ class LongformerForMaskedLM(BertPreTrainedModel): outputs = (masked_lm_loss,) + outputs return outputs # (masked_lm_loss), prediction_scores, (hidden_states), (attentions) + + +@add_start_docstrings( + """Longformer Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of + the hidden-states output to compute `span start logits` and `span end logits`). """, + LONGFORMER_START_DOCSTRING, +) +class LongformerForQuestionAnswering(BertPreTrainedModel): + config_class = LongformerConfig + pretrained_model_archive_map = LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP + base_model_prefix = "longformer" + + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.longformer = LongformerModel(config) + self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) + + self.init_weights() + + def _get_question_end_index(self, input_ids): + sep_token_indices = (input_ids == self.config.sep_token_id).nonzero() + + assert sep_token_indices.size(1) == 2, "input_ids should have two dimensions" + assert sep_token_indices.size(0) == 3 * input_ids.size( + 0 + ), "There should be exactly three separator tokens in every sample for questions answering" + + return sep_token_indices.view(input_ids.size(0), 3, 2)[:, 0, 1] + + def _compute_global_attention_mask(self, input_ids): + question_end_index = self._get_question_end_index(input_ids) + question_end_index = question_end_index.unsqueeze(dim=1) # size: batch_size x 1 + # bool attention mask with True in locations of global attention + attention_mask = torch.arange(input_ids.size(1), device=input_ids.device) + attention_mask = attention_mask.expand_as(input_ids) < question_end_index + + attention_mask = attention_mask.int() + 1 # from True, False to 2, 1 + return attention_mask.long() + + @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING) + def forward( + self, + input_ids, + attention_mask=None, + token_type_ids=None, + position_ids=None, + inputs_embeds=None, + start_positions=None, + end_positions=None, + ): + r""" + start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for position (index) of the start of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). + Position outside of the sequence are not taken into account for computing the loss. + end_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`): + Labels for position (index) of the end of the labelled span for computing the token classification loss. + Positions are clamped to the length of the sequence (`sequence_length`). + Position outside of the sequence are not taken into account for computing the loss. + Returns: + :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs: + loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`labels` is provided): + Total span extraction loss is the sum of a Cross-Entropy for the start and end positions. + start_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`): + Span-start scores (before SoftMax). + end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`): + Span-end scores (before SoftMax). + hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``): + Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) + of shape :obj:`(batch_size, sequence_length, hidden_size)`. + Hidden-states of the model at the output of each layer plus the initial embedding outputs. + attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``): + Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape + :obj:`(batch_size, num_heads, sequence_length, sequence_length)`. + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + Examples:: + from transformers import LongformerTokenizer, LongformerForQuestionAnswering + import torch + + tokenizer = LongformerTokenizer.from_pretrained(longformer-base-4096') + model = LongformerForQuestionAnswering.from_pretrained(longformer-base-4096') + + question, text = "Who was Jim Henson?", "Jim Henson was a nice puppet" + encoding = tokenizer.encode_plus(question, text) + input_ids = encoding["input_ids"] + + # default is local attention everywhere + # the forward method will automatically set global attention on question tokens + attention_mask = encoding["attention_mask"] + + start_scores, end_scores = model(torch.tensor([input_ids]), attention_mask=attention_mask) + all_tokens = tokenizer.convert_ids_to_tokens(input_ids) + answer = ' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1]) + """ + + # set global attention on question tokens + global_attention_mask = self._compute_global_attention_mask(input_ids) + if attention_mask is None: + attention_mask = global_attention_mask + else: + # combine global_attention_mask with attention_mask + # global attention on question tokens, no attention on padding tokens + attention_mask = global_attention_mask * attention_mask + + outputs = self.longformer( + input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + inputs_embeds=inputs_embeds, + ) + + sequence_output = outputs[0] + + logits = self.qa_outputs(sequence_output) + start_logits, end_logits = logits.split(1, dim=-1) + start_logits = start_logits.squeeze(-1) + end_logits = end_logits.squeeze(-1) + + outputs = (start_logits, end_logits,) + outputs[2:] + if start_positions is not None and end_positions is not None: + # If we are on multi-GPU, split add a dimension + if len(start_positions.size()) > 1: + start_positions = start_positions.squeeze(-1) + if len(end_positions.size()) > 1: + end_positions = end_positions.squeeze(-1) + # sometimes the start/end positions are outside our model inputs, we ignore these terms + ignored_index = start_logits.size(1) + start_positions.clamp_(0, ignored_index) + end_positions.clamp_(0, ignored_index) + + loss_fct = CrossEntropyLoss(ignore_index=ignored_index) + start_loss = loss_fct(start_logits, start_positions) + end_loss = loss_fct(end_logits, end_positions) + total_loss = (start_loss + end_loss) / 2 + outputs = (total_loss,) + outputs + + return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions) diff --git a/tests/test_modeling_longformer.py b/tests/test_modeling_longformer.py index 0d7ba840db..d850a9d30f 100644 --- a/tests/test_modeling_longformer.py +++ b/tests/test_modeling_longformer.py @@ -29,6 +29,7 @@ if is_torch_available(): LongformerConfig, LongformerModel, LongformerForMaskedLM, + LongformerForQuestionAnswering, ) @@ -171,6 +172,28 @@ class LongformerModelTester(object): ) self.check_loss_output(result) + def create_and_check_longformer_for_question_answering( + self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + ): + model = LongformerForQuestionAnswering(config=config) + model.to(torch_device) + model.eval() + loss, start_logits, end_logits = model( + input_ids, + attention_mask=input_mask, + token_type_ids=token_type_ids, + start_positions=sequence_labels, + end_positions=sequence_labels, + ) + result = { + "loss": loss, + "start_logits": start_logits, + "end_logits": end_logits, + } + self.parent.assertListEqual(list(result["start_logits"].size()), [self.batch_size, self.seq_length]) + self.parent.assertListEqual(list(result["end_logits"].size()), [self.batch_size, self.seq_length]) + self.check_loss_output(result) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() ( @@ -185,6 +208,26 @@ class LongformerModelTester(object): inputs_dict = {"input_ids": input_ids, "token_type_ids": token_type_ids, "attention_mask": input_mask} return config, inputs_dict + def prepare_config_and_inputs_for_question_answering(self): + config_and_inputs = self.prepare_config_and_inputs() + ( + config, + input_ids, + token_type_ids, + input_mask, + sequence_labels, + token_labels, + choice_labels, + ) = config_and_inputs + + # Replace sep_token_id by some random id + input_ids[input_ids == config.sep_token_id] = torch.randint(0, config.vocab_size, (1,)).item() + # Make sure there are exactly three sep_token_id + input_ids[:, -3:] = config.sep_token_id + input_mask = torch.ones_like(input_ids) + + return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels + @require_torch class LongformerModelTest(ModelTesterMixin, unittest.TestCase): @@ -209,6 +252,10 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_longformer_for_masked_lm(*config_and_inputs) + def test_longformer_for_question_answering(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs_for_question_answering() + self.model_tester.create_and_check_longformer_for_question_answering(*config_and_inputs) + class LongformerModelIntegrationTest(unittest.TestCase): @slow