diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 929bbb13a5..708a3a54e3 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4255,7 +4255,7 @@ class ModelTesterMixin: if isinstance(inp, torch.Tensor) and inp.dtype in [torch.float32, torch.float16]: inputs_dict[name] = inp.to(torch.float16) - with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): + with sdpa_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): _ = model(**inputs_dict) @require_non_xpu @@ -4347,11 +4347,7 @@ class ModelTesterMixin: model_sdpa = model_sdpa.eval() with torch.no_grad(): - with torch.backends.cuda.sdp_kernel( - enable_flash=False, - enable_math=True, - enable_mem_efficient=False, - ): + with sdpa_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False): res_eager = model_eager(**inputs_dict, return_dict=False)[0] res_sdpa = model_sdpa(**inputs_dict, return_dict=False)[0]