From 56ee2560bedfc471befe438ec253a1804b6eeb48 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 29 May 2020 17:58:42 +0200 Subject: [PATCH] [Longformer] Better handling of global attention mask vs local attention mask (#4672) * better api * improve automatic setting of global attention mask * fix longformer bug * fix global attention mask in test * fix global attn mask flatten * fix slow tests * update docstring * update docs and make more robust * improve attention mask --- docs/source/model_doc/longformer.rst | 2 +- src/transformers/modeling_longformer.py | 151 +++++++++++++++++------- tests/test_modeling_longformer.py | 7 +- 3 files changed, 114 insertions(+), 46 deletions(-) diff --git a/docs/source/model_doc/longformer.rst b/docs/source/model_doc/longformer.rst index 07d0898ccf..b25bab001c 100644 --- a/docs/source/model_doc/longformer.rst +++ b/docs/source/model_doc/longformer.rst @@ -21,7 +21,7 @@ A selecetd few tokens attend "globally" to all other tokens, as it is convention Note that "locally" and "globally" attending tokens are projected by different query, key and value matrices. Also note that every "locally" attending token not only attends to tokens within its window :math:`w`, but also to all "globally" attending tokens so that global attention is *symmetric*. -The user can define which tokens are masked, which tokens attend "locally" and which tokens attend "globally" by setting the `config.attention_mask` `torch.Tensor` appropriately. In contrast to other models `Longformer` accepts the following values in `config.attention_mask`: `0` - the token is masked and not attended at all (as is done in other models), `1` - the token attends "locally", `2` - token attends "globally". For more information please also refer to :func:`~transformers.LongformerModel.forward` method. +The user can define which tokens attend "locally" and which tokens attend "globally" by setting the tensor `global_attention_mask` at run-time appropriately. `Longformer` employs the following logic for `global_attention_mask`: `0` - the token attends "locally", `1` - token attends "globally". For more information please also refer to :func:`~transformers.LongformerModel.forward` method. Using Longformer self attention, the memory and time complexity of the query-key matmul operation, which usually represents the memory and time bottleneck, can be reduced from :math:`\mathcal{O}(n_s \times n_s)` to :math:`\mathcal{O}(n_s \times w)`, with :math:`n_s` being the sequence length and :math:`w` being the average window size. It is assumed that the number of "globally" attending tokens is insignificant as compared to the number of "locally" attending tokens. diff --git a/src/transformers/modeling_longformer.py b/src/transformers/modeling_longformer.py index 5baf056a10..2a8e862c2e 100644 --- a/src/transformers/modeling_longformer.py +++ b/src/transformers/modeling_longformer.py @@ -39,6 +39,44 @@ LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP = { } +def _get_question_end_index(input_ids, sep_token_id): + """ + Computes the index of the first occurance of `sep_token_id`. + """ + + sep_token_indices = (input_ids == sep_token_id).nonzero() + batch_size = input_ids.shape[0] + + assert sep_token_indices.shape[1] == 2, "`input_ids` should have two dimensions" + assert ( + sep_token_indices.shape[0] == 3 * batch_size + ), f"There should be exactly three separator tokens: {sep_token_id} in every sample for questions answering. You might also consider to set `global_attention_mask` manually in the forward function to avoid this error." + + return sep_token_indices.view(batch_size, 3, 2)[:, 0, 1] + + +def _compute_global_attention_mask(input_ids, sep_token_id, before_sep_token=True): + """ + Computes global attention mask by putting attention on all tokens + before `sep_token_id` if `before_sep_token is True` else after + `sep_token_id`. + """ + + question_end_index = _get_question_end_index(input_ids, sep_token_id) + question_end_index = question_end_index.unsqueeze(dim=1) # size: batch_size x 1 + # bool attention mask with True in locations of global attention + attention_mask = torch.arange(input_ids.shape[1], device=input_ids.device) + if before_sep_token is True: + attention_mask = (attention_mask.expand_as(input_ids) < question_end_index).to(torch.uint8) + else: + # last token is separation token and should not be counted and in the middle are two separation tokens + attention_mask = (attention_mask.expand_as(input_ids) > (question_end_index + 1)).to(torch.uint8) * ( + attention_mask.expand_as(input_ids) < input_ids.shape[-1] + ).to(torch.uint8) + + return attention_mask + + class LongformerSelfAttention(nn.Module): def __init__(self, config, layer_id): super().__init__() @@ -420,17 +458,22 @@ LONGFORMER_INPUTS_DOCSTRING = r""" `What are input IDs? <../glossary.html#input-ids>`__ attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): - Mask to decide the attention given on each token, local attention, global attenion, or no attention (for padding tokens). + Mask to avoid performing attention on padding token indices. + Mask values selected in ``[0, 1]``: + ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens. + + `What are attention masks? <../glossary.html#attention-mask>`__ + + global_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): + Mask to decide the attention given on each token, local attention or global attenion. Tokens with global attention attends to all other tokens, and all other tokens attend to them. This is important for task-specific finetuning because it makes the model more flexible at representing the task. For example, for classification, the token should be given global attention. For QA, all question tokens should also have global attention. Please refer to the Longformer paper https://arxiv.org/abs/2004.05150 for more details. - Mask values selected in ``[0, 1, 2]``: - ``0`` for no attention (padding tokens), - ``1`` for local attention (a sliding window attention), - ``2`` for global attention (tokens that attend to all other tokens, and all other tokens attend to them). + Mask values selected in ``[0, 1]``: + ``0`` for local attention (a sliding window attention), + ``1`` for global attention (tokens that attend to all other tokens, and all other tokens attend to them). - `What are attention masks? <../glossary.html#attention-mask>`__ token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`): Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1`` @@ -542,6 +585,7 @@ class LongformerModel(RobertaModel): self, input_ids=None, attention_mask=None, + global_attention_mask=None, token_type_ids=None, position_ids=None, inputs_embeds=None, @@ -593,6 +637,19 @@ class LongformerModel(RobertaModel): if isinstance(self.config.attention_window, int) else max(self.config.attention_window) ) + + # merge `global_attention_mask` and `attention_mask` + if global_attention_mask is not None: + # longformer self attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn) + # (global_attention_mask + 1) => 1 for local attention, 2 for global attention + # => final attention_mask => 0 for no attention, 1 for local attention 2 for global attention + if attention_mask is not None: + attention_mask = attention_mask * (global_attention_mask + 1) + else: + # simply use `global_attention_mask` as `attention_mask` + # if no `attention_mask` is given + attention_mask = global_attention_mask + 1 + padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds = self._pad_to_window_size( input_ids=input_ids, attention_mask=attention_mask, @@ -646,6 +703,7 @@ class LongformerForMaskedLM(BertPreTrainedModel): self, input_ids=None, attention_mask=None, + global_attention_mask=None, token_type_ids=None, position_ids=None, inputs_embeds=None, @@ -695,6 +753,7 @@ class LongformerForMaskedLM(BertPreTrainedModel): outputs = self.longformer( input_ids, attention_mask=attention_mask, + global_attention_mask=global_attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, inputs_embeds=inputs_embeds, @@ -734,6 +793,7 @@ class LongformerForSequenceClassification(BertPreTrainedModel): self, input_ids=None, attention_mask=None, + global_attention_mask=None, token_type_ids=None, position_ids=None, inputs_embeds=None, @@ -778,15 +838,16 @@ class LongformerForSequenceClassification(BertPreTrainedModel): """ - if attention_mask is None: - attention_mask = torch.ones_like(input_ids) - - # global attention on cls token - attention_mask[:, 0] = 2 + if global_attention_mask is None: + logger.info("Initializing global attention on CLS token...") + global_attention_mask = torch.zeros_like(input_ids) + # global attention on cls token + global_attention_mask[:, 0] = 1 outputs = self.longformer( input_ids, attention_mask=attention_mask, + global_attention_mask=global_attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, inputs_embeds=inputs_embeds, @@ -846,31 +907,12 @@ class LongformerForQuestionAnswering(BertPreTrainedModel): self.init_weights() - def _compute_global_attention_mask(self, input_ids): - question_end_index = self._get_question_end_index(input_ids) - question_end_index = question_end_index.unsqueeze(dim=1) # size: batch_size x 1 - # bool attention mask with True in locations of global attention - attention_mask = torch.arange(input_ids.shape[1], device=input_ids.device) - attention_mask = attention_mask.expand_as(input_ids) < question_end_index - - return attention_mask.long() + 1 # True => global attention; False => local attention - - def _get_question_end_index(self, input_ids): - sep_token_indices = (input_ids == self.config.sep_token_id).nonzero() - batch_size = input_ids.shape[0] - - assert sep_token_indices.shape[1] == 2, "`input_ids` should have two dimensions" - assert ( - sep_token_indices.shape[0] == 3 * batch_size - ), f"There should be exactly three separator tokens: {self.config.sep_token_id} in every sample for questions answering" - - return sep_token_indices.view(batch_size, 3, 2)[:, 0, 1] - @add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)")) def forward( self, input_ids, attention_mask=None, + global_attention_mask=None, token_type_ids=None, position_ids=None, inputs_embeds=None, @@ -929,17 +971,15 @@ class LongformerForQuestionAnswering(BertPreTrainedModel): """ # set global attention on question tokens - global_attention_mask = self._compute_global_attention_mask(input_ids) - if attention_mask is None: - attention_mask = global_attention_mask - else: - # combine global_attention_mask with attention_mask - # global attention on question tokens, no attention on padding tokens - attention_mask = global_attention_mask * attention_mask + if global_attention_mask is None: + logger.info("Initializing global attention on question tokens...") + # put global attention on all tokens until `config.sep_token_id` is reached + global_attention_mask = _compute_global_attention_mask(input_ids, self.config.sep_token_id) outputs = self.longformer( input_ids, attention_mask=attention_mask, + global_attention_mask=global_attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, inputs_embeds=inputs_embeds, @@ -998,6 +1038,7 @@ class LongformerForTokenClassification(BertPreTrainedModel): self, input_ids=None, attention_mask=None, + global_attention_mask=None, token_type_ids=None, position_ids=None, inputs_embeds=None, @@ -1043,6 +1084,7 @@ class LongformerForTokenClassification(BertPreTrainedModel): outputs = self.longformer( input_ids, attention_mask=attention_mask, + global_attention_mask=global_attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, inputs_embeds=inputs_embeds, @@ -1097,6 +1139,7 @@ class LongformerForMultipleChoice(BertPreTrainedModel): input_ids=None, token_type_ids=None, attention_mask=None, + global_attention_mask=None, labels=None, position_ids=None, inputs_embeds=None, @@ -1129,29 +1172,51 @@ class LongformerForMultipleChoice(BertPreTrainedModel): Examples:: - from transformers import LongformerTokenizer, LongformerForTokenClassification + from transformers import LongformerTokenizer, LongformerForMultipleChoice import torch - tokenizer = LongformerTokenizer.from_pretrained('longformer-base-4096') - model = LongformerForMultipleChoice.from_pretrained('longformer-base-4096') - choices = ["Hello, my dog is cute", "Hello, my cat is amazing"] - input_ids = torch.tensor([tokenizer.encode(s, add_special_tokens=True) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices + tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096') + model = LongformerForMultipleChoice.from_pretrained('allenai/longformer-base-4096') + # context = "The dog is cute" | choice = "the dog" / "the cat" + choices = [("The dog is cute", "the dog"), ("The dog is cute", "the cat")] + input_ids = torch.tensor([tokenizer.encode(s[0], s[1], add_special_tokens=True) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices labels = torch.tensor(1).unsqueeze(0) # Batch size 1 + + # global attention is automatically put on "the dog" and "the cat" outputs = model(input_ids, labels=labels) loss, classification_scores = outputs[:2] """ num_choices = input_ids.shape[1] + # set global attention on question tokens + if global_attention_mask is None: + logger.info("Initializing global attention on multiple choice...") + # put global attention on all tokens after `config.sep_token_id` + global_attention_mask = torch.stack( + [ + _compute_global_attention_mask(input_ids[:, i], self.config.sep_token_id, before_sep_token=False) + for i in range(num_choices) + ], + dim=1, + ) + flat_input_ids = input_ids.view(-1, input_ids.size(-1)) flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None + flat_global_attention_mask = ( + global_attention_mask.view(-1, global_attention_mask.size(-1)) + if global_attention_mask is not None + else None + ) + outputs = self.longformer( flat_input_ids, position_ids=flat_position_ids, token_type_ids=flat_token_type_ids, attention_mask=flat_attention_mask, + global_attention_mask=flat_global_attention_mask, ) pooled_output = outputs[1] diff --git a/tests/test_modeling_longformer.py b/tests/test_modeling_longformer.py index 00a67d716f..51429c8457 100644 --- a/tests/test_modeling_longformer.py +++ b/tests/test_modeling_longformer.py @@ -184,6 +184,7 @@ class LongformerModelTester(object): loss, start_logits, end_logits = model( input_ids, attention_mask=input_mask, + global_attention_mask=input_mask, token_type_ids=token_type_ids, start_positions=sequence_labels, end_positions=sequence_labels, @@ -239,9 +240,11 @@ class LongformerModelTester(object): multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() + multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous() loss, logits = model( multiple_choice_inputs_ids, attention_mask=multiple_choice_input_mask, + global_attention_mask=multiple_choice_input_mask, token_type_ids=multiple_choice_token_type_ids, labels=choice_labels, ) @@ -330,7 +333,7 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase): class LongformerModelIntegrationTest(unittest.TestCase): @slow def test_inference_no_head(self): - model = LongformerModel.from_pretrained("longformer-base-4096") + model = LongformerModel.from_pretrained("allenai/longformer-base-4096") model.to(torch_device) # 'Hello world! ' repeated 1000 times @@ -350,7 +353,7 @@ class LongformerModelIntegrationTest(unittest.TestCase): @slow def test_inference_masked_lm(self): - model = LongformerForMaskedLM.from_pretrained("longformer-base-4096") + model = LongformerForMaskedLM.from_pretrained("allenai/longformer-base-4096") model.to(torch_device) # 'Hello world! ' repeated 1000 times