@@ -3925,7 +3925,7 @@ class ModelTesterMixin:
|
|||||||
|
|
||||||
# Only non-padding tokens are expected to match.
|
# Only non-padding tokens are expected to match.
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
torch.allclose(res_eager[attention_mask == 1], res_sdpa[attention_mask == 1], rtol=1e-3)
|
torch.allclose(res_eager[attention_mask == 1], res_sdpa[attention_mask == 1], rtol=1e-4, atol=1e-4)
|
||||||
)
|
)
|
||||||
|
|
||||||
@require_flash_attn
|
@require_flash_attn
|
||||||
|
|||||||
Reference in New Issue
Block a user