Fix test_eager_matches_sdpa_inference for XPU backend (#34889)
* Use torch.nn.attention.sdpa_kernel instead of deprecated torch.backends.cuda.sdp_kernel Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com> * Fix test_eager_matches_sdpa_inference for XPU backend As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH which is implemented on PyTorch level using aten operators and is device agnostic with respect to implementation of each aten operator. Thus, we can reuse CUDA (or CPU) MATH weights for XPU. Fixes: #34888 Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com> * Use torch.amp.autocast instead of deprecated torch.cuda.amp.autocast in nemotron Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com> --------- Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
This commit is contained in:
@@ -187,6 +187,22 @@ def _deepspeed_zero3(ds_config):
|
||||
unset_hf_deepspeed_config()
|
||||
|
||||
|
||||
def sdpa_kernel(enable_flash, enable_math, enable_mem_efficient):
|
||||
if version.parse(torch.__version__).release < version.parse("2.3").release:
|
||||
return torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=enable_flash, enable_math=enable_math, enable_mem_efficient=enable_mem_efficient
|
||||
)
|
||||
|
||||
backends = []
|
||||
if enable_flash:
|
||||
backends += [torch.nn.attention.SDPBackend.FLASH_ATTENTION]
|
||||
if enable_math:
|
||||
backends += [torch.nn.attention.SDPBackend.MATH]
|
||||
if enable_mem_efficient:
|
||||
backends += [torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION]
|
||||
return torch.nn.attention.sdpa_kernel(backends)
|
||||
|
||||
|
||||
@require_torch
|
||||
class ModelTesterMixin:
|
||||
model_tester = None
|
||||
@@ -4175,7 +4191,7 @@ class ModelTesterMixin:
|
||||
|
||||
# TODO: test gradients as well (& for FA2 as well!)
|
||||
with torch.no_grad():
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
with sdpa_kernel(
|
||||
enable_flash=enable_kernels,
|
||||
enable_math=True,
|
||||
enable_mem_efficient=enable_kernels,
|
||||
@@ -4198,6 +4214,12 @@ class ModelTesterMixin:
|
||||
if torch_device in ["cpu", "cuda"]:
|
||||
atol = atols[torch_device, enable_kernels, torch_dtype]
|
||||
rtol = rtols[torch_device, enable_kernels, torch_dtype]
|
||||
elif torch_device == "xpu":
|
||||
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
|
||||
# which is implemented on PyTorch level using aten operators and is
|
||||
# device agnostic with respect to implementation of each aten operator.
|
||||
atol = atols["cuda", False, torch_dtype]
|
||||
rtol = rtols["cuda", False, torch_dtype]
|
||||
else:
|
||||
atol = 1e-7
|
||||
rtol = 1e-4
|
||||
|
||||
Reference in New Issue
Block a user