use the enable_gqa param in torch.nn.functional.scaled_dot_product_at… (#39412)

* use the enable_gqa param in torch.nn.functional.scaled_dot_product_attention

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* ci failure fix

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* add check

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix ci failure

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* refine code, extend to cuda

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* refine code

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* fix review comments

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* refine the PR

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

---------

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
This commit is contained in:
Wang, Yi
2025-07-21 20:46:43 +08:00
committed by GitHub
parent 6b3a1f2f51
commit 9323d0873c
2 changed files with 24 additions and 8 deletions

View File

@@ -3,11 +3,15 @@ from typing import Optional
import torch import torch
from ..utils import logging from ..utils import logging
from ..utils.import_utils import is_torch_greater_or_equal
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_is_torch_greater_or_equal_than_2_5 = is_torch_greater_or_equal("2.5", accept_dev=True)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
""" """
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
@@ -20,6 +24,14 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
def use_gqa_in_sdpa(attention_mask: Optional[torch.Tensor], key: torch.Tensor) -> bool:
# GQA can only be used under the following conditions
# 1. torch version >= 2.5
# 2. attention_mask is None (otherwise it will fall back to the math kernel)
# 3. key is not a torch.fx.Proxy (otherwise it will fail with a tracing error)
return _is_torch_greater_or_equal_than_2_5 and attention_mask is None and not isinstance(key, torch.fx.Proxy)
def sdpa_attention_forward( def sdpa_attention_forward(
module: torch.nn.Module, module: torch.nn.Module,
query: torch.Tensor, query: torch.Tensor,
@@ -36,10 +48,13 @@ def sdpa_attention_forward(
"`sdpa` attention does not support `output_attentions=True` or `head_mask`." "`sdpa` attention does not support `output_attentions=True` or `head_mask`."
" Please set your attention to `eager` if you want any of these features." " Please set your attention to `eager` if you want any of these features."
) )
sdpa_kwargs = {}
if hasattr(module, "num_key_value_groups"): if hasattr(module, "num_key_value_groups"):
if not use_gqa_in_sdpa(attention_mask, key):
key = repeat_kv(key, module.num_key_value_groups) key = repeat_kv(key, module.num_key_value_groups)
value = repeat_kv(value, module.num_key_value_groups) value = repeat_kv(value, module.num_key_value_groups)
else:
sdpa_kwargs = {"enable_gqa": True}
if attention_mask is not None and attention_mask.ndim == 4: if attention_mask is not None and attention_mask.ndim == 4:
attention_mask = attention_mask[:, :, :, : key.shape[-2]] attention_mask = attention_mask[:, :, :, : key.shape[-2]]
@@ -71,6 +86,7 @@ def sdpa_attention_forward(
dropout_p=dropout, dropout_p=dropout,
scale=scaling, scale=scaling,
is_causal=is_causal, is_causal=is_causal,
**sdpa_kwargs,
) )
attn_output = attn_output.transpose(1, 2).contiguous() attn_output = attn_output.transpose(1, 2).contiguous()

View File

@@ -588,12 +588,12 @@ class CacheExportIntegrationTest(unittest.TestCase):
past_key_values=past_key_values_eager, past_key_values=past_key_values_eager,
use_cache=True, use_cache=True,
) )
self.assertTrue(torch.allclose(res.logits, res_eager.logits)) self.assertTrue(torch.allclose(res.logits, res_eager.logits, atol=1e-5))
for k1, k2 in zip(res.past_key_values.key_cache, res_eager.past_key_values.key_cache): for k1, k2 in zip(res.past_key_values.key_cache, res_eager.past_key_values.key_cache):
self.assertTrue(torch.allclose(k1, k2)) self.assertTrue(torch.allclose(k1, k2, atol=1e-5))
for v1, v2 in zip(res.past_key_values.value_cache, res_eager.past_key_values.value_cache): for v1, v2 in zip(res.past_key_values.value_cache, res_eager.past_key_values.value_cache):
self.assertTrue(torch.allclose(v1, v2)) self.assertTrue(torch.allclose(v1, v2, atol=1e-5))
def test_dynamic_cache_exportability_multiple_run(self): def test_dynamic_cache_exportability_multiple_run(self):
# When exporting with DynamicCache, you should export two graphs: # When exporting with DynamicCache, you should export two graphs:
@@ -686,10 +686,10 @@ class CacheExportIntegrationTest(unittest.TestCase):
) )
for k1, k2 in zip(res_export_2.past_key_values.key_cache, res_eager_2.past_key_values.key_cache): for k1, k2 in zip(res_export_2.past_key_values.key_cache, res_eager_2.past_key_values.key_cache):
self.assertTrue(torch.allclose(k1, k2)) self.assertTrue(torch.allclose(k1, k2, atol=1e-5))
for v1, v2 in zip(res_export_2.past_key_values.value_cache, res_eager_2.past_key_values.value_cache): for v1, v2 in zip(res_export_2.past_key_values.value_cache, res_eager_2.past_key_values.value_cache):
self.assertTrue(torch.allclose(v1, v2)) self.assertTrue(torch.allclose(v1, v2, atol=1e-5))
@unittest.skip("Runs on my machine locally, passed, no idea why it does not online") @unittest.skip("Runs on my machine locally, passed, no idea why it does not online")
def test_static_cache_exportability(self): def test_static_cache_exportability(self):