🔴🔴🔴 [Attention] Refactor Attention Interface for Bart-based Models (#38108)
* starting attn refactor for encoder decoder models via bart (eager + sdpa) * flash attention works, remove unnecessary code * flex attention support for bart!, gotta check if the renaming is not too aggressive * some comments * skip flex grad test for standalone as done with the other test * revert flex attn rename (for now), sdpa simplify, and todos * more todos * refactor mask creation for reuse * modular attempt at biogpt * first batch of other models * fix attn dropout * fix autoformer copies * hubert * another batch of models * copies/style + last round of bart models --> whisper next? * remove unnecessary _reshape function and remove copy to whisper * add skip for decoder-only models out of enc-dec (same as in bart) * bring back licences * remove comment, added to pr read instead * mostly docs * disable sew flex attn as it's unclear attn mask for now * oops * test fixes for enc-dec * torch fx fixes + try at flex attn * skip on mbart * some more fixes * musicgen skip / delete old attn class logic + sdpa compose compile skip * disable flex attn for musicgen, not worth the effort * more fixes and style * flex attention test for dropout and encoder decoder that dont have main input names * informer fixes * the weirdest thing I've encountered yet... * style * remove empty tensor attempt, found core root in previous commits * disable time series due to tests being very text centric on inputs * add speech to text to be ignoring the other attns, also due to tests * update docs * remaining issues resolved ? * update docs for current state --> nllb moe and pegasus x sdpa is questionable :D * some models have not set the is_causal flag... * change dtype in softmax tol old behaviour + some modular fixes * I hate it but it is what it is * fixes from main for bart * forgot this one * some model fixes * style * current status * marian works now * fixing some copies * some copy fixes + time series x informer * last models possibly and fixes on style/copies * some post merge fixes * more fixes * make attention interface callable and move warnings there * style lol * add comment to "unsupported" * remove callable interface and change interface warnings + some copies * fix * ternary is ugly af, make it simpler * how did that happen * fix flex attn test * failing the test * no more fallback! fixing copies next * style + attn fixed * fixing copies and mask creation * wrong copy * fixup tests and disable flex attn for now * fixup last tests?
This commit is contained in:
@@ -958,6 +958,8 @@ class ModelTesterMixin:
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.return_dict = True
|
||||
# force eager attention to support output attentions
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
seq_len = getattr(self.model_tester, "seq_length", None)
|
||||
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
|
||||
@@ -1106,7 +1108,11 @@ class ModelTesterMixin:
|
||||
configs_no_init.torchscript = True
|
||||
for model_class in self.all_model_classes:
|
||||
for attn_implementation in ["eager", "sdpa"]:
|
||||
if attn_implementation == "sdpa" and (not model_class._supports_sdpa or not is_torch_sdpa_available()):
|
||||
if (
|
||||
attn_implementation == "sdpa"
|
||||
and (not model_class._supports_sdpa or not is_torch_sdpa_available())
|
||||
or config.output_attentions
|
||||
):
|
||||
continue
|
||||
|
||||
configs_no_init._attn_implementation = attn_implementation
|
||||
@@ -1708,6 +1714,10 @@ class ModelTesterMixin:
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = self.has_attentions
|
||||
|
||||
# force eager attention to support output attentions
|
||||
if self.has_attentions:
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
# no need to test all models as different heads yield the same functionality
|
||||
model_class = self.all_model_classes[0]
|
||||
model = model_class(config)
|
||||
@@ -4555,13 +4565,26 @@ class ModelTesterMixin:
|
||||
# TODO: raushan, fix for composite models after making VLMs support new attn API
|
||||
if not model_class._supports_flex_attn or self._is_composite:
|
||||
self.skipTest(reason="This model does not support flex attention")
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config._attn_implementation = "flex_attention"
|
||||
model = model_class(config).to(device=torch_device, dtype=torch.float16)
|
||||
# Flex Attention can not use dropout
|
||||
if hasattr(config, "attention_dropout"):
|
||||
config.attention_dropout = 0
|
||||
if hasattr(config, "attention_probs_dropout_prob"):
|
||||
config.attention_probs_dropout_prob = 0
|
||||
|
||||
model = model_class(config).to(device=torch_device)
|
||||
self.assertTrue(model.config._attn_implementation == "flex_attention")
|
||||
|
||||
# Elaborate workaround for encoder-decoder models as some do not specify their main input
|
||||
dummy_inputs = {model.main_input_name: inputs_dict[model.main_input_name].to(torch_device)}
|
||||
if config.is_encoder_decoder:
|
||||
dummy_inputs["decoder_input_ids"] = inputs_dict["decoder_input_ids"]
|
||||
dummy_inputs["decoder_attention_mask"] = inputs_dict["decoder_attention_mask"]
|
||||
|
||||
# If this does not raise an error, the test passes (see https://github.com/huggingface/transformers/pull/35605)
|
||||
_ = model(inputs_dict["input_ids"].to(torch_device))
|
||||
_ = model(**dummy_inputs)
|
||||
|
||||
def test_generation_tester_mixin_inheritance(self):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user