use the enable_gqa param in torch.nn.functional.scaled_dot_product_at… (#39412)
* use the enable_gqa param in torch.nn.functional.scaled_dot_product_attention Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * ci failure fix Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * add check Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix ci failure Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * refine code, extend to cuda Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * refine code Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * fix review comments Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> * refine the PR Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> --------- Signed-off-by: Wang, Yi A <yi.a.wang@intel.com> Co-authored-by: Cyril Vallez <cyril.vallez@huggingface.co>
This commit is contained in:
@@ -3,11 +3,15 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..utils import logging
|
from ..utils import logging
|
||||||
|
from ..utils.import_utils import is_torch_greater_or_equal
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
_is_torch_greater_or_equal_than_2_5 = is_torch_greater_or_equal("2.5", accept_dev=True)
|
||||||
|
|
||||||
|
|
||||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||||
@@ -20,6 +24,14 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
|||||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||||
|
|
||||||
|
|
||||||
|
def use_gqa_in_sdpa(attention_mask: Optional[torch.Tensor], key: torch.Tensor) -> bool:
|
||||||
|
# GQA can only be used under the following conditions
|
||||||
|
# 1. torch version >= 2.5
|
||||||
|
# 2. attention_mask is None (otherwise it will fall back to the math kernel)
|
||||||
|
# 3. key is not a torch.fx.Proxy (otherwise it will fail with a tracing error)
|
||||||
|
return _is_torch_greater_or_equal_than_2_5 and attention_mask is None and not isinstance(key, torch.fx.Proxy)
|
||||||
|
|
||||||
|
|
||||||
def sdpa_attention_forward(
|
def sdpa_attention_forward(
|
||||||
module: torch.nn.Module,
|
module: torch.nn.Module,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
@@ -36,10 +48,13 @@ def sdpa_attention_forward(
|
|||||||
"`sdpa` attention does not support `output_attentions=True` or `head_mask`."
|
"`sdpa` attention does not support `output_attentions=True` or `head_mask`."
|
||||||
" Please set your attention to `eager` if you want any of these features."
|
" Please set your attention to `eager` if you want any of these features."
|
||||||
)
|
)
|
||||||
|
sdpa_kwargs = {}
|
||||||
if hasattr(module, "num_key_value_groups"):
|
if hasattr(module, "num_key_value_groups"):
|
||||||
key = repeat_kv(key, module.num_key_value_groups)
|
if not use_gqa_in_sdpa(attention_mask, key):
|
||||||
value = repeat_kv(value, module.num_key_value_groups)
|
key = repeat_kv(key, module.num_key_value_groups)
|
||||||
|
value = repeat_kv(value, module.num_key_value_groups)
|
||||||
|
else:
|
||||||
|
sdpa_kwargs = {"enable_gqa": True}
|
||||||
|
|
||||||
if attention_mask is not None and attention_mask.ndim == 4:
|
if attention_mask is not None and attention_mask.ndim == 4:
|
||||||
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
attention_mask = attention_mask[:, :, :, : key.shape[-2]]
|
||||||
@@ -71,6 +86,7 @@ def sdpa_attention_forward(
|
|||||||
dropout_p=dropout,
|
dropout_p=dropout,
|
||||||
scale=scaling,
|
scale=scaling,
|
||||||
is_causal=is_causal,
|
is_causal=is_causal,
|
||||||
|
**sdpa_kwargs,
|
||||||
)
|
)
|
||||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||||
|
|
||||||
|
|||||||
@@ -588,12 +588,12 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
|||||||
past_key_values=past_key_values_eager,
|
past_key_values=past_key_values_eager,
|
||||||
use_cache=True,
|
use_cache=True,
|
||||||
)
|
)
|
||||||
self.assertTrue(torch.allclose(res.logits, res_eager.logits))
|
self.assertTrue(torch.allclose(res.logits, res_eager.logits, atol=1e-5))
|
||||||
for k1, k2 in zip(res.past_key_values.key_cache, res_eager.past_key_values.key_cache):
|
for k1, k2 in zip(res.past_key_values.key_cache, res_eager.past_key_values.key_cache):
|
||||||
self.assertTrue(torch.allclose(k1, k2))
|
self.assertTrue(torch.allclose(k1, k2, atol=1e-5))
|
||||||
|
|
||||||
for v1, v2 in zip(res.past_key_values.value_cache, res_eager.past_key_values.value_cache):
|
for v1, v2 in zip(res.past_key_values.value_cache, res_eager.past_key_values.value_cache):
|
||||||
self.assertTrue(torch.allclose(v1, v2))
|
self.assertTrue(torch.allclose(v1, v2, atol=1e-5))
|
||||||
|
|
||||||
def test_dynamic_cache_exportability_multiple_run(self):
|
def test_dynamic_cache_exportability_multiple_run(self):
|
||||||
# When exporting with DynamicCache, you should export two graphs:
|
# When exporting with DynamicCache, you should export two graphs:
|
||||||
@@ -686,10 +686,10 @@ class CacheExportIntegrationTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
for k1, k2 in zip(res_export_2.past_key_values.key_cache, res_eager_2.past_key_values.key_cache):
|
for k1, k2 in zip(res_export_2.past_key_values.key_cache, res_eager_2.past_key_values.key_cache):
|
||||||
self.assertTrue(torch.allclose(k1, k2))
|
self.assertTrue(torch.allclose(k1, k2, atol=1e-5))
|
||||||
|
|
||||||
for v1, v2 in zip(res_export_2.past_key_values.value_cache, res_eager_2.past_key_values.value_cache):
|
for v1, v2 in zip(res_export_2.past_key_values.value_cache, res_eager_2.past_key_values.value_cache):
|
||||||
self.assertTrue(torch.allclose(v1, v2))
|
self.assertTrue(torch.allclose(v1, v2, atol=1e-5))
|
||||||
|
|
||||||
@unittest.skip("Runs on my machine locally, passed, no idea why it does not online")
|
@unittest.skip("Runs on my machine locally, passed, no idea why it does not online")
|
||||||
def test_static_cache_exportability(self):
|
def test_static_cache_exportability(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user