Use sdpa_kernel in tests (#35472)
* update: use sdpa_kernel * update: rerun test
This commit is contained in:
@@ -4255,7 +4255,7 @@ class ModelTesterMixin:
|
|||||||
if isinstance(inp, torch.Tensor) and inp.dtype in [torch.float32, torch.float16]:
|
if isinstance(inp, torch.Tensor) and inp.dtype in [torch.float32, torch.float16]:
|
||||||
inputs_dict[name] = inp.to(torch.float16)
|
inputs_dict[name] = inp.to(torch.float16)
|
||||||
|
|
||||||
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
with sdpa_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
||||||
_ = model(**inputs_dict)
|
_ = model(**inputs_dict)
|
||||||
|
|
||||||
@require_non_xpu
|
@require_non_xpu
|
||||||
@@ -4347,11 +4347,7 @@ class ModelTesterMixin:
|
|||||||
model_sdpa = model_sdpa.eval()
|
model_sdpa = model_sdpa.eval()
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with torch.backends.cuda.sdp_kernel(
|
with sdpa_kernel(enable_flash=False, enable_math=True, enable_mem_efficient=False):
|
||||||
enable_flash=False,
|
|
||||||
enable_math=True,
|
|
||||||
enable_mem_efficient=False,
|
|
||||||
):
|
|
||||||
res_eager = model_eager(**inputs_dict, return_dict=False)[0]
|
res_eager = model_eager(**inputs_dict, return_dict=False)[0]
|
||||||
res_sdpa = model_sdpa(**inputs_dict, return_dict=False)[0]
|
res_sdpa = model_sdpa(**inputs_dict, return_dict=False)[0]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user