[tests] remove cuda-only test marker in AwqConfigTest (#37032)
* enable on xpu * add xpu support --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
@@ -893,12 +893,13 @@ class AwqConfig(QuantizationConfigMixin):
|
|||||||
|
|
||||||
if self.backend == AwqBackendPackingMethod.LLMAWQ:
|
if self.backend == AwqBackendPackingMethod.LLMAWQ:
|
||||||
# Only cuda device can run this function
|
# Only cuda device can run this function
|
||||||
if not torch.cuda.is_available():
|
if not (torch.cuda.is_available() or torch.xpu.is_available()):
|
||||||
raise ValueError("LLM-AWQ backend is only supported on CUDA")
|
raise ValueError("LLM-AWQ backend is only supported on CUDA and XPU")
|
||||||
|
if torch.cuda.is_available():
|
||||||
compute_capability = torch.cuda.get_device_capability()
|
compute_capability = torch.cuda.get_device_capability()
|
||||||
major, minor = compute_capability
|
major, minor = compute_capability
|
||||||
if major < 8:
|
if major < 8:
|
||||||
raise ValueError("LLM-AWQ backend is only supported on GPUs with compute capability >= 8.0")
|
raise ValueError("LLM-AWQ backend is only supported on CUDA GPUs with compute capability >= 8.0")
|
||||||
|
|
||||||
if self.do_fuse and self.fuse_max_seq_len is None:
|
if self.do_fuse and self.fuse_max_seq_len is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@@ -41,7 +41,6 @@ if is_accelerate_available():
|
|||||||
|
|
||||||
@require_torch_accelerator
|
@require_torch_accelerator
|
||||||
class AwqConfigTest(unittest.TestCase):
|
class AwqConfigTest(unittest.TestCase):
|
||||||
@require_torch_gpu
|
|
||||||
def test_wrong_backend(self):
|
def test_wrong_backend(self):
|
||||||
"""
|
"""
|
||||||
Simple test that checks if a user passes a wrong backend an error is raised
|
Simple test that checks if a user passes a wrong backend an error is raised
|
||||||
@@ -59,13 +58,15 @@ class AwqConfigTest(unittest.TestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
AwqConfig(bits=4, backend="unexisting-backend")
|
AwqConfig(bits=4, backend="unexisting-backend")
|
||||||
|
|
||||||
# Only cuda device can run this function
|
# Only cuda and xpu devices can run this function
|
||||||
support_llm_awq = False
|
support_llm_awq = False
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
compute_capability = torch.cuda.get_device_capability()
|
compute_capability = torch.cuda.get_device_capability()
|
||||||
major, minor = compute_capability
|
major, minor = compute_capability
|
||||||
if major >= 8:
|
if major >= 8:
|
||||||
support_llm_awq = True
|
support_llm_awq = True
|
||||||
|
elif torch.xpu.is_available():
|
||||||
|
support_llm_awq = True
|
||||||
|
|
||||||
if support_llm_awq:
|
if support_llm_awq:
|
||||||
# LLMAWQ should work on an A100
|
# LLMAWQ should work on an A100
|
||||||
|
|||||||
Reference in New Issue
Block a user