use the enable_gqa param in torch.nn.functional.scaled_dot_product_at… (#39412)
* use the enable_gqa param in torch.nn.functional.scaled_dot_product_attention Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * ci failure fix Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * add check Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix ci failure Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * refine code, extend to cuda Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * refine code Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix review comments Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * refine the PR Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> --------- Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
This commit is contained in:
@@ -588,12 +588,12 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
||||
past_key_values=past_key_values_eager,
|
||||
use_cache=True,
|
||||
)
|
||||
self.assertTrue(torch.allclose(res.logits, res_eager.logits))
|
||||
self.assertTrue(torch.allclose(res.logits, res_eager.logits, atol=1e-5))
|
||||
for k1, k2 in zip(res.past_key_values.key_cache, res_eager.past_key_values.key_cache):
|
||||
self.assertTrue(torch.allclose(k1, k2))
|
||||
self.assertTrue(torch.allclose(k1, k2, atol=1e-5))
|
||||
|
||||
for v1, v2 in zip(res.past_key_values.value_cache, res_eager.past_key_values.value_cache):
|
||||
self.assertTrue(torch.allclose(v1, v2))
|
||||
self.assertTrue(torch.allclose(v1, v2, atol=1e-5))
|
||||
|
||||
def test_dynamic_cache_exportability_multiple_run(self):
|
||||
# When exporting with DynamicCache, you should export two graphs:
|
||||
@@ -686,10 +686,10 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
for k1, k2 in zip(res_export_2.past_key_values.key_cache, res_eager_2.past_key_values.key_cache):
|
||||
self.assertTrue(torch.allclose(k1, k2))
|
||||
self.assertTrue(torch.allclose(k1, k2, atol=1e-5))
|
||||
|
||||
for v1, v2 in zip(res_export_2.past_key_values.value_cache, res_eager_2.past_key_values.value_cache):
|
||||
self.assertTrue(torch.allclose(v1, v2))
|
||||
self.assertTrue(torch.allclose(v1, v2, atol=1e-5))
|
||||
|
||||
@unittest.skip("Runs on my machine locally, passed, no idea why it does not online")
|
||||
def test_static_cache_exportability(self):
|
||||
|
||||
Reference in New Issue
Block a user