Fix test_eager_matches_sdpa_inference for XPU backend (#34889)
* Use torch.nn.attention.sdpa_kernel instead of deprecated torch.backends.cuda.sdp_kernel Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com> * Fix test_eager_matches_sdpa_inference for XPU backend As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH which is implemented on PyTorch level using aten operators and is device agnostic with respect to implementation of each aten operator. Thus, we can reuse CUDA (or CPU) MATH weights for XPU. Fixes: #34888 Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com> * Use torch.amp.autocast instead of deprecated torch.cuda.amp.autocast in nemotron Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com> --------- Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
This commit is contained in:
@@ -41,7 +41,7 @@ from transformers.utils import (
|
||||
)
|
||||
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor
|
||||
from ...test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor, sdpa_kernel
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -636,7 +636,7 @@ class MimiModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
|
||||
# TODO: test gradients as well (& for FA2 as well!)
|
||||
with torch.no_grad():
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
with sdpa_kernel(
|
||||
enable_flash=enable_kernels,
|
||||
enable_math=True,
|
||||
enable_mem_efficient=enable_kernels,
|
||||
@@ -653,6 +653,12 @@ class MimiModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
if torch_device in ["cpu", "cuda"]:
|
||||
atol = atols[torch_device, enable_kernels, torch_dtype]
|
||||
rtol = rtols[torch_device, enable_kernels, torch_dtype]
|
||||
elif torch_device == "xpu":
|
||||
# As of PyTorch 2.5 XPU backend supports only torch.nn.attention.SDPBackend.MATH
|
||||
# which is implemented on PyTorch level using aten operators and is
|
||||
# device agnostic with respect to implementation of each aten operator.
|
||||
atol = atols["cuda", False, torch_dtype]
|
||||
rtol = rtols["cuda", False, torch_dtype]
|
||||
else:
|
||||
atol = 1e-7
|
||||
rtol = 1e-4
|
||||
|
||||
Reference in New Issue
Block a user