enable misc cases on XPU & use device agnostic APIs for cases in tests (#38192)
* use device agnostic APIs in tests Signed-off-by: Matrix Yao <matrix.yao@intel.com> * more Signed-off-by: Matrix Yao <matrix.yao@intel.com> * fix style Signed-off-by: Matrix Yao <matrix.yao@intel.com> * add reset_peak_memory_stats API Signed-off-by: YAO Matrix <matrix.yao@intel.com> * update --------- Signed-off-by: Matrix Yao <matrix.yao@intel.com> Signed-off-by: YAO Matrix <matrix.yao@intel.com> Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -21,6 +21,7 @@ from transformers import set_seed
|
||||
from transformers.generation.configuration_utils import ALL_CACHE_IMPLEMENTATIONS
|
||||
from transformers.testing_utils import (
|
||||
CaptureStderr,
|
||||
backend_device_count,
|
||||
cleanup,
|
||||
get_gpu_count,
|
||||
is_torch_available,
|
||||
@@ -210,8 +211,8 @@ class CacheIntegrationTest(unittest.TestCase):
|
||||
if not has_accelerator:
|
||||
self.skipTest("Offloaded caches require an accelerator")
|
||||
if cache_implementation in ["offloaded_static", "offloaded_hybrid_chunked"]:
|
||||
if torch.cuda.device_count() != 1:
|
||||
self.skipTest("Offloaded static caches require exactly 1 GPU")
|
||||
if backend_device_count(torch_device) != 1:
|
||||
self.skipTest("Offloaded static caches require exactly 1 accelerator")
|
||||
|
||||
@parameterized.expand(TEST_CACHE_IMPLEMENTATIONS)
|
||||
def test_cache_batched(self, cache_implementation):
|
||||
|
||||
Reference in New Issue
Block a user