Add head_mask and decoder_head_mask to TF LED (#9988)

* Add head masking to TF LED

* Add head_mask to Longformer + one doc piece to LED

* Fix integration tests
This commit is contained in:
Daniel Stancl
2021-02-09 17:45:18 +01:00
committed by GitHub
parent 77c0ce8c0c
commit e7381c4596
4 changed files with 217 additions and 11 deletions

View File

@@ -162,6 +162,8 @@ def prepare_led_inputs_dict(
decoder_input_ids,
attention_mask=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
):
if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
@@ -173,11 +175,17 @@ def prepare_led_inputs_dict(
],
axis=-1,
)
if head_mask is None:
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None:
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
}
@@ -187,7 +195,6 @@ class TFLEDModelTest(TFModelTesterMixin, unittest.TestCase):
all_generative_model_classes = (TFLEDForConditionalGeneration,) if is_tf_available() else ()
is_encoder_decoder = True
test_pruning = False
test_head_masking = False
def setUp(self):
self.model_tester = TFLEDModelTester(self)