Output global_attentions in Longformer models (#7562)
* Output global_attentions in Longformer models * make style * small refactoring * fix tests * make fix-copies * add for tf as well * remove comments in test * make fix-copies * make style * add docs * make docstring pretty Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -436,7 +436,7 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
|
||||
tf.debugging.assert_near(chunked_hidden_states[0, 0, :, 0], expected_slice_along_chunk, rtol=1e-3)
|
||||
|
||||
def test_layer_local_attn(self):
|
||||
model = TFLongformerModel.from_pretrained("patrickvonplaten/longformer-random-tiny", use_cdn=False)
|
||||
model = TFLongformerModel.from_pretrained("patrickvonplaten/longformer-random-tiny")
|
||||
layer = model.longformer.encoder.layer[0].attention.self_attention
|
||||
hidden_states = self._get_hidden_states()
|
||||
batch_size, seq_length, hidden_size = hidden_states.shape
|
||||
@@ -449,7 +449,7 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
|
||||
is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0)
|
||||
|
||||
output_hidden_states = layer(
|
||||
[hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn, None]
|
||||
[hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn]
|
||||
)[0]
|
||||
|
||||
expected_slice = tf.convert_to_tensor(
|
||||
@@ -460,7 +460,7 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
|
||||
tf.debugging.assert_near(output_hidden_states[0, 1], expected_slice, rtol=1e-3)
|
||||
|
||||
def test_layer_global_attn(self):
|
||||
model = TFLongformerModel.from_pretrained("patrickvonplaten/longformer-random-tiny", use_cdn=False)
|
||||
model = TFLongformerModel.from_pretrained("patrickvonplaten/longformer-random-tiny")
|
||||
layer = model.longformer.encoder.layer[0].attention.self_attention
|
||||
hidden_states = self._get_hidden_states()
|
||||
|
||||
@@ -481,7 +481,7 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
|
||||
is_global_attn = tf.math.reduce_any(is_index_global_attn)
|
||||
|
||||
output_hidden_states = layer(
|
||||
[hidden_states, -tf.math.abs(attention_mask), is_index_masked, is_index_global_attn, is_global_attn, None]
|
||||
[hidden_states, -tf.math.abs(attention_mask), is_index_masked, is_index_global_attn, is_global_attn]
|
||||
)[0]
|
||||
|
||||
self.assertTrue(output_hidden_states.shape, (2, 4, 8))
|
||||
@@ -496,6 +496,74 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
|
||||
tf.debugging.assert_near(output_hidden_states[0, 2], expected_slice_0, rtol=1e-3)
|
||||
tf.debugging.assert_near(output_hidden_states[1, -2], expected_slice_1, rtol=1e-3)
|
||||
|
||||
def test_layer_attn_probs(self):
|
||||
model = TFLongformerModel.from_pretrained("patrickvonplaten/longformer-random-tiny")
|
||||
layer = model.longformer.encoder.layer[0].attention.self_attention
|
||||
hidden_states = tf.concat([self._get_hidden_states(), self._get_hidden_states() - 0.5], axis=0)
|
||||
batch_size, seq_length, hidden_size = hidden_states.shape
|
||||
|
||||
# create attn mask
|
||||
attention_mask_1 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32)
|
||||
attention_mask_2 = tf.zeros((1, 1, 1, seq_length), dtype=tf.dtypes.float32)
|
||||
|
||||
attention_mask_1 = tf.where(tf.range(4)[None, :, None, None] > 1, 10000.0, attention_mask_1)
|
||||
attention_mask_1 = tf.where(tf.range(4)[None, :, None, None] > 2, -10000.0, attention_mask_1)
|
||||
attention_mask_2 = tf.where(tf.range(4)[None, :, None, None] > 0, 10000.0, attention_mask_2)
|
||||
attention_mask = tf.concat([attention_mask_1, attention_mask_2], axis=0)
|
||||
|
||||
is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0)
|
||||
is_index_global_attn = tf.math.greater(attention_mask[:, :, 0, 0], 0)
|
||||
is_global_attn = tf.math.reduce_any(is_index_global_attn)
|
||||
|
||||
output_hidden_states, local_attentions, global_attentions = layer(
|
||||
[hidden_states, -tf.math.abs(attention_mask), is_index_masked, is_index_global_attn, is_global_attn]
|
||||
)
|
||||
|
||||
self.assertEqual(local_attentions.shape, (2, 4, 2, 8))
|
||||
self.assertEqual(global_attentions.shape, (2, 2, 3, 4))
|
||||
|
||||
self.assertTrue((local_attentions[0, 2:4, :, :] == 0).numpy().tolist())
|
||||
self.assertTrue((local_attentions[1, 1:4, :, :] == 0).numpy().tolist())
|
||||
|
||||
#
|
||||
# The weight of all tokens with local attention must sum to 1.
|
||||
self.assertTrue(
|
||||
(tf.math.abs(tf.math.reduce_sum(global_attentions[0, :, :2, :], axis=-1) - 1) < 1e-6).numpy().tolist()
|
||||
)
|
||||
self.assertTrue(
|
||||
(tf.math.abs(tf.math.reduce_sum(global_attentions[1, :, :1, :], axis=-1) - 1) < 1e-6).numpy().tolist()
|
||||
)
|
||||
|
||||
tf.debugging.assert_near(
|
||||
local_attentions[0, 0, 0, :],
|
||||
tf.convert_to_tensor(
|
||||
[0.3328, 0.0000, 0.0000, 0.0000, 0.0000, 0.3355, 0.3318, 0.0000], dtype=tf.dtypes.float32
|
||||
),
|
||||
rtol=1e-3,
|
||||
)
|
||||
|
||||
tf.debugging.assert_near(
|
||||
local_attentions[1, 0, 0, :],
|
||||
tf.convert_to_tensor(
|
||||
[0.2492, 0.2502, 0.2502, 0.0000, 0.0000, 0.2505, 0.0000, 0.0000], dtype=tf.dtypes.float32
|
||||
),
|
||||
rtol=1e-3,
|
||||
)
|
||||
|
||||
# All the global attention weights must sum to 1.
|
||||
self.assertTrue((tf.math.abs(tf.math.reduce_sum(global_attentions, axis=-1) - 1) < 1e-6).numpy().tolist())
|
||||
|
||||
tf.debugging.assert_near(
|
||||
global_attentions[0, 0, 1, :],
|
||||
tf.convert_to_tensor([0.2500, 0.2500, 0.2500, 0.2500], dtype=tf.dtypes.float32),
|
||||
rtol=1e-3,
|
||||
)
|
||||
tf.debugging.assert_near(
|
||||
global_attentions[1, 0, 0, :],
|
||||
tf.convert_to_tensor([0.2497, 0.2500, 0.2499, 0.2504], dtype=tf.dtypes.float32),
|
||||
rtol=1e-3,
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_inference_no_head(self):
|
||||
model = TFLongformerModel.from_pretrained("allenai/longformer-base-4096")
|
||||
|
||||
Reference in New Issue
Block a user