From 7f28da285076c8267ff4c907e3948157f490fb69 Mon Sep 17 00:00:00 2001 From: Yao Matrix Date: Fri, 16 May 2025 19:56:43 +0800 Subject: [PATCH] clean autoawq cases on xpu (#38163) * clean autoawq cases on xpu Signed-off-by: Matrix Yao * fix style Signed-off-by: Matrix Yao --------- Signed-off-by: Matrix Yao --- tests/quantization/autoawq/test_awq.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/quantization/autoawq/test_awq.py b/tests/quantization/autoawq/test_awq.py index 055f736e12..195480be49 100644 --- a/tests/quantization/autoawq/test_awq.py +++ b/tests/quantization/autoawq/test_awq.py @@ -21,6 +21,7 @@ from transformers.testing_utils import ( backend_empty_cache, require_accelerate, require_auto_awq, + require_flash_attn, require_intel_extension_for_pytorch, require_torch_accelerator, require_torch_gpu, @@ -243,7 +244,7 @@ class AwqTest(unittest.TestCase): self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) @require_torch_multi_accelerator - def test_quantized_model_multi_gpu(self): + def test_quantized_model_multi_accelerator(self): """ Simple test that checks if the quantized model is working properly with multiple GPUs """ @@ -305,7 +306,7 @@ class AwqFusedTest(unittest.TestCase): def tearDown(self): gc.collect() - torch.cuda.empty_cache() + backend_empty_cache(torch_device) gc.collect() def _check_fused_modules(self, model): @@ -359,6 +360,8 @@ class AwqFusedTest(unittest.TestCase): torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8, "Skipping because RuntimeError: FlashAttention only supports Ampere GPUs or newer, so not supported on GPU with capability < 8.0", ) + @require_flash_attn + @require_torch_gpu def test_generation_fused(self): """ Test generation quality for fused models - single batch case @@ -382,6 +385,8 @@ class AwqFusedTest(unittest.TestCase): self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION) + @require_flash_attn + @require_torch_gpu @unittest.skipIf( torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8, "Skipping because RuntimeError: FlashAttention only supports Ampere GPUs or newer, so not supported on GPU with capability < 8.0", @@ -433,6 +438,7 @@ class AwqFusedTest(unittest.TestCase): self.assertEqual(outputs[0]["generated_text"], EXPECTED_OUTPUT) + @require_flash_attn @require_torch_multi_gpu @unittest.skipIf( torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8, @@ -473,8 +479,9 @@ class AwqFusedTest(unittest.TestCase): outputs = model.generate(**inputs, max_new_tokens=12) self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION_CUSTOM_MODEL) - @unittest.skip(reason="Not enough GPU memory on CI runners") + @require_flash_attn @require_torch_multi_gpu + @unittest.skip(reason="Not enough GPU memory on CI runners") def test_generation_mixtral_fused(self): """ Text generation test for Mixtral + AWQ + fused