🔴🔴🔴 [Attention] Refactor Attention Interface for Bart-based Models (#38108)
* starting attn refactor for encoder decoder models via bart (eager + sdpa) * flash attention works, remove unnecessary code * flex attention support for bart!, gotta check if the renaming is not too aggressive * some comments * skip flex grad test for standalone as done with the other test * revert flex attn rename (for now), sdpa simplify, and todos * more todos * refactor mask creation for reuse * modular attempt at biogpt * first batch of other models * fix attn dropout * fix autoformer copies * hubert * another batch of models * copies/style + last round of bart models --> whisper next? * remove unnecessary _reshape function and remove copy to whisper * add skip for decoder-only models out of enc-dec (same as in bart) * bring back licences * remove comment, added to pr read instead * mostly docs * disable sew flex attn as it's unclear attn mask for now * oops * test fixes for enc-dec * torch fx fixes + try at flex attn * skip on mbart * some more fixes * musicgen skip / delete old attn class logic + sdpa compose compile skip * disable flex attn for musicgen, not worth the effort * more fixes and style * flex attention test for dropout and encoder decoder that dont have main input names * informer fixes * the weirdest thing I've encountered yet... * style * remove empty tensor attempt, found core root in previous commits * disable time series due to tests being very text centric on inputs * add speech to text to be ignoring the other attns, also due to tests * update docs * remaining issues resolved ? * update docs for current state --> nllb moe and pegasus x sdpa is questionable :D * some models have not set the is_causal flag... * change dtype in softmax tol old behaviour + some modular fixes * I hate it but it is what it is * fixes from main for bart * forgot this one * some model fixes * style * current status * marian works now * fixing some copies * some copy fixes + time series x informer * last models possibly and fixes on style/copies * some post merge fixes * more fixes * make attention interface callable and move warnings there * style lol * add comment to "unsupported" * remove callable interface and change interface warnings + some copies * fix * ternary is ugly af, make it simpler * how did that happen * fix flex attn test * failing the test * no more fallback! fixing copies next * style + attn fixed * fixing copies and mask creation * wrong copy * fixup tests and disable flex attn for now * fixup last tests?
This commit is contained in:
@@ -357,7 +357,8 @@ class M2M100ModelIntegrationTests(unittest.TestCase):
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
# change to expected output here
|
||||
expected_slice = torch.tensor(
|
||||
[[-0.7780, -0.1676, 0.1038], [-6.7556, -1.3992, 0.0567], [-7.5383, -0.5920, -0.2779]], device=torch_device
|
||||
[[[-0.7780, -0.1676, 0.1038], [-6.7556, -1.3992, 0.0567], [-7.5383, -0.5920, -0.2779]]],
|
||||
device=torch_device,
|
||||
)
|
||||
torch.testing.assert_close(output[:, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE)
|
||||
|
||||
@@ -374,7 +375,8 @@ class M2M100ModelIntegrationTests(unittest.TestCase):
|
||||
self.assertEqual(output.shape, expected_shape)
|
||||
# change to expected output here
|
||||
expected_slice = torch.tensor(
|
||||
[[-1.0448, -1.0411, 3.7992], [-3.2191, -3.2386, -1.3451], [-3.6210, -3.5993, 0.4925]], device=torch_device
|
||||
[[[-1.0448, -1.0411, 3.7992], [-3.2191, -3.2386, -1.3451], [-3.6210, -3.5993, 0.4925]]],
|
||||
device=torch_device,
|
||||
)
|
||||
torch.testing.assert_close(output[:, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE)
|
||||
|
||||
@@ -426,7 +428,7 @@ class M2M100ModelIntegrationTests(unittest.TestCase):
|
||||
Overwriting the common test as the test is flaky on tiny models
|
||||
"""
|
||||
model = M2M100ForConditionalGeneration.from_pretrained(
|
||||
"facebook/m2m100_418M", torch_dtype=torch.float16, attn_implementation="flash_attention_2"
|
||||
"facebook/m2m100_418M", attn_implementation="flash_attention_2", torch_dtype=torch.float16
|
||||
).to(torch_device)
|
||||
|
||||
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", src_lang="fr", tgt_lang="en")
|
||||
|
||||
Reference in New Issue
Block a user