Add head_mask and decoder_head_mask to TF LED (#9988)

* Add head masking to TF LED

* Add head_mask to Longformer + one doc piece to LED

* Fix integration tests
This commit is contained in:
Daniel Stancl
2021-02-09 17:45:18 +01:00
committed by GitHub
parent 77c0ce8c0c
commit e7381c4596
4 changed files with 217 additions and 11 deletions

View File

@@ -162,6 +162,8 @@ def prepare_led_inputs_dict(
decoder_input_ids,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
):
if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
@@ -173,11 +175,17 @@ def prepare_led_inputs_dict(
],
axis=-1,
)
if head_mask is None:
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None:
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
}
@@ -187,7 +195,6 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFLEDForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
test_head_masking = False
def setUp(self):
self.model_tester = TFLEDModelTester(self)

View File

@@ -297,7 +297,6 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
if is_tf_available()
else ()
)
test_head_masking = False
def setUp(self):
self.model_tester = TFLongformerModelTester(self)
@@ -517,8 +516,10 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
attention_mask = tf.where(tf.range(4)[None, :, None, None] > 1, -10000.0, attention_mask[:, :, None, None])
is_index_masked = tf.math.less(attention_mask[:, :, 0, 0], 0)
layer_head_mask = None
output_hidden_states = layer(
[hidden_states, attention_mask, is_index_masked, is_index_global_attn, is_global_attn]
[hidden_states, attention_mask, layer_head_mask, is_index_masked, is_index_global_attn, is_global_attn]
)[0]
expected_slice = tf.convert_to_tensor(
@@ -549,8 +550,17 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
is_index_global_attn = tf.math.greater(attention_mask[:, :, 0, 0], 0)
is_global_attn = tf.math.reduce_any(is_index_global_attn)
layer_head_mask = None
output_hidden_states = layer(
[hidden_states, -tf.math.abs(attention_mask), is_index_masked, is_index_global_attn, is_global_attn]
[
hidden_states,
-tf.math.abs(attention_mask),
layer_head_mask,
is_index_masked,
is_index_global_attn,
is_global_attn,
]
)[0]
self.assertTrue(output_hidden_states.shape, (2, 4, 8))
@@ -584,8 +594,17 @@ class TFLongformerModelIntegrationTest(unittest.TestCase):
is_index_global_attn = tf.math.greater(attention_mask[:, :, 0, 0], 0)
is_global_attn = tf.math.reduce_any(is_index_global_attn)
layer_head_mask = None
output_hidden_states, local_attentions, global_attentions = layer(
[hidden_states, -tf.math.abs(attention_mask), is_index_masked, is_index_global_attn, is_global_attn]
[
hidden_states,
-tf.math.abs(attention_mask),
layer_head_mask,
is_index_masked,
is_index_global_attn,
is_global_attn,
]
)
self.assertEqual(local_attentions.shape, (2, 4, 2, 8))