[Longformer] Better handling of global attention mask vs local attention mask (#4672)
* better api * improve automatic setting of global attention mask * fix longformer bug * fix global attention mask in test * fix global attn mask flatten * fix slow tests * update docstring * update docs and make more robust * improve attention mask
This commit is contained in:
committed by
GitHub
parent
e2230ba77b
commit
56ee2560be
@@ -184,6 +184,7 @@ class LongformerModelTester(object):
|
||||
loss, start_logits, end_logits = model(
|
||||
input_ids,
|
||||
attention_mask=input_mask,
|
||||
global_attention_mask=input_mask,
|
||||
token_type_ids=token_type_ids,
|
||||
start_positions=sequence_labels,
|
||||
end_positions=sequence_labels,
|
||||
@@ -239,9 +240,11 @@ class LongformerModelTester(object):
|
||||
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
|
||||
loss, logits = model(
|
||||
multiple_choice_inputs_ids,
|
||||
attention_mask=multiple_choice_input_mask,
|
||||
global_attention_mask=multiple_choice_input_mask,
|
||||
token_type_ids=multiple_choice_token_type_ids,
|
||||
labels=choice_labels,
|
||||
)
|
||||
@@ -330,7 +333,7 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
class LongformerModelIntegrationTest(unittest.TestCase):
|
||||
@slow
|
||||
def test_inference_no_head(self):
|
||||
model = LongformerModel.from_pretrained("longformer-base-4096")
|
||||
model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
|
||||
model.to(torch_device)
|
||||
|
||||
# 'Hello world! ' repeated 1000 times
|
||||
@@ -350,7 +353,7 @@ class LongformerModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_inference_masked_lm(self):
|
||||
model = LongformerForMaskedLM.from_pretrained("longformer-base-4096")
|
||||
model = LongformerForMaskedLM.from_pretrained("allenai/longformer-base-4096")
|
||||
model.to(torch_device)
|
||||
|
||||
# 'Hello world! ' repeated 1000 times
|
||||
|
||||
Reference in New Issue
Block a user