[FlexAttention] Reenable flex for encoder-decoder and make the test more robust (#38321)
* reenable most flex attention test cases * style * trigger * trigger
This commit is contained in:
@@ -4570,20 +4570,29 @@ class ModelTesterMixin:
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config._attn_implementation = "flex_attention"
|
||||
# Flex Attention can not use dropout
|
||||
# Flex Attention cannot use dropout
|
||||
if hasattr(config, "attention_dropout"):
|
||||
config.attention_dropout = 0
|
||||
if hasattr(config, "attention_probs_dropout_prob"):
|
||||
config.attention_probs_dropout_prob = 0
|
||||
|
||||
# Flex attention relies on triton on compilation
|
||||
# However, triton cannot handle hidden dimensions of less than 16
|
||||
# --> forcing at least a hidden dim of 16
|
||||
config.hidden_size *= max(
|
||||
16 // getattr(config, "head_dim", config.hidden_size // config.num_attention_heads), 1
|
||||
)
|
||||
if hasattr(config, "head_dim"):
|
||||
config.head_dim = max(16, config.head_dim)
|
||||
|
||||
model = model_class(config).to(device=torch_device)
|
||||
self.assertTrue(model.config._attn_implementation == "flex_attention")
|
||||
|
||||
# Elaborate workaround for encoder-decoder models as some do not specify their main input
|
||||
dummy_inputs = {model.main_input_name: inputs_dict[model.main_input_name].to(torch_device)}
|
||||
if config.is_encoder_decoder:
|
||||
dummy_inputs["decoder_input_ids"] = inputs_dict["decoder_input_ids"]
|
||||
dummy_inputs["decoder_attention_mask"] = inputs_dict["decoder_attention_mask"]
|
||||
dummy_inputs["decoder_input_ids"] = inputs_dict["decoder_input_ids"].to(torch_device)
|
||||
dummy_inputs["decoder_attention_mask"] = inputs_dict["decoder_attention_mask"].to(torch_device)
|
||||
|
||||
# If this does not raise an error, the test passes (see https://github.com/huggingface/transformers/pull/35605)
|
||||
_ = model(**dummy_inputs)
|
||||
|
||||
Reference in New Issue
Block a user