switch to device agnostic device calling for test cases (#38247)

* use device agnostic APIs in test cases

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* fix style

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* add one more

Signed-off-by: YAO Matrix <matrix.yao@intel.com>

* xpu now supports integer device id, aligning to CUDA behaviors

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* update to use device_properties

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* fix style

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* update comment

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* fix comments

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* fix style

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

---------

Signed-off-by: Matrix Yao <matrix.yao@intel.com>
Signed-off-by: YAO Matrix <matrix.yao@intel.com>
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yao Matrix
2025-05-26 16:18:53 +08:00
committed by GitHub
parent cba279f46c
commit a5a0c7b888
39 changed files with 259 additions and 389 deletions

View File

@@ -19,6 +19,7 @@ import unittest
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AwqConfig, OPTForCausalLM
from transformers.testing_utils import (
backend_empty_cache,
get_device_properties,
require_accelerate,
require_auto_awq,
require_flash_attn,
@@ -61,12 +62,10 @@ class AwqConfigTest(unittest.TestCase):
# Only cuda and xpu devices can run this function
support_llm_awq = False
if torch.cuda.is_available():
compute_capability = torch.cuda.get_device_capability()
major, minor = compute_capability
if major >= 8:
support_llm_awq = True
elif torch.xpu.is_available():
device_type, major = get_device_properties()
if device_type == "cuda" and major >= 8:
support_llm_awq = True
elif device_type == "xpu":
support_llm_awq = True
if support_llm_awq:
@@ -357,7 +356,7 @@ class AwqFusedTest(unittest.TestCase):
self.assertTrue(isinstance(model.model.layers[0].block_sparse_moe.gate, torch.nn.Linear))
@unittest.skipIf(
torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8,
get_device_properties()[0] == "cuda" and get_device_properties()[1] < 8,
"Skipping because RuntimeError: FlashAttention only supports Ampere GPUs or newer, so not supported on GPU with capability < 8.0",
)
@require_flash_attn
@@ -388,7 +387,7 @@ class AwqFusedTest(unittest.TestCase):
@require_flash_attn
@require_torch_gpu
@unittest.skipIf(
torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8,
get_device_properties()[0] == "cuda" and get_device_properties()[1] < 8,
"Skipping because RuntimeError: FlashAttention only supports Ampere GPUs or newer, so not supported on GPU with capability < 8.0",
)
def test_generation_fused_batched(self):
@@ -441,7 +440,7 @@ class AwqFusedTest(unittest.TestCase):
@require_flash_attn
@require_torch_multi_gpu
@unittest.skipIf(
torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8,
get_device_properties()[0] == "cuda" and get_device_properties()[1] < 8,
"Skipping because RuntimeError: FlashAttention only supports Ampere GPUs or newer, so not supported on GPU with capability < 8.0",
)
def test_generation_custom_model(self):