From 7c5bd24ffac6510eb64a0cc2599d9372de809407 Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Fri, 21 Feb 2025 21:20:40 +0800 Subject: [PATCH] [tests] make quanto tests device-agnostic (#36328) * make device-agnostic * name change --- .../quantization/quanto_integration/test_quanto.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/quantization/quanto_integration/test_quanto.py b/tests/quantization/quanto_integration/test_quanto.py index 2022c33665..45ef7616ec 100644 --- a/tests/quantization/quanto_integration/test_quanto.py +++ b/tests/quantization/quanto_integration/test_quanto.py @@ -22,7 +22,6 @@ from transformers.testing_utils import ( require_optimum_quanto, require_read_token, require_torch_accelerator, - require_torch_gpu, slow, torch_device, ) @@ -181,11 +180,11 @@ class QuantoQuantizationTest(unittest.TestCase): """ self.check_inference_correctness(self.quantized_model, "cpu") - def test_generate_quality_cuda(self): + def test_generate_quality_accelerator(self): """ - Simple test to check the quality of the model on cuda by comparing the generated tokens with the expected tokens + Simple test to check the quality of the model on accelerators by comparing the generated tokens with the expected tokens """ - self.check_inference_correctness(self.quantized_model, "cuda") + self.check_inference_correctness(self.quantized_model, torch_device) def test_quantized_model_layers(self): from optimum.quanto import QBitsTensor, QModuleMixin, QTensor @@ -215,7 +214,7 @@ class QuantoQuantizationTest(unittest.TestCase): ) self.quantized_model.to(0) self.assertEqual( - self.quantized_model.transformer.h[0].self_attention.query_key_value.weight._data.device.type, "cuda" + self.quantized_model.transformer.h[0].self_attention.query_key_value.weight._data.device.type, torch_device ) def test_serialization_bin(self): @@ -430,7 +429,7 @@ class QuantoQuantizationQBitsTensorSerializationTest(QuantoQuantizationSerializa weights = "int4" -@require_torch_gpu +@require_torch_accelerator class QuantoQuantizationActivationTest(unittest.TestCase): def test_quantize_activation(self): quantization_config = QuantoConfig( @@ -443,7 +442,7 @@ class QuantoQuantizationActivationTest(unittest.TestCase): @require_optimum_quanto -@require_torch_gpu +@require_torch_accelerator class QuantoKVCacheQuantizationTest(unittest.TestCase): @slow @require_read_token