Add head_mask and decoder_head_mask to TF LED (#9988)

* Add head masking to TF LED

* Add head_mask to Longformer + one doc piece to LED

* Fix integration tests
This commit is contained in:
Daniel Stancl
2021-02-09 17:45:18 +01:00
committed by GitHub
parent 77c0ce8c0c
commit e7381c4596
4 changed files with 217 additions and 11 deletions

View File

@@ -297,7 +297,6 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available()
else ()
)
test_head_masking = False
def setUp(self):
self.model_tester = TFLongformerModelTester(self)
@@ -517,8 +516,10 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
attention_mask = tf.where(tf.range(4)[None, :, None, None] > 1, -10000.0, attention_mask[:, :, None, None])
is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0)
layer_head_mask = None
output_hidden_states = layer(
[hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn]
[hidden_states, attention_mask, layer_head_mask, is_index_masked, is_index_global_attn, is_global_attn]
)[0]
expected_slice = tf.convert_to_tensor(
@@ -549,8 +550,17 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
is_index_global_attn = tf.math.greater(attention_mask[:, :, 0, 0], 0)
is_global_attn = tf.math.reduce_any(is_index_global_attn)
layer_head_mask = None
output_hidden_states = layer(
[hidden_states, -tf.math.abs(attention_mask), is_index_masked, is_index_global_attn, is_global_attn]
[
hidden_states,
-tf.math.abs(attention_mask),
layer_head_mask,
is_index_masked,
is_index_global_attn,
is_global_attn,
]
)[0]
self.assertTrue(output_hidden_states.shape, (2, 4, 8))
@@ -584,8 +594,17 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
is_index_global_attn = tf.math.greater(attention_mask[:, :, 0, 0], 0)
is_global_attn = tf.math.reduce_any(is_index_global_attn)
layer_head_mask = None
output_hidden_states, local_attentions, global_attentions = layer(
[hidden_states, -tf.math.abs(attention_mask), is_index_masked, is_index_global_attn, is_global_attn]
[
hidden_states,
-tf.math.abs(attention_mask),
layer_head_mask,
is_index_masked,
is_index_global_attn,
is_global_attn,
]
)
self.assertEqual(local_attentions.shape, (2, 4, 2, 8))