[tests] make test_sdpa_equivalence device-agnostic (#32520)
* fix on xpu * [run_all]
This commit is contained in:
@@ -27,6 +27,7 @@ from transformers.testing_utils import (
|
|||||||
require_flash_attn,
|
require_flash_attn,
|
||||||
require_read_token,
|
require_read_token,
|
||||||
require_torch,
|
require_torch,
|
||||||
|
require_torch_accelerator,
|
||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
require_torch_sdpa,
|
require_torch_sdpa,
|
||||||
slow,
|
slow,
|
||||||
@@ -460,7 +461,7 @@ class GemmaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixi
|
|||||||
self.skipTest(reason="Gemma flash attention does not support right padding")
|
self.skipTest(reason="Gemma flash attention does not support right padding")
|
||||||
|
|
||||||
@require_torch_sdpa
|
@require_torch_sdpa
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
@slow
|
@slow
|
||||||
def test_sdpa_equivalence(self):
|
def test_sdpa_equivalence(self):
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
|
|||||||
Reference in New Issue
Block a user