[Longformer] Better handling of global attention mask vs local attention mask (#4672)
* better api * improve automatic setting of global attention mask * fix longformer bug * fix global attention mask in test * fix global attn mask flatten * fix slow tests * update docstring * update docs and make more robust * improve attention mask
This commit is contained in:
committed by
GitHub
parent
e2230ba77b
commit
56ee2560be
@@ -39,6 +39,44 @@ LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
||||
}
|
||||
|
||||
|
||||
def _get_question_end_index(input_ids, sep_token_id):
|
||||
"""
|
||||
Computes the index of the first occurance of `sep_token_id`.
|
||||
"""
|
||||
|
||||
sep_token_indices = (input_ids == sep_token_id).nonzero()
|
||||
batch_size = input_ids.shape[0]
|
||||
|
||||
assert sep_token_indices.shape[1] == 2, "`input_ids` should have two dimensions"
|
||||
assert (
|
||||
sep_token_indices.shape[0] == 3 * batch_size
|
||||
), f"There should be exactly three separator tokens: {sep_token_id} in every sample for questions answering. You might also consider to set `global_attention_mask` manually in the forward function to avoid this error."
|
||||
|
||||
return sep_token_indices.view(batch_size, 3, 2)[:, 0, 1]
|
||||
|
||||
|
||||
def _compute_global_attention_mask(input_ids, sep_token_id, before_sep_token=True):
|
||||
"""
|
||||
Computes global attention mask by putting attention on all tokens
|
||||
before `sep_token_id` if `before_sep_token is True` else after
|
||||
`sep_token_id`.
|
||||
"""
|
||||
|
||||
question_end_index = _get_question_end_index(input_ids, sep_token_id)
|
||||
question_end_index = question_end_index.unsqueeze(dim=1) # size: batch_size x 1
|
||||
# bool attention mask with True in locations of global attention
|
||||
attention_mask = torch.arange(input_ids.shape[1], device=input_ids.device)
|
||||
if before_sep_token is True:
|
||||
attention_mask = (attention_mask.expand_as(input_ids) < question_end_index).to(torch.uint8)
|
||||
else:
|
||||
# last token is separation token and should not be counted and in the middle are two separation tokens
|
||||
attention_mask = (attention_mask.expand_as(input_ids) > (question_end_index + 1)).to(torch.uint8) * (
|
||||
attention_mask.expand_as(input_ids) < input_ids.shape[-1]
|
||||
).to(torch.uint8)
|
||||
|
||||
return attention_mask
|
||||
|
||||
|
||||
class LongformerSelfAttention(nn.Module):
|
||||
def __init__(self, config, layer_id):
|
||||
super().__init__()
|
||||
@@ -420,17 +458,22 @@ LONGFORMER_INPUTS_DOCSTRING = r"""
|
||||
|
||||
`What are input IDs? <../glossary.html#input-ids>`__
|
||||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
|
||||
Mask to decide the attention given on each token, local attention, global attenion, or no attention (for padding tokens).
|
||||
Mask to avoid performing attention on padding token indices.
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
|
||||
global_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
|
||||
Mask to decide the attention given on each token, local attention or global attenion.
|
||||
Tokens with global attention attends to all other tokens, and all other tokens attend to them. This is important for
|
||||
task-specific finetuning because it makes the model more flexible at representing the task. For example,
|
||||
for classification, the <s> token should be given global attention. For QA, all question tokens should also have
|
||||
global attention. Please refer to the Longformer paper https://arxiv.org/abs/2004.05150 for more details.
|
||||
Mask values selected in ``[0, 1, 2]``:
|
||||
``0`` for no attention (padding tokens),
|
||||
``1`` for local attention (a sliding window attention),
|
||||
``2`` for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
|
||||
Mask values selected in ``[0, 1]``:
|
||||
``0`` for local attention (a sliding window attention),
|
||||
``1`` for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
|
||||
Segment token indices to indicate first and second portions of the inputs.
|
||||
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
|
||||
@@ -542,6 +585,7 @@ class LongformerModel(RobertaModel):
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
global_attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
inputs_embeds=None,
|
||||
@@ -593,6 +637,19 @@ class LongformerModel(RobertaModel):
|
||||
if isinstance(self.config.attention_window, int)
|
||||
else max(self.config.attention_window)
|
||||
)
|
||||
|
||||
# merge `global_attention_mask` and `attention_mask`
|
||||
if global_attention_mask is not None:
|
||||
# longformer self attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn)
|
||||
# (global_attention_mask + 1) => 1 for local attention, 2 for global attention
|
||||
# => final attention_mask => 0 for no attention, 1 for local attention 2 for global attention
|
||||
if attention_mask is not None:
|
||||
attention_mask = attention_mask * (global_attention_mask + 1)
|
||||
else:
|
||||
# simply use `global_attention_mask` as `attention_mask`
|
||||
# if no `attention_mask` is given
|
||||
attention_mask = global_attention_mask + 1
|
||||
|
||||
padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds = self._pad_to_window_size(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
@@ -646,6 +703,7 @@ class LongformerForMaskedLM(BertPreTrainedModel):
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
global_attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
inputs_embeds=None,
|
||||
@@ -695,6 +753,7 @@ class LongformerForMaskedLM(BertPreTrainedModel):
|
||||
outputs = self.longformer(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
global_attention_mask=global_attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@@ -734,6 +793,7 @@ class LongformerForSequenceClassification(BertPreTrainedModel):
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
global_attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
inputs_embeds=None,
|
||||
@@ -778,15 +838,16 @@ class LongformerForSequenceClassification(BertPreTrainedModel):
|
||||
|
||||
"""
|
||||
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones_like(input_ids)
|
||||
|
||||
# global attention on cls token
|
||||
attention_mask[:, 0] = 2
|
||||
if global_attention_mask is None:
|
||||
logger.info("Initializing global attention on CLS token...")
|
||||
global_attention_mask = torch.zeros_like(input_ids)
|
||||
# global attention on cls token
|
||||
global_attention_mask[:, 0] = 1
|
||||
|
||||
outputs = self.longformer(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
global_attention_mask=global_attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@@ -846,31 +907,12 @@ class LongformerForQuestionAnswering(BertPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def _compute_global_attention_mask(self, input_ids):
|
||||
question_end_index = self._get_question_end_index(input_ids)
|
||||
question_end_index = question_end_index.unsqueeze(dim=1) # size: batch_size x 1
|
||||
# bool attention mask with True in locations of global attention
|
||||
attention_mask = torch.arange(input_ids.shape[1], device=input_ids.device)
|
||||
attention_mask = attention_mask.expand_as(input_ids) < question_end_index
|
||||
|
||||
return attention_mask.long() + 1 # True => global attention; False => local attention
|
||||
|
||||
def _get_question_end_index(self, input_ids):
|
||||
sep_token_indices = (input_ids == self.config.sep_token_id).nonzero()
|
||||
batch_size = input_ids.shape[0]
|
||||
|
||||
assert sep_token_indices.shape[1] == 2, "`input_ids` should have two dimensions"
|
||||
assert (
|
||||
sep_token_indices.shape[0] == 3 * batch_size
|
||||
), f"There should be exactly three separator tokens: {self.config.sep_token_id} in every sample for questions answering"
|
||||
|
||||
return sep_token_indices.view(batch_size, 3, 2)[:, 0, 1]
|
||||
|
||||
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||
def forward(
|
||||
self,
|
||||
input_ids,
|
||||
attention_mask=None,
|
||||
global_attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
inputs_embeds=None,
|
||||
@@ -929,17 +971,15 @@ class LongformerForQuestionAnswering(BertPreTrainedModel):
|
||||
"""
|
||||
|
||||
# set global attention on question tokens
|
||||
global_attention_mask = self._compute_global_attention_mask(input_ids)
|
||||
if attention_mask is None:
|
||||
attention_mask = global_attention_mask
|
||||
else:
|
||||
# combine global_attention_mask with attention_mask
|
||||
# global attention on question tokens, no attention on padding tokens
|
||||
attention_mask = global_attention_mask * attention_mask
|
||||
if global_attention_mask is None:
|
||||
logger.info("Initializing global attention on question tokens...")
|
||||
# put global attention on all tokens until `config.sep_token_id` is reached
|
||||
global_attention_mask = _compute_global_attention_mask(input_ids, self.config.sep_token_id)
|
||||
|
||||
outputs = self.longformer(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
global_attention_mask=global_attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@@ -998,6 +1038,7 @@ class LongformerForTokenClassification(BertPreTrainedModel):
|
||||
self,
|
||||
input_ids=None,
|
||||
attention_mask=None,
|
||||
global_attention_mask=None,
|
||||
token_type_ids=None,
|
||||
position_ids=None,
|
||||
inputs_embeds=None,
|
||||
@@ -1043,6 +1084,7 @@ class LongformerForTokenClassification(BertPreTrainedModel):
|
||||
outputs = self.longformer(
|
||||
input_ids,
|
||||
attention_mask=attention_mask,
|
||||
global_attention_mask=global_attention_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
@@ -1097,6 +1139,7 @@ class LongformerForMultipleChoice(BertPreTrainedModel):
|
||||
input_ids=None,
|
||||
token_type_ids=None,
|
||||
attention_mask=None,
|
||||
global_attention_mask=None,
|
||||
labels=None,
|
||||
position_ids=None,
|
||||
inputs_embeds=None,
|
||||
@@ -1129,29 +1172,51 @@ class LongformerForMultipleChoice(BertPreTrainedModel):
|
||||
|
||||
Examples::
|
||||
|
||||
from transformers import LongformerTokenizer, LongformerForTokenClassification
|
||||
from transformers import LongformerTokenizer, LongformerForMultipleChoice
|
||||
import torch
|
||||
|
||||
tokenizer = LongformerTokenizer.from_pretrained('longformer-base-4096')
|
||||
model = LongformerForMultipleChoice.from_pretrained('longformer-base-4096')
|
||||
choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
|
||||
input_ids = torch.tensor([tokenizer.encode(s, add_special_tokens=True) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
|
||||
tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096')
|
||||
model = LongformerForMultipleChoice.from_pretrained('allenai/longformer-base-4096')
|
||||
# context = "The dog is cute" | choice = "the dog" / "the cat"
|
||||
choices = [("The dog is cute", "the dog"), ("The dog is cute", "the cat")]
|
||||
input_ids = torch.tensor([tokenizer.encode(s[0], s[1], add_special_tokens=True) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
|
||||
labels = torch.tensor(1).unsqueeze(0) # Batch size 1
|
||||
|
||||
# global attention is automatically put on "the dog" and "the cat"
|
||||
outputs = model(input_ids, labels=labels)
|
||||
loss, classification_scores = outputs[:2]
|
||||
|
||||
"""
|
||||
num_choices = input_ids.shape[1]
|
||||
|
||||
# set global attention on question tokens
|
||||
if global_attention_mask is None:
|
||||
logger.info("Initializing global attention on multiple choice...")
|
||||
# put global attention on all tokens after `config.sep_token_id`
|
||||
global_attention_mask = torch.stack(
|
||||
[
|
||||
_compute_global_attention_mask(input_ids[:, i], self.config.sep_token_id, before_sep_token=False)
|
||||
for i in range(num_choices)
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
|
||||
flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
||||
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
||||
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
||||
flat_global_attention_mask = (
|
||||
global_attention_mask.view(-1, global_attention_mask.size(-1))
|
||||
if global_attention_mask is not None
|
||||
else None
|
||||
)
|
||||
|
||||
outputs = self.longformer(
|
||||
flat_input_ids,
|
||||
position_ids=flat_position_ids,
|
||||
token_type_ids=flat_token_type_ids,
|
||||
attention_mask=flat_attention_mask,
|
||||
global_attention_mask=flat_global_attention_mask,
|
||||
)
|
||||
pooled_output = outputs[1]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user