TF BART models - Add cross_attentions to model output and fix cross-attention head masking (#10699)

* Add cross_attn_head_mask to BART

* Fix cross_attentions in TFBart-like models

* This commit enables returning of `cross_attentions`
for TFBart-like models

* It also fixes attention head masking in cross-attenion module

* Update TF model templates

* Fix missing , in TF model templates

* Fix typo: congig -> config
This commit is contained in:
Daniel Stancl
2021-04-26 14:16:21 +02:00
committed by GitHub
parent 4bd6b54fa4
commit 38a716cd41
15 changed files with 643 additions and 216 deletions

View File

@@ -146,6 +146,7 @@ def prepare_blenderbot_inputs_dict(
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
):
if attention_mask is None:
attention_mask = tf.cast(tf.math.not_equal(input_ids, config.pad_token_id), tf.int8)
@@ -161,6 +162,8 @@ def prepare_blenderbot_inputs_dict(
head_mask = tf.ones((config.encoder_layers, config.encoder_attention_heads))
if decoder_head_mask is None:
decoder_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
if cross_attn_head_mask is None:
cross_attn_head_mask = tf.ones((config.decoder_layers, config.decoder_attention_heads))
return {
"input_ids": input_ids,
"decoder_input_ids": decoder_input_ids,
@@ -168,6 +171,7 @@ def prepare_blenderbot_inputs_dict(
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
}