From 475664e2c6b986a7ccb268f35ac4a2f2e5654568 Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Mon, 31 Mar 2025 17:53:02 +0800 Subject: [PATCH] [tests] remove cuda-only test marker in `AwqConfigTest` (#37032) * enable on xpu * add xpu support --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --- src/transformers/utils/quantization_config.py | 13 +++++++------ tests/quantization/autoawq/test_awq.py | 5 +++-- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 7128677a79..8b29cfec1a 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -893,12 +893,13 @@ class AwqConfig(QuantizationConfigMixin): if self.backend == AwqBackendPackingMethod.LLMAWQ: # Only cuda device can run this function - if not torch.cuda.is_available(): - raise ValueError("LLM-AWQ backend is only supported on CUDA") - compute_capability = torch.cuda.get_device_capability() - major, minor = compute_capability - if major < 8: - raise ValueError("LLM-AWQ backend is only supported on GPUs with compute capability >= 8.0") + if not (torch.cuda.is_available() or torch.xpu.is_available()): + raise ValueError("LLM-AWQ backend is only supported on CUDA and XPU") + if torch.cuda.is_available(): + compute_capability = torch.cuda.get_device_capability() + major, minor = compute_capability + if major < 8: + raise ValueError("LLM-AWQ backend is only supported on CUDA GPUs with compute capability >= 8.0") if self.do_fuse and self.fuse_max_seq_len is None: raise ValueError( diff --git a/tests/quantization/autoawq/test_awq.py b/tests/quantization/autoawq/test_awq.py index d597f8de71..913c6636b1 100644 --- a/tests/quantization/autoawq/test_awq.py +++ b/tests/quantization/autoawq/test_awq.py @@ -41,7 +41,6 @@ if is_accelerate_available(): @require_torch_accelerator class AwqConfigTest(unittest.TestCase): - @require_torch_gpu def test_wrong_backend(self): """ Simple test that checks if a user passes a wrong backend an error is raised @@ -59,13 +58,15 @@ class AwqConfigTest(unittest.TestCase): with self.assertRaises(ValueError): AwqConfig(bits=4, backend="unexisting-backend") - # Only cuda device can run this function + # Only cuda and xpu devices can run this function support_llm_awq = False if torch.cuda.is_available(): compute_capability = torch.cuda.get_device_capability() major, minor = compute_capability if major >= 8: support_llm_awq = True + elif torch.xpu.is_available(): + support_llm_awq = True if support_llm_awq: # LLMAWQ should work on an A100