[Longformer] Better handling of global attention mask vs local attention mask (#4672)
* better api * improve automatic setting of global attention mask * fix longformer bug * fix global attention mask in test * fix global attn mask flatten * fix slow tests * update docstring * update docs and make more robust * improve attention mask
This commit is contained in:
committed by
GitHub
parent
e2230ba77b
commit
56ee2560be
@@ -21,7 +21,7 @@ A selecetd few tokens attend "globally" to all other tokens, as it is convention
|
|||||||
Note that "locally" and "globally" attending tokens are projected by different query, key and value matrices.
|
Note that "locally" and "globally" attending tokens are projected by different query, key and value matrices.
|
||||||
Also note that every "locally" attending token not only attends to tokens within its window :math:`w`, but also to all "globally" attending tokens so that global attention is *symmetric*.
|
Also note that every "locally" attending token not only attends to tokens within its window :math:`w`, but also to all "globally" attending tokens so that global attention is *symmetric*.
|
||||||
|
|
||||||
The user can define which tokens are masked, which tokens attend "locally" and which tokens attend "globally" by setting the `config.attention_mask` `torch.Tensor` appropriately. In contrast to other models `Longformer` accepts the following values in `config.attention_mask`: `0` - the token is masked and not attended at all (as is done in other models), `1` - the token attends "locally", `2` - token attends "globally". For more information please also refer to :func:`~transformers.LongformerModel.forward` method.
|
The user can define which tokens attend "locally" and which tokens attend "globally" by setting the tensor `global_attention_mask` at run-time appropriately. `Longformer` employs the following logic for `global_attention_mask`: `0` - the token attends "locally", `1` - token attends "globally". For more information please also refer to :func:`~transformers.LongformerModel.forward` method.
|
||||||
|
|
||||||
Using Longformer self attention, the memory and time complexity of the query-key matmul operation, which usually represents the memory and time bottleneck, can be reduced from :math:`\mathcal{O}(n_s \times n_s)` to :math:`\mathcal{O}(n_s \times w)`, with :math:`n_s` being the sequence length and :math:`w` being the average window size. It is assumed that the number of "globally" attending tokens is insignificant as compared to the number of "locally" attending tokens.
|
Using Longformer self attention, the memory and time complexity of the query-key matmul operation, which usually represents the memory and time bottleneck, can be reduced from :math:`\mathcal{O}(n_s \times n_s)` to :math:`\mathcal{O}(n_s \times w)`, with :math:`n_s` being the sequence length and :math:`w` being the average window size. It is assumed that the number of "globally" attending tokens is insignificant as compared to the number of "locally" attending tokens.
|
||||||
|
|
||||||
|
|||||||
@@ -39,6 +39,44 @@ LONGFORMER_PRETRAINED_MODEL_ARCHIVE_MAP = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_question_end_index(input_ids, sep_token_id):
|
||||||
|
"""
|
||||||
|
Computes the index of the first occurance of `sep_token_id`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
sep_token_indices = (input_ids == sep_token_id).nonzero()
|
||||||
|
batch_size = input_ids.shape[0]
|
||||||
|
|
||||||
|
assert sep_token_indices.shape[1] == 2, "`input_ids` should have two dimensions"
|
||||||
|
assert (
|
||||||
|
sep_token_indices.shape[0] == 3 * batch_size
|
||||||
|
), f"There should be exactly three separator tokens: {sep_token_id} in every sample for questions answering. You might also consider to set `global_attention_mask` manually in the forward function to avoid this error."
|
||||||
|
|
||||||
|
return sep_token_indices.view(batch_size, 3, 2)[:, 0, 1]
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_global_attention_mask(input_ids, sep_token_id, before_sep_token=True):
|
||||||
|
"""
|
||||||
|
Computes global attention mask by putting attention on all tokens
|
||||||
|
before `sep_token_id` if `before_sep_token is True` else after
|
||||||
|
`sep_token_id`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
question_end_index = _get_question_end_index(input_ids, sep_token_id)
|
||||||
|
question_end_index = question_end_index.unsqueeze(dim=1) # size: batch_size x 1
|
||||||
|
# bool attention mask with True in locations of global attention
|
||||||
|
attention_mask = torch.arange(input_ids.shape[1], device=input_ids.device)
|
||||||
|
if before_sep_token is True:
|
||||||
|
attention_mask = (attention_mask.expand_as(input_ids) < question_end_index).to(torch.uint8)
|
||||||
|
else:
|
||||||
|
# last token is separation token and should not be counted and in the middle are two separation tokens
|
||||||
|
attention_mask = (attention_mask.expand_as(input_ids) > (question_end_index + 1)).to(torch.uint8) * (
|
||||||
|
attention_mask.expand_as(input_ids) < input_ids.shape[-1]
|
||||||
|
).to(torch.uint8)
|
||||||
|
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
|
|
||||||
class LongformerSelfAttention(nn.Module):
|
class LongformerSelfAttention(nn.Module):
|
||||||
def __init__(self, config, layer_id):
|
def __init__(self, config, layer_id):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -420,17 +458,22 @@ LONGFORMER_INPUTS_DOCSTRING = r"""
|
|||||||
|
|
||||||
`What are input IDs? <../glossary.html#input-ids>`__
|
`What are input IDs? <../glossary.html#input-ids>`__
|
||||||
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
|
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
|
||||||
Mask to decide the attention given on each token, local attention, global attenion, or no attention (for padding tokens).
|
Mask to avoid performing attention on padding token indices.
|
||||||
|
Mask values selected in ``[0, 1]``:
|
||||||
|
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
|
||||||
|
|
||||||
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
|
|
||||||
|
global_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
|
||||||
|
Mask to decide the attention given on each token, local attention or global attenion.
|
||||||
Tokens with global attention attends to all other tokens, and all other tokens attend to them. This is important for
|
Tokens with global attention attends to all other tokens, and all other tokens attend to them. This is important for
|
||||||
task-specific finetuning because it makes the model more flexible at representing the task. For example,
|
task-specific finetuning because it makes the model more flexible at representing the task. For example,
|
||||||
for classification, the <s> token should be given global attention. For QA, all question tokens should also have
|
for classification, the <s> token should be given global attention. For QA, all question tokens should also have
|
||||||
global attention. Please refer to the Longformer paper https://arxiv.org/abs/2004.05150 for more details.
|
global attention. Please refer to the Longformer paper https://arxiv.org/abs/2004.05150 for more details.
|
||||||
Mask values selected in ``[0, 1, 2]``:
|
Mask values selected in ``[0, 1]``:
|
||||||
``0`` for no attention (padding tokens),
|
``0`` for local attention (a sliding window attention),
|
||||||
``1`` for local attention (a sliding window attention),
|
``1`` for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
|
||||||
``2`` for global attention (tokens that attend to all other tokens, and all other tokens attend to them).
|
|
||||||
|
|
||||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
|
||||||
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
|
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`{0}`, `optional`, defaults to :obj:`None`):
|
||||||
Segment token indices to indicate first and second portions of the inputs.
|
Segment token indices to indicate first and second portions of the inputs.
|
||||||
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
|
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
|
||||||
@@ -542,6 +585,7 @@ class LongformerModel(RobertaModel):
|
|||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
|
global_attention_mask=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -593,6 +637,19 @@ class LongformerModel(RobertaModel):
|
|||||||
if isinstance(self.config.attention_window, int)
|
if isinstance(self.config.attention_window, int)
|
||||||
else max(self.config.attention_window)
|
else max(self.config.attention_window)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# merge `global_attention_mask` and `attention_mask`
|
||||||
|
if global_attention_mask is not None:
|
||||||
|
# longformer self attention expects attention mask to have 0 (no attn), 1 (local attn), 2 (global attn)
|
||||||
|
# (global_attention_mask + 1) => 1 for local attention, 2 for global attention
|
||||||
|
# => final attention_mask => 0 for no attention, 1 for local attention 2 for global attention
|
||||||
|
if attention_mask is not None:
|
||||||
|
attention_mask = attention_mask * (global_attention_mask + 1)
|
||||||
|
else:
|
||||||
|
# simply use `global_attention_mask` as `attention_mask`
|
||||||
|
# if no `attention_mask` is given
|
||||||
|
attention_mask = global_attention_mask + 1
|
||||||
|
|
||||||
padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds = self._pad_to_window_size(
|
padding_len, input_ids, attention_mask, token_type_ids, position_ids, inputs_embeds = self._pad_to_window_size(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
@@ -646,6 +703,7 @@ class LongformerForMaskedLM(BertPreTrainedModel):
|
|||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
|
global_attention_mask=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -695,6 +753,7 @@ class LongformerForMaskedLM(BertPreTrainedModel):
|
|||||||
outputs = self.longformer(
|
outputs = self.longformer(
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
global_attention_mask=global_attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
@@ -734,6 +793,7 @@ class LongformerForSequenceClassification(BertPreTrainedModel):
|
|||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
|
global_attention_mask=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -778,15 +838,16 @@ class LongformerForSequenceClassification(BertPreTrainedModel):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if attention_mask is None:
|
if global_attention_mask is None:
|
||||||
attention_mask = torch.ones_like(input_ids)
|
logger.info("Initializing global attention on CLS token...")
|
||||||
|
global_attention_mask = torch.zeros_like(input_ids)
|
||||||
# global attention on cls token
|
# global attention on cls token
|
||||||
attention_mask[:, 0] = 2
|
global_attention_mask[:, 0] = 1
|
||||||
|
|
||||||
outputs = self.longformer(
|
outputs = self.longformer(
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
global_attention_mask=global_attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
@@ -846,31 +907,12 @@ class LongformerForQuestionAnswering(BertPreTrainedModel):
|
|||||||
|
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def _compute_global_attention_mask(self, input_ids):
|
|
||||||
question_end_index = self._get_question_end_index(input_ids)
|
|
||||||
question_end_index = question_end_index.unsqueeze(dim=1) # size: batch_size x 1
|
|
||||||
# bool attention mask with True in locations of global attention
|
|
||||||
attention_mask = torch.arange(input_ids.shape[1], device=input_ids.device)
|
|
||||||
attention_mask = attention_mask.expand_as(input_ids) < question_end_index
|
|
||||||
|
|
||||||
return attention_mask.long() + 1 # True => global attention; False => local attention
|
|
||||||
|
|
||||||
def _get_question_end_index(self, input_ids):
|
|
||||||
sep_token_indices = (input_ids == self.config.sep_token_id).nonzero()
|
|
||||||
batch_size = input_ids.shape[0]
|
|
||||||
|
|
||||||
assert sep_token_indices.shape[1] == 2, "`input_ids` should have two dimensions"
|
|
||||||
assert (
|
|
||||||
sep_token_indices.shape[0] == 3 * batch_size
|
|
||||||
), f"There should be exactly three separator tokens: {self.config.sep_token_id} in every sample for questions answering"
|
|
||||||
|
|
||||||
return sep_token_indices.view(batch_size, 3, 2)[:, 0, 1]
|
|
||||||
|
|
||||||
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
@add_start_docstrings_to_callable(LONGFORMER_INPUTS_DOCSTRING.format("(batch_size, sequence_length)"))
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
|
global_attention_mask=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -929,17 +971,15 @@ class LongformerForQuestionAnswering(BertPreTrainedModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
# set global attention on question tokens
|
# set global attention on question tokens
|
||||||
global_attention_mask = self._compute_global_attention_mask(input_ids)
|
if global_attention_mask is None:
|
||||||
if attention_mask is None:
|
logger.info("Initializing global attention on question tokens...")
|
||||||
attention_mask = global_attention_mask
|
# put global attention on all tokens until `config.sep_token_id` is reached
|
||||||
else:
|
global_attention_mask = _compute_global_attention_mask(input_ids, self.config.sep_token_id)
|
||||||
# combine global_attention_mask with attention_mask
|
|
||||||
# global attention on question tokens, no attention on padding tokens
|
|
||||||
attention_mask = global_attention_mask * attention_mask
|
|
||||||
|
|
||||||
outputs = self.longformer(
|
outputs = self.longformer(
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
global_attention_mask=global_attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
@@ -998,6 +1038,7 @@ class LongformerForTokenClassification(BertPreTrainedModel):
|
|||||||
self,
|
self,
|
||||||
input_ids=None,
|
input_ids=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
|
global_attention_mask=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -1043,6 +1084,7 @@ class LongformerForTokenClassification(BertPreTrainedModel):
|
|||||||
outputs = self.longformer(
|
outputs = self.longformer(
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
|
global_attention_mask=global_attention_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
position_ids=position_ids,
|
position_ids=position_ids,
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
@@ -1097,6 +1139,7 @@ class LongformerForMultipleChoice(BertPreTrainedModel):
|
|||||||
input_ids=None,
|
input_ids=None,
|
||||||
token_type_ids=None,
|
token_type_ids=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
|
global_attention_mask=None,
|
||||||
labels=None,
|
labels=None,
|
||||||
position_ids=None,
|
position_ids=None,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
@@ -1129,29 +1172,51 @@ class LongformerForMultipleChoice(BertPreTrainedModel):
|
|||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
|
|
||||||
from transformers import LongformerTokenizer, LongformerForTokenClassification
|
from transformers import LongformerTokenizer, LongformerForMultipleChoice
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
tokenizer = LongformerTokenizer.from_pretrained('longformer-base-4096')
|
tokenizer = LongformerTokenizer.from_pretrained('allenai/longformer-base-4096')
|
||||||
model = LongformerForMultipleChoice.from_pretrained('longformer-base-4096')
|
model = LongformerForMultipleChoice.from_pretrained('allenai/longformer-base-4096')
|
||||||
choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
|
# context = "The dog is cute" | choice = "the dog" / "the cat"
|
||||||
input_ids = torch.tensor([tokenizer.encode(s, add_special_tokens=True) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
|
choices = [("The dog is cute", "the dog"), ("The dog is cute", "the cat")]
|
||||||
|
input_ids = torch.tensor([tokenizer.encode(s[0], s[1], add_special_tokens=True) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
|
||||||
labels = torch.tensor(1).unsqueeze(0) # Batch size 1
|
labels = torch.tensor(1).unsqueeze(0) # Batch size 1
|
||||||
|
|
||||||
|
# global attention is automatically put on "the dog" and "the cat"
|
||||||
outputs = model(input_ids, labels=labels)
|
outputs = model(input_ids, labels=labels)
|
||||||
loss, classification_scores = outputs[:2]
|
loss, classification_scores = outputs[:2]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
num_choices = input_ids.shape[1]
|
num_choices = input_ids.shape[1]
|
||||||
|
|
||||||
|
# set global attention on question tokens
|
||||||
|
if global_attention_mask is None:
|
||||||
|
logger.info("Initializing global attention on multiple choice...")
|
||||||
|
# put global attention on all tokens after `config.sep_token_id`
|
||||||
|
global_attention_mask = torch.stack(
|
||||||
|
[
|
||||||
|
_compute_global_attention_mask(input_ids[:, i], self.config.sep_token_id, before_sep_token=False)
|
||||||
|
for i in range(num_choices)
|
||||||
|
],
|
||||||
|
dim=1,
|
||||||
|
)
|
||||||
|
|
||||||
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
|
flat_input_ids = input_ids.view(-1, input_ids.size(-1))
|
||||||
flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
|
||||||
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
|
||||||
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
|
||||||
|
flat_global_attention_mask = (
|
||||||
|
global_attention_mask.view(-1, global_attention_mask.size(-1))
|
||||||
|
if global_attention_mask is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
outputs = self.longformer(
|
outputs = self.longformer(
|
||||||
flat_input_ids,
|
flat_input_ids,
|
||||||
position_ids=flat_position_ids,
|
position_ids=flat_position_ids,
|
||||||
token_type_ids=flat_token_type_ids,
|
token_type_ids=flat_token_type_ids,
|
||||||
attention_mask=flat_attention_mask,
|
attention_mask=flat_attention_mask,
|
||||||
|
global_attention_mask=flat_global_attention_mask,
|
||||||
)
|
)
|
||||||
pooled_output = outputs[1]
|
pooled_output = outputs[1]
|
||||||
|
|
||||||
|
|||||||
@@ -184,6 +184,7 @@ class LongformerModelTester(object):
|
|||||||
loss, start_logits, end_logits = model(
|
loss, start_logits, end_logits = model(
|
||||||
input_ids,
|
input_ids,
|
||||||
attention_mask=input_mask,
|
attention_mask=input_mask,
|
||||||
|
global_attention_mask=input_mask,
|
||||||
token_type_ids=token_type_ids,
|
token_type_ids=token_type_ids,
|
||||||
start_positions=sequence_labels,
|
start_positions=sequence_labels,
|
||||||
end_positions=sequence_labels,
|
end_positions=sequence_labels,
|
||||||
@@ -239,9 +240,11 @@ class LongformerModelTester(object):
|
|||||||
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||||
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||||
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||||
|
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||||
loss, logits = model(
|
loss, logits = model(
|
||||||
multiple_choice_inputs_ids,
|
multiple_choice_inputs_ids,
|
||||||
attention_mask=multiple_choice_input_mask,
|
attention_mask=multiple_choice_input_mask,
|
||||||
|
global_attention_mask=multiple_choice_input_mask,
|
||||||
token_type_ids=multiple_choice_token_type_ids,
|
token_type_ids=multiple_choice_token_type_ids,
|
||||||
labels=choice_labels,
|
labels=choice_labels,
|
||||||
)
|
)
|
||||||
@@ -330,7 +333,7 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
class LongformerModelIntegrationTest(unittest.TestCase):
|
class LongformerModelIntegrationTest(unittest.TestCase):
|
||||||
@slow
|
@slow
|
||||||
def test_inference_no_head(self):
|
def test_inference_no_head(self):
|
||||||
model = LongformerModel.from_pretrained("longformer-base-4096")
|
model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
|
|
||||||
# 'Hello world! ' repeated 1000 times
|
# 'Hello world! ' repeated 1000 times
|
||||||
@@ -350,7 +353,7 @@ class LongformerModelIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_inference_masked_lm(self):
|
def test_inference_masked_lm(self):
|
||||||
model = LongformerForMaskedLM.from_pretrained("longformer-base-4096")
|
model = LongformerForMaskedLM.from_pretrained("allenai/longformer-base-4096")
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
|
|
||||||
# 'Hello world! ' repeated 1000 times
|
# 'Hello world! ' repeated 1000 times
|
||||||
|
|||||||
Reference in New Issue
Block a user