🔴🔴🔴 [Attention] Refactor Attention Interface for Bart-based Models (#38108)
* starting attn refactor for encoder decoder models via bart (eager + sdpa) * flash attention works, remove unnecessary code * flex attention support for bart!, gotta check if the renaming is not too aggressive * some comments * skip flex grad test for standalone as done with the other test * revert flex attn rename (for now), sdpa simplify, and todos * more todos * refactor mask creation for reuse * modular attempt at biogpt * first batch of other models * fix attn dropout * fix autoformer copies * hubert * another batch of models * copies/style + last round of bart models --> whisper next? * remove unnecessary _reshape function and remove copy to whisper * add skip for decoder-only models out of enc-dec (same as in bart) * bring back licences * remove comment, added to pr read instead * mostly docs * disable sew flex attn as it's unclear attn mask for now * oops * test fixes for enc-dec * torch fx fixes + try at flex attn * skip on mbart * some more fixes * musicgen skip / delete old attn class logic + sdpa compose compile skip * disable flex attn for musicgen, not worth the effort * more fixes and style * flex attention test for dropout and encoder decoder that dont have main input names * informer fixes * the weirdest thing I've encountered yet... * style * remove empty tensor attempt, found core root in previous commits * disable time series due to tests being very text centric on inputs * add speech to text to be ignoring the other attns, also due to tests * update docs * remaining issues resolved ? * update docs for current state --> nllb moe and pegasus x sdpa is questionable :D * some models have not set the is_causal flag... * change dtype in softmax tol old behaviour + some modular fixes * I hate it but it is what it is * fixes from main for bart * forgot this one * some model fixes * style * current status * marian works now * fixing some copies * some copy fixes + time series x informer * last models possibly and fixes on style/copies * some post merge fixes * more fixes * make attention interface callable and move warnings there * style lol * add comment to "unsupported" * remove callable interface and change interface warnings + some copies * fix * ternary is ugly af, make it simpler * how did that happen * fix flex attn test * failing the test * no more fallback! fixing copies next * style + attn fixed * fixing copies and mask creation * wrong copy * fixup tests and disable flex attn for now * fixup last tests?
This commit is contained in:
@@ -1148,6 +1148,10 @@ class GenerationTesterMixin:
|
||||
# enable cache
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
|
||||
|
||||
# force eager attention to support output attentions
|
||||
if self.has_attentions:
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
# NOTE: assisted generation only works with cache on at the moment.
|
||||
if not hasattr(config.get_text_config(), "use_cache"):
|
||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||
@@ -1228,6 +1232,10 @@ class GenerationTesterMixin:
|
||||
# enable cache
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
|
||||
|
||||
# force eager attention to support output attentions
|
||||
if self.has_attentions:
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
# NOTE: assisted generation only works with cache on at the moment.
|
||||
if not hasattr(config.get_text_config(), "use_cache"):
|
||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||
@@ -1282,6 +1290,10 @@ class GenerationTesterMixin:
|
||||
# enable cache if the model is not openai-gpt, xlnet, cpm, or xlm
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
|
||||
|
||||
# force eager attention to support output attentions
|
||||
if self.has_attentions:
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
# Encoder-decoder models are not supported
|
||||
if config.is_encoder_decoder:
|
||||
self.skipTest("DoLa is not supported for encoder-decoder models")
|
||||
@@ -1346,6 +1358,10 @@ class GenerationTesterMixin:
|
||||
# enable cache
|
||||
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
|
||||
|
||||
# force eager attention to support output attentions
|
||||
if self.has_attentions:
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
# NOTE: assisted generation only works with cache on at the moment.
|
||||
if not hasattr(config.get_text_config(), "use_cache"):
|
||||
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
|
||||
|
||||
Reference in New Issue
Block a user