🔴🔴🔴 [Attention] Refactor Attention Interface for Bart-based Models (#38108)
* starting attn refactor for encoder decoder models via bart (eager + sdpa) * flash attention works, remove unnecessary code * flex attention support for bart!, gotta check if the renaming is not too aggressive * some comments * skip flex grad test for standalone as done with the other test * revert flex attn rename (for now), sdpa simplify, and todos * more todos * refactor mask creation for reuse * modular attempt at biogpt * first batch of other models * fix attn dropout * fix autoformer copies * hubert * another batch of models * copies/style + last round of bart models --> whisper next? * remove unnecessary _reshape function and remove copy to whisper * add skip for decoder-only models out of enc-dec (same as in bart) * bring back licences * remove comment, added to pr read instead * mostly docs * disable sew flex attn as it's unclear attn mask for now * oops * test fixes for enc-dec * torch fx fixes + try at flex attn * skip on mbart * some more fixes * musicgen skip / delete old attn class logic + sdpa compose compile skip * disable flex attn for musicgen, not worth the effort * more fixes and style * flex attention test for dropout and encoder decoder that dont have main input names * informer fixes * the weirdest thing I've encountered yet... * style * remove empty tensor attempt, found core root in previous commits * disable time series due to tests being very text centric on inputs * add speech to text to be ignoring the other attns, also due to tests * update docs * remaining issues resolved ? * update docs for current state --> nllb moe and pegasus x sdpa is questionable :D * some models have not set the is_causal flag... * change dtype in softmax tol old behaviour + some modular fixes * I hate it but it is what it is * fixes from main for bart * forgot this one * some model fixes * style * current status * marian works now * fixing some copies * some copy fixes + time series x informer * last models possibly and fixes on style/copies * some post merge fixes * more fixes * make attention interface callable and move warnings there * style lol * add comment to "unsupported" * remove callable interface and change interface warnings + some copies * fix * ternary is ugly af, make it simpler * how did that happen * fix flex attn test * failing the test * no more fallback! fixing copies next * style + attn fixed * fixing copies and mask creation * wrong copy * fixup tests and disable flex attn for now * fixup last tests?
This commit is contained in:
@@ -370,6 +370,9 @@ class HubertModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = True
|
||||
|
||||
# force eager attention to support output attentions
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
# no need to test all models as different heads yield the same functionality
|
||||
model_class = self.all_model_classes[0]
|
||||
model = model_class(config)
|
||||
@@ -632,6 +635,9 @@ class HubertRobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = True
|
||||
|
||||
# force eager attention to support output attentions
|
||||
config._attn_implementation = "eager"
|
||||
|
||||
# no need to test all models as different heads yield the same functionality
|
||||
model_class = self.all_model_classes[0]
|
||||
model = model_class(config)
|
||||
|
||||
Reference in New Issue
Block a user