Fix test_eager_matches_sdpa_inference for XPU backend (#34889)

* Use torch.nn.attention.sdpa_kernel instead of deprecated torch.backends.cuda.sdp_kernel

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>

* Fix test_eager_matches_sdpa_inference for XPU backend

As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
which is implemented on PyTorch level using aten operators and is device
agnostic with respect to implementation of each aten operator. Thus, we can
reuse CUDA (or CPU) MATH weights for XPU.

Fixes: #34888
Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>

* Use torch.amp.autocast instead of deprecated torch.cuda.amp.autocast in nemotron

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>

---------

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
This commit is contained in:
Dmitry Rogozhkin
2024-12-02 07:21:04 -08:00
committed by GitHub
parent f41d5d8f74
commit 31830474bf
5 changed files with 64 additions and 12 deletions

View File

@@ -41,7 +41,7 @@ from transformers.utils import (
)
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor, sdpa_kernel
if is_torch_available():
@@ -636,7 +636,7 @@ class MimiModelTest(ModelTesterMixin, unittest.TestCase):
# TODO: test gradients as well (& for FA2 as well!)
with torch.no_grad():
with torch.backends.cuda.sdp_kernel(
with sdpa_kernel(
enable_flash=enable_kernels,
enable_math=True,
enable_mem_efficient=enable_kernels,
@@ -653,6 +653,12 @@ class MimiModelTest(ModelTesterMixin, unittest.TestCase):
if torch_device in ["cpu", "cuda"]:
atol = atols[torch_device, enable_kernels, torch_dtype]
rtol = rtols[torch_device, enable_kernels, torch_dtype]
elif torch_device == "xpu":
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
# which is implemented on PyTorch level using aten operators and is
# device agnostic with respect to implementation of each aten operator.
atol = atols["cuda", False, torch_dtype]
rtol = rtols["cuda", False, torch_dtype]
else:
atol = 1e-7
rtol = 1e-4