[tests] make quanto tests device-agnostic (#36328)
* make device-agnostic * name change
This commit is contained in:
@@ -22,7 +22,6 @@ from transformers.testing_utils import (
|
|||||||
require_optimum_quanto,
|
require_optimum_quanto,
|
||||||
require_read_token,
|
require_read_token,
|
||||||
require_torch_accelerator,
|
require_torch_accelerator,
|
||||||
require_torch_gpu,
|
|
||||||
slow,
|
slow,
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
@@ -181,11 +180,11 @@ class QuantoQuantizationTest(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
self.check_inference_correctness(self.quantized_model, "cpu")
|
self.check_inference_correctness(self.quantized_model, "cpu")
|
||||||
|
|
||||||
def test_generate_quality_cuda(self):
|
def test_generate_quality_accelerator(self):
|
||||||
"""
|
"""
|
||||||
Simple test to check the quality of the model on cuda by comparing the generated tokens with the expected tokens
|
Simple test to check the quality of the model on accelerators by comparing the generated tokens with the expected tokens
|
||||||
"""
|
"""
|
||||||
self.check_inference_correctness(self.quantized_model, "cuda")
|
self.check_inference_correctness(self.quantized_model, torch_device)
|
||||||
|
|
||||||
def test_quantized_model_layers(self):
|
def test_quantized_model_layers(self):
|
||||||
from optimum.quanto import QBitsTensor, QModuleMixin, QTensor
|
from optimum.quanto import QBitsTensor, QModuleMixin, QTensor
|
||||||
@@ -215,7 +214,7 @@ class QuantoQuantizationTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self.quantized_model.to(0)
|
self.quantized_model.to(0)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
self.quantized_model.transformer.h[0].self_attention.query_key_value.weight._data.device.type, "cuda"
|
self.quantized_model.transformer.h[0].self_attention.query_key_value.weight._data.device.type, torch_device
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_serialization_bin(self):
|
def test_serialization_bin(self):
|
||||||
@@ -430,7 +429,7 @@ class QuantoQuantizationQBitsTensorSerializationTest(QuantoQuantizationSerializa
|
|||||||
weights = "int4"
|
weights = "int4"
|
||||||
|
|
||||||
|
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
class QuantoQuantizationActivationTest(unittest.TestCase):
|
class QuantoQuantizationActivationTest(unittest.TestCase):
|
||||||
def test_quantize_activation(self):
|
def test_quantize_activation(self):
|
||||||
quantization_config = QuantoConfig(
|
quantization_config = QuantoConfig(
|
||||||
@@ -443,7 +442,7 @@ class QuantoQuantizationActivationTest(unittest.TestCase):
|
|||||||
|
|
||||||
|
|
||||||
@require_optimum_quanto
|
@require_optimum_quanto
|
||||||
@require_torch_gpu
|
@require_torch_accelerator
|
||||||
class QuantoKVCacheQuantizationTest(unittest.TestCase):
|
class QuantoKVCacheQuantizationTest(unittest.TestCase):
|
||||||
@slow
|
@slow
|
||||||
@require_read_token
|
@require_read_token
|
||||||
|
|||||||
Reference in New Issue
Block a user