[Longformer] Better handling of global attention mask vs local attention mask (#4672)

* better api

* improve automatic setting of global attention mask

* fix longformer bug

* fix global attention mask in test

* fix global attn mask flatten

* fix slow tests

* update docstring

* update docs and make more robust

* improve attention mask
This commit is contained in:
Patrick von Platen
2020-05-29 17:58:42 +02:00
committed by GitHub
parent e2230ba77b
commit 56ee2560be
3 changed files with 114 additions and 46 deletions

View File

@@ -184,6 +184,7 @@ class LongformerModelTester(object):
loss, start_logits, end_logits = model(
input_ids,
attention_mask=input_mask,
global_attention_mask=input_mask,
token_type_ids=token_type_ids,
start_positions=sequence_labels,
end_positions=sequence_labels,
@@ -239,9 +240,11 @@ class LongformerModelTester(object):
multiple_choice_inputs_ids = input_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_token_type_ids = token_type_ids.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
multiple_choice_input_mask = input_mask.unsqueeze(1).expand(-1, self.num_choices, -1).contiguous()
loss, logits = model(
multiple_choice_inputs_ids,
attention_mask=multiple_choice_input_mask,
global_attention_mask=multiple_choice_input_mask,
token_type_ids=multiple_choice_token_type_ids,
labels=choice_labels,
)
@@ -330,7 +333,7 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
class LongformerModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_no_head(self):
model = LongformerModel.from_pretrained("longformer-base-4096")
model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
model.to(torch_device)
# 'Hello world! ' repeated 1000 times
@@ -350,7 +353,7 @@ class LongformerModelIntegrationTest(unittest.TestCase):
@slow
def test_inference_masked_lm(self):
model = LongformerForMaskedLM.from_pretrained("longformer-base-4096")
model = LongformerForMaskedLM.from_pretrained("allenai/longformer-base-4096")
model.to(torch_device)
# 'Hello world! ' repeated 1000 times