Output global_attentions in Longformer models (#7562)

* Output global_attentions in Longformer models

* make style

* small refactoring

* fix tests

* make fix-copies

* add for tf as well

* remove comments in test

* make fix-copies

* make style

* add docs

* make docstring pretty

Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
Guillaume Filion
2020-11-05 15:10:43 -05:00
committed by GitHub
parent 7abc1d96d1
commit 27b402cab0
7 changed files with 684 additions and 155 deletions

View File

@@ -71,6 +71,8 @@ class LongformerModelTester:
# [num_attention_heads, encoder_seq_length, encoder_key_length], but LongformerSelfAttention
# returns attention of shape [num_attention_heads, encoder_seq_length, self.attention_window + 1]
# because its local attention only attends to `self.attention_window + 1` locations
# (assuming no token with global attention, otherwise the last dimension of attentions
# is x + self.attention_window + 1, where x is the number of tokens with global attention)
self.key_length = self.attention_window + 1
# because of padding `encoder_seq_length`, is different from `seq_length`. Relevant for
@@ -476,9 +478,20 @@ class LongformerModelIntegrationTest(unittest.TestCase):
layer = model.encoder.layer[0].attention.self.to(torch_device)
hidden_states = self._get_hidden_states()
batch_size, seq_length, hidden_size = hidden_states.size()
attention_mask = torch.zeros((batch_size, 1, 1, seq_length), dtype=torch.float32, device=torch_device)
attention_mask[:, :, :, -2:] = -10000
output_hidden_states = layer(hidden_states, attention_mask)[0]
attention_mask = torch.zeros((batch_size, seq_length), dtype=torch.float32, device=torch_device)
attention_mask[:, -2:] = -10000
is_index_masked = attention_mask < 0
is_index_global_attn = attention_mask > 0
is_global_attn = is_index_global_attn.flatten().any().item()
output_hidden_states, _ = layer(
hidden_states,
attention_mask=attention_mask,
is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn,
)
self.assertTrue(output_hidden_states.shape, (1, 4, 8))
self.assertTrue(
@@ -499,13 +512,24 @@ class LongformerModelIntegrationTest(unittest.TestCase):
layer = model.encoder.layer[0].attention.self.to(torch_device)
hidden_states = torch.cat([self._get_hidden_states(), self._get_hidden_states() - 0.5], dim=0)
batch_size, seq_length, hidden_size = hidden_states.size()
attention_mask = torch.zeros((batch_size, 1, 1, seq_length), dtype=torch.float32, device=torch_device)
attention_mask = torch.zeros((batch_size, seq_length), dtype=torch.float32, device=torch_device)
# create attn mask
attention_mask[0, :, :, -2:] = 10000.0
attention_mask[0, :, :, -1:] = -10000.0
attention_mask[1, :, :, 1:] = 10000.0
output_hidden_states = layer(hidden_states, attention_mask)[0]
attention_mask[0, -2:] = 10000.0
attention_mask[0, -1:] = -10000.0
attention_mask[1, 1:] = 10000.0
is_index_masked = attention_mask < 0
is_index_global_attn = attention_mask > 0
is_global_attn = is_index_global_attn.flatten().any().item()
output_hidden_states, _, _ = layer(
hidden_states,
attention_mask=attention_mask,
is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn,
)
self.assertTrue(output_hidden_states.shape, (2, 4, 8))
@@ -533,6 +557,93 @@ class LongformerModelIntegrationTest(unittest.TestCase):
)
)
def test_layer_attn_probs(self):
model = LongformerModel.from_pretrained("patrickvonplaten/longformer-random-tiny")
model.eval()
layer = model.encoder.layer[0].attention.self.to(torch_device)
hidden_states = torch.cat([self._get_hidden_states(), self._get_hidden_states() - 0.5], dim=0)
batch_size, seq_length, hidden_size = hidden_states.size()
attention_mask = torch.zeros((batch_size, seq_length), dtype=torch.float32, device=torch_device)
# create attn mask
attention_mask[0, -2:] = 10000.0
attention_mask[0, -1:] = -10000.0
attention_mask[1, 1:] = 10000.0
is_index_masked = attention_mask < 0
is_index_global_attn = attention_mask > 0
is_global_attn = is_index_global_attn.flatten().any().item()
output_hidden_states, local_attentions, global_attentions = layer(
hidden_states,
attention_mask=attention_mask,
is_index_masked=is_index_masked,
is_index_global_attn=is_index_global_attn,
is_global_attn=is_global_attn,
)
self.assertEqual(local_attentions.shape, (2, 4, 2, 8))
self.assertEqual(global_attentions.shape, (2, 2, 3, 4))
# All tokens with global attention have weight 0 in local attentions.
self.assertTrue(torch.all(local_attentions[0, 2:4, :, :] == 0))
self.assertTrue(torch.all(local_attentions[1, 1:4, :, :] == 0))
# The weight of all tokens with local attention must sum to 1.
self.assertTrue(torch.all(torch.abs(global_attentions[0, :, :2, :].sum(dim=-1) - 1) < 1e-6))
self.assertTrue(torch.all(torch.abs(global_attentions[1, :, :1, :].sum(dim=-1) - 1) < 1e-6))
self.assertTrue(
torch.allclose(
local_attentions[0, 0, 0, :],
torch.tensor(
[0.3328, 0.0000, 0.0000, 0.0000, 0.0000, 0.3355, 0.3318, 0.0000],
dtype=torch.float32,
device=torch_device,
),
atol=1e-3,
)
)
self.assertTrue(
torch.allclose(
local_attentions[1, 0, 0, :],
torch.tensor(
[0.2492, 0.2502, 0.2502, 0.0000, 0.0000, 0.2505, 0.0000, 0.0000],
dtype=torch.float32,
device=torch_device,
),
atol=1e-3,
)
)
# All the global attention weights must sum to 1.
self.assertTrue(torch.all(torch.abs(global_attentions.sum(dim=-1) - 1) < 1e-6))
self.assertTrue(
torch.allclose(
global_attentions[0, 0, 1, :],
torch.tensor(
[0.2500, 0.2500, 0.2500, 0.2500],
dtype=torch.float32,
device=torch_device,
),
atol=1e-3,
)
)
self.assertTrue(
torch.allclose(
global_attentions[1, 0, 0, :],
torch.tensor(
[0.2497, 0.2500, 0.2499, 0.2504],
dtype=torch.float32,
device=torch_device,
),
atol=1e-3,
)
)
@slow
def test_inference_no_head(self):
model = LongformerModel.from_pretrained("allenai/longformer-base-4096")
@@ -541,6 +652,7 @@ class LongformerModelIntegrationTest(unittest.TestCase):
# 'Hello world!'
input_ids = torch.tensor([[0, 20920, 232, 328, 1437, 2]], dtype=torch.long, device=torch_device)
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
output = model(input_ids, attention_mask=attention_mask)[0]
output_without_mask = model(input_ids)[0]