enable misc cases on XPU & use device agnostic APIs for cases in tests (#38192)

* use device agnostic APIs in tests

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* more

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* fix style

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* add reset_peak_memory_stats API

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* update

---------

Signed-off-by: Matrix Yao <matrix.yao@intel.com>
Signed-off-by: YAO Matrix <matrix.yao@intel.com>
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yao Matrix
2025-05-20 16:09:01 +08:00
committed by GitHub
parent dbc4b91db4
commit 3bd1c20149
13 changed files with 52 additions and 30 deletions

View File

@@ -21,6 +21,7 @@ from transformers import set_seed
from transformers.generation.configuration_utils import ALL_CACHE_IMPLEMENTATIONS
from transformers.testing_utils import (
CaptureStderr,
backend_device_count,
cleanup,
get_gpu_count,
is_torch_available,
@@ -210,8 +211,8 @@ class CacheIntegrationTest(unittest.TestCase):
if not has_accelerator:
self.skipTest("Offloaded caches require an accelerator")
if cache_implementation in ["offloaded_static", "offloaded_hybrid_chunked"]:
if torch.cuda.device_count() != 1:
self.skipTest("Offloaded static caches require exactly 1 GPU")
if backend_device_count(torch_device) != 1:
self.skipTest("Offloaded static caches require exactly 1 accelerator")
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
def test_cache_batched(self, cache_implementation):