TF: BART compatible with XLA generation (#17479)

* Also propagate changes to blenderbot, blenderbot_small, marian, mbart, and pegasus
This commit is contained in:
Joao Gante
2022-06-20 11:07:46 +01:00
committed by GitHub
parent 6589e510fa
commit 132402d752
18 changed files with 421 additions and 86 deletions

View File

@@ -214,10 +214,10 @@ class TFModelTesterMixin:
"decoder_input_ids",
"decoder_attention_mask",
]
expected_arg_names.extend(["decoder_position_ids"] if "decoder_position_ids" in arg_names else [])
expected_arg_names.extend(
["head_mask", "decoder_head_mask"] if "head_mask" and "decoder_head_mask" in arg_names else []
)
# Necessary to handle BART with newly added cross_attn_head_mask
expected_arg_names.extend(
["cross_attn_head_mask", "encoder_outputs"]
if "cross_attn_head_mask" in arg_names