use the enable_gqa param in torch.nn.functional.scaled_dot_product_at… (#39412)

* use the enable_gqa param in torch.nn.functional.scaled_dot_product_attention

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* ci failure fix

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* add check

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix ci failure

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* refine code, extend to cuda

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* refine code

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix review comments

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* refine the PR

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
This commit is contained in:
Wang, Yi
2025-07-21 20:46:43 +08:00
committed by GitHub
parent 6b3a1f2f51
commit 9323d0873c
2 changed files with 24 additions and 8 deletions

View File

@@ -588,12 +588,12 @@ class CacheExportIntegrationTest(unittest.TestCase):
past_key_values=past_key_values_eager,
use_cache=True,
)
self.assertTrue(torch.allclose(res.logits, res_eager.logits))
self.assertTrue(torch.allclose(res.logits, res_eager.logits, atol=1e-5))
for k1, k2 in zip(res.past_key_values.key_cache, res_eager.past_key_values.key_cache):
self.assertTrue(torch.allclose(k1, k2))
self.assertTrue(torch.allclose(k1, k2, atol=1e-5))
for v1, v2 in zip(res.past_key_values.value_cache, res_eager.past_key_values.value_cache):
self.assertTrue(torch.allclose(v1, v2))
self.assertTrue(torch.allclose(v1, v2, atol=1e-5))
def test_dynamic_cache_exportability_multiple_run(self):
# When exporting with DynamicCache, you should export two graphs:
@@ -686,10 +686,10 @@ class CacheExportIntegrationTest(unittest.TestCase):
)
for k1, k2 in zip(res_export_2.past_key_values.key_cache, res_eager_2.past_key_values.key_cache):
self.assertTrue(torch.allclose(k1, k2))
self.assertTrue(torch.allclose(k1, k2, atol=1e-5))
for v1, v2 in zip(res_export_2.past_key_values.value_cache, res_eager_2.past_key_values.value_cache):
self.assertTrue(torch.allclose(v1, v2))
self.assertTrue(torch.allclose(v1, v2, atol=1e-5))
@unittest.skip("Runs on my machine locally, passed, no idea why it does not online")
def test_static_cache_exportability(self):