TF: BART compatible with XLA generation (#17479)
* Also propagate changes to blenderbot, blenderbot_small, marian, mbart, and pegasus
This commit is contained in:
@@ -214,10 +214,10 @@ class TFModelTesterMixin:
|
||||
"decoder_input_ids",
|
||||
"decoder_attention_mask",
|
||||
]
|
||||
expected_arg_names.extend(["decoder_position_ids"] if "decoder_position_ids" in arg_names else [])
|
||||
expected_arg_names.extend(
|
||||
["head_mask", "decoder_head_mask"] if "head_mask" and "decoder_head_mask" in arg_names else []
|
||||
)
|
||||
# Necessary to handle BART with newly added cross_attn_head_mask
|
||||
expected_arg_names.extend(
|
||||
["cross_attn_head_mask", "encoder_outputs"]
|
||||
if "cross_attn_head_mask" in arg_names
|
||||
|
||||
Reference in New Issue
Block a user