Add head_mask and decoder_head_mask to TF LED (#9988)
* Add head masking to TF LED * Add head_mask to Longformer + one doc piece to LED * Fix integration tests
This commit is contained in:
@@ -297,7 +297,6 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
if is_tf_available()
|
||||
else ()
|
||||
)
|
||||
test_head_masking = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = TFLongformerModelTester(self)
|
||||
@@ -517,8 +516,10 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
|
||||
attention_mask = tf.where(tf.range(4)[None, :, None, None] > 1, -10000.0, attention_mask[:, :, None, None])
|
||||
is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0)
|
||||
|
||||
layer_head_mask = None
|
||||
|
||||
output_hidden_states = layer(
|
||||
[hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn]
|
||||
[hidden_states, attention_mask, layer_head_mask, is_index_masked, is_index_global_attn, is_global_attn]
|
||||
)[0]
|
||||
|
||||
expected_slice = tf.convert_to_tensor(
|
||||
@@ -549,8 +550,17 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
|
||||
is_index_global_attn = tf.math.greater(attention_mask[:, :, 0, 0], 0)
|
||||
is_global_attn = tf.math.reduce_any(is_index_global_attn)
|
||||
|
||||
layer_head_mask = None
|
||||
|
||||
output_hidden_states = layer(
|
||||
[hidden_states, -tf.math.abs(attention_mask), is_index_masked, is_index_global_attn, is_global_attn]
|
||||
[
|
||||
hidden_states,
|
||||
-tf.math.abs(attention_mask),
|
||||
layer_head_mask,
|
||||
is_index_masked,
|
||||
is_index_global_attn,
|
||||
is_global_attn,
|
||||
]
|
||||
)[0]
|
||||
|
||||
self.assertTrue(output_hidden_states.shape, (2, 4, 8))
|
||||
@@ -584,8 +594,17 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
|
||||
is_index_global_attn = tf.math.greater(attention_mask[:, :, 0, 0], 0)
|
||||
is_global_attn = tf.math.reduce_any(is_index_global_attn)
|
||||
|
||||
layer_head_mask = None
|
||||
|
||||
output_hidden_states, local_attentions, global_attentions = layer(
|
||||
[hidden_states, -tf.math.abs(attention_mask), is_index_masked, is_index_global_attn, is_global_attn]
|
||||
[
|
||||
hidden_states,
|
||||
-tf.math.abs(attention_mask),
|
||||
layer_head_mask,
|
||||
is_index_masked,
|
||||
is_index_global_attn,
|
||||
is_global_attn,
|
||||
]
|
||||
)
|
||||
|
||||
self.assertEqual(local_attentions.shape, (2, 4, 2, 8))
|
||||
|
||||
Reference in New Issue
Block a user