clean autoawq cases on xpu (#38163)

* clean autoawq cases on xpu

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

* fix style

Signed-off-by: Matrix Yao <matrix.yao@intel.com>

---------

Signed-off-by: Matrix Yao <matrix.yao@intel.com>
This commit is contained in:
Yao Matrix
2025-05-16 19:56:43 +08:00
committed by GitHub
parent 01ad9f4b49
commit 7f28da2850

View File

@@ -21,6 +21,7 @@ from transformers.testing_utils import (
backend_empty_cache, backend_empty_cache,
require_accelerate, require_accelerate,
require_auto_awq, require_auto_awq,
require_flash_attn,
require_intel_extension_for_pytorch, require_intel_extension_for_pytorch,
require_torch_accelerator, require_torch_accelerator,
require_torch_gpu, require_torch_gpu,
@@ -243,7 +244,7 @@ class AwqTest(unittest.TestCase):
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
@require_torch_multi_accelerator @require_torch_multi_accelerator
def test_quantized_model_multi_gpu(self): def test_quantized_model_multi_accelerator(self):
""" """
Simple test that checks if the quantized model is working properly with multiple GPUs Simple test that checks if the quantized model is working properly with multiple GPUs
""" """
@@ -305,7 +306,7 @@ class AwqFusedTest(unittest.TestCase):
def tearDown(self): def tearDown(self):
gc.collect() gc.collect()
torch.cuda.empty_cache() backend_empty_cache(torch_device)
gc.collect() gc.collect()
def _check_fused_modules(self, model): def _check_fused_modules(self, model):
@@ -359,6 +360,8 @@ class AwqFusedTest(unittest.TestCase):
torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8, torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8,
"Skipping because RuntimeError: FlashAttention only supports Ampere GPUs or newer, so not supported on GPU with capability < 8.0", "Skipping because RuntimeError: FlashAttention only supports Ampere GPUs or newer, so not supported on GPU with capability < 8.0",
) )
@require_flash_attn
@require_torch_gpu
def test_generation_fused(self): def test_generation_fused(self):
""" """
Test generation quality for fused models - single batch case Test generation quality for fused models - single batch case
@@ -382,6 +385,8 @@ class AwqFusedTest(unittest.TestCase):
self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION) self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION)
@require_flash_attn
@require_torch_gpu
@unittest.skipIf( @unittest.skipIf(
torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8, torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8,
"Skipping because RuntimeError: FlashAttention only supports Ampere GPUs or newer, so not supported on GPU with capability < 8.0", "Skipping because RuntimeError: FlashAttention only supports Ampere GPUs or newer, so not supported on GPU with capability < 8.0",
@@ -433,6 +438,7 @@ class AwqFusedTest(unittest.TestCase):
self.assertEqual(outputs[0]["generated_text"], EXPECTED_OUTPUT) self.assertEqual(outputs[0]["generated_text"], EXPECTED_OUTPUT)
@require_flash_attn
@require_torch_multi_gpu @require_torch_multi_gpu
@unittest.skipIf( @unittest.skipIf(
torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8, torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8,
@@ -473,8 +479,9 @@ class AwqFusedTest(unittest.TestCase):
outputs = model.generate(**inputs, max_new_tokens=12) outputs = model.generate(**inputs, max_new_tokens=12)
self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION_CUSTOM_MODEL) self.assertEqual(tokenizer.decode(outputs[0], skip_special_tokens=True), self.EXPECTED_GENERATION_CUSTOM_MODEL)
@unittest.skip(reason="Not enough GPU memory on CI runners") @require_flash_attn
@require_torch_multi_gpu @require_torch_multi_gpu
@unittest.skip(reason="Not enough GPU memory on CI runners")
def test_generation_mixtral_fused(self): def test_generation_mixtral_fused(self):
""" """
Text generation test for Mixtral + AWQ + fused Text generation test for Mixtral + AWQ + fused