Output global_attentions in Longformer models (#7562)

* Output global_attentions in Longformer models

* make style

* small refactoring

* fix tests

* make fix-copies

* add for tf as well

* remove comments in test

* make fix-copies

* make style

* add docs

* make docstring pretty

Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
Guillaume Filion
2020-11-05 15:10:43 -05:00
committed by GitHub
parent 7abc1d96d1
commit 27b402cab0
7 changed files with 684 additions and 155 deletions

View File

@@ -436,7 +436,7 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
tf.debugging.assert_near(chunked_hidden_states[0, 0, :, 0], expected_slice_along_chunk, rtol=1e-3)
def test_layer_local_attn(self):
model = TFLongformerModel.from_pretrained("patrickvonplaten/longformer-random-tiny", use_cdn=False)
model = TFLongformerModel.from_pretrained("patrickvonplaten/longformer-random-tiny")
layer = model.longformer.encoder.layer[0].attention.self_attention
hidden_states = self._get_hidden_states()
batch_size, seq_length, hidden_size = hidden_states.shape
@@ -449,7 +449,7 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0)
output_hidden_states = layer(
[hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn, None]
[hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn]
)[0]
expected_slice = tf.convert_to_tensor(
@@ -460,7 +460,7 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
tf.debugging.assert_near(output_hidden_states[0, 1], expected_slice, rtol=1e-3)
def test_layer_global_attn(self):
model = TFLongformerModel.from_pretrained("patrickvonplaten/longformer-random-tiny", use_cdn=False)
model = TFLongformerModel.from_pretrained("patrickvonplaten/longformer-random-tiny")
layer = model.longformer.encoder.layer[0].attention.self_attention
hidden_states = self._get_hidden_states()
@@ -481,7 +481,7 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
is_global_attn = tf.math.reduce_any(is_index_global_attn)
output_hidden_states = layer(
[hidden_states, -tf.math.abs(attention_mask), is_index_masked, is_index_global_attn, is_global_attn, None]
[hidden_states, -tf.math.abs(attention_mask), is_index_masked, is_index_global_attn, is_global_attn]
)[0]
self.assertTrue(output_hidden_states.shape, (2, 4, 8))
@@ -496,6 +496,74 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
tf.debugging.assert_near(output_hidden_states[0, 2], expected_slice_0, rtol=1e-3)
tf.debugging.assert_near(output_hidden_states[1, -2], expected_slice_1, rtol=1e-3)
def test_layer_attn_probs(self):
model = TFLongformerModel.from_pretrained("patrickvonplaten/longformer-random-tiny")
layer = model.longformer.encoder.layer[0].attention.self_attention
hidden_states = tf.concat([self._get_hidden_states(), self._get_hidden_states() - 0.5], axis=0)
batch_size, seq_length, hidden_size = hidden_states.shape
# create attn mask
attention_mask_1 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32)
attention_mask_2 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32)
attention_mask_1 = tf.where(tf.range(4)[None, :, None, None] > 1, 10000.0, attention_mask_1)
attention_mask_1 = tf.where(tf.range(4)[None, :, None, None] > 2, -10000.0, attention_mask_1)
attention_mask_2 = tf.where(tf.range(4)[None, :, None, None] > 0, 10000.0, attention_mask_2)
attention_mask = tf.concat([attention_mask_1, attention_mask_2], axis=0)
is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0)
is_index_global_attn = tf.math.greater(attention_mask[:, :, 0, 0], 0)
is_global_attn = tf.math.reduce_any(is_index_global_attn)
output_hidden_states, local_attentions, global_attentions = layer(
[hidden_states, -tf.math.abs(attention_mask), is_index_masked, is_index_global_attn, is_global_attn]
)
self.assertEqual(local_attentions.shape, (2, 4, 2, 8))
self.assertEqual(global_attentions.shape, (2, 2, 3, 4))
self.assertTrue((local_attentions[0, 2:4, :, :] == 0).numpy().tolist())
self.assertTrue((local_attentions[1, 1:4, :, :] == 0).numpy().tolist())
#
# The weight of all tokens with local attention must sum to 1.
self.assertTrue(
(tf.math.abs(tf.math.reduce_sum(global_attentions[0, :, :2, :], axis=-1) - 1) < 1e-6).numpy().tolist()
)
self.assertTrue(
(tf.math.abs(tf.math.reduce_sum(global_attentions[1, :, :1, :], axis=-1) - 1) < 1e-6).numpy().tolist()
)
tf.debugging.assert_near(
local_attentions[0, 0, 0, :],
tf.convert_to_tensor(
[0.3328, 0.0000, 0.0000, 0.0000, 0.0000, 0.3355, 0.3318, 0.0000], dtype=tf.dtypes.float32
),
rtol=1e-3,
)
tf.debugging.assert_near(
local_attentions[1, 0, 0, :],
tf.convert_to_tensor(
[0.2492, 0.2502, 0.2502, 0.0000, 0.0000, 0.2505, 0.0000, 0.0000], dtype=tf.dtypes.float32
),
rtol=1e-3,
)
# All the global attention weights must sum to 1.
self.assertTrue((tf.math.abs(tf.math.reduce_sum(global_attentions, axis=-1) - 1) < 1e-6).numpy().tolist())
tf.debugging.assert_near(
global_attentions[0, 0, 1, :],
tf.convert_to_tensor([0.2500, 0.2500, 0.2500, 0.2500], dtype=tf.dtypes.float32),
rtol=1e-3,
)
tf.debugging.assert_near(
global_attentions[1, 0, 0, :],
tf.convert_to_tensor([0.2497, 0.2500, 0.2499, 0.2504], dtype=tf.dtypes.float32),
rtol=1e-3,
)
@slow
def test_inference_no_head(self):
model = TFLongformerModel.from_pretrained("allenai/longformer-base-4096")