[Quantization] Switch to optimum-quanto (#31732)
* switch to optimum-quanto rebase squach * fix import check * again * test try-except * style
This commit is contained in:
@@ -29,7 +29,7 @@ from transformers.testing_utils import (
|
||||
is_flaky,
|
||||
require_accelerate,
|
||||
require_auto_gptq,
|
||||
require_quanto,
|
||||
require_optimum_quanto,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_accelerator,
|
||||
@@ -1941,7 +1941,7 @@ class GenerationTesterMixin:
|
||||
self.assertTrue(len(results.past_key_values.key_cache) == num_hidden_layers)
|
||||
self.assertTrue(results.past_key_values.key_cache[0].shape == cache_shape)
|
||||
|
||||
@require_quanto
|
||||
@require_optimum_quanto
|
||||
@pytest.mark.generate
|
||||
def test_generate_with_quant_cache(self):
|
||||
for model_class in self.all_generative_model_classes:
|
||||
|
||||
@@ -19,13 +19,13 @@ import unittest
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, QuantoConfig
|
||||
from transformers.testing_utils import (
|
||||
require_accelerate,
|
||||
require_quanto,
|
||||
require_optimum_quanto,
|
||||
require_read_token,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import is_accelerate_available, is_quanto_available, is_torch_available
|
||||
from transformers.utils import is_accelerate_available, is_optimum_quanto_available, is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@@ -36,8 +36,8 @@ if is_torch_available():
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
if is_quanto_available():
|
||||
from quanto import QLayerNorm, QLinear
|
||||
if is_optimum_quanto_available():
|
||||
from optimum.quanto import QLayerNorm, QLinear
|
||||
|
||||
from transformers.integrations.quanto import replace_with_quanto_layers
|
||||
|
||||
@@ -47,7 +47,7 @@ class QuantoConfigTest(unittest.TestCase):
|
||||
pass
|
||||
|
||||
|
||||
@require_quanto
|
||||
@require_optimum_quanto
|
||||
@require_accelerate
|
||||
class QuantoTestIntegration(unittest.TestCase):
|
||||
model_id = "facebook/opt-350m"
|
||||
@@ -124,7 +124,7 @@ class QuantoTestIntegration(unittest.TestCase):
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_quanto
|
||||
@require_optimum_quanto
|
||||
@require_accelerate
|
||||
class QuantoQuantizationTest(unittest.TestCase):
|
||||
"""
|
||||
@@ -187,7 +187,7 @@ class QuantoQuantizationTest(unittest.TestCase):
|
||||
self.check_inference_correctness(self.quantized_model, "cuda")
|
||||
|
||||
def test_quantized_model_layers(self):
|
||||
from quanto import QBitsTensor, QModuleMixin, QTensor
|
||||
from optimum.quanto import QBitsTensor, QModuleMixin, QTensor
|
||||
|
||||
"""
|
||||
Suite of simple test to check if the layers are quantized and are working properly
|
||||
@@ -256,7 +256,7 @@ class QuantoQuantizationTest(unittest.TestCase):
|
||||
self.assertTrue(torch.equal(d0[k], d1[k].to(d0[k].device)))
|
||||
|
||||
def test_compare_with_quanto(self):
|
||||
from quanto import freeze, qint4, qint8, quantize
|
||||
from optimum.quanto import freeze, qint4, qint8, quantize
|
||||
|
||||
w_mapping = {"int8": qint8, "int4": qint4}
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
@@ -272,7 +272,7 @@ class QuantoQuantizationTest(unittest.TestCase):
|
||||
|
||||
@unittest.skip
|
||||
def test_load_from_quanto_saved(self):
|
||||
from quanto import freeze, qint4, qint8, quantize
|
||||
from optimum.quanto import freeze, qint4, qint8, quantize
|
||||
|
||||
from transformers import QuantoConfig
|
||||
|
||||
@@ -356,7 +356,7 @@ class QuantoQuantizationOffloadTest(QuantoQuantizationTest):
|
||||
"""
|
||||
We check that we have unquantized value in the cpu and in the disk
|
||||
"""
|
||||
import quanto
|
||||
from optimum.quanto import QBitsTensor, QTensor
|
||||
|
||||
cpu_weights = self.quantized_model.transformer.h[22].self_attention.query_key_value._hf_hook.weights_map[
|
||||
"weight"
|
||||
@@ -364,13 +364,11 @@ class QuantoQuantizationOffloadTest(QuantoQuantizationTest):
|
||||
disk_weights = self.quantized_model.transformer.h[23].self_attention.query_key_value._hf_hook.weights_map[
|
||||
"weight"
|
||||
]
|
||||
self.assertTrue(isinstance(cpu_weights, torch.Tensor) and not isinstance(cpu_weights, quanto.QTensor))
|
||||
self.assertTrue(isinstance(disk_weights, torch.Tensor) and not isinstance(disk_weights, quanto.QTensor))
|
||||
self.assertTrue(isinstance(cpu_weights, torch.Tensor) and not isinstance(cpu_weights, QTensor))
|
||||
self.assertTrue(isinstance(disk_weights, torch.Tensor) and not isinstance(disk_weights, QTensor))
|
||||
if self.weights == "int4":
|
||||
self.assertTrue(isinstance(cpu_weights, torch.Tensor) and not isinstance(disk_weights, quanto.QBitsTensor))
|
||||
self.assertTrue(
|
||||
isinstance(disk_weights, torch.Tensor) and not isinstance(disk_weights, quanto.QBitsTensor)
|
||||
)
|
||||
self.assertTrue(isinstance(cpu_weights, torch.Tensor) and not isinstance(disk_weights, QBitsTensor))
|
||||
self.assertTrue(isinstance(disk_weights, torch.Tensor) and not isinstance(disk_weights, QBitsTensor))
|
||||
|
||||
|
||||
@unittest.skip(reason="Skipping test class because serialization is not supported yet")
|
||||
@@ -416,18 +414,18 @@ class QuantoQuantizationSerializationCudaTest(QuantoQuantizationTest):
|
||||
|
||||
|
||||
class QuantoQuantizationQBitsTensorTest(QuantoQuantizationTest):
|
||||
EXPECTED_OUTPUTS = "Hello my name is Nils, I am a student of the University"
|
||||
EXPECTED_OUTPUTS = "Hello my name is John, I am a professional photographer, I"
|
||||
weights = "int4"
|
||||
|
||||
|
||||
class QuantoQuantizationQBitsTensorOffloadTest(QuantoQuantizationOffloadTest):
|
||||
EXPECTED_OUTPUTS = "Hello my name is Nils, I am a student of the University"
|
||||
EXPECTED_OUTPUTS = "Hello my name is John, I am a professional photographer, I"
|
||||
weights = "int4"
|
||||
|
||||
|
||||
@unittest.skip(reason="Skipping test class because serialization is not supported yet")
|
||||
class QuantoQuantizationQBitsTensorSerializationTest(QuantoQuantizationSerializationTest):
|
||||
EXPECTED_OUTPUTS = "Hello my name is Nils, I am a student of the University"
|
||||
EXPECTED_OUTPUTS = "Hello my name is John, I am a professional photographer, I"
|
||||
weights = "int4"
|
||||
|
||||
|
||||
@@ -443,14 +441,14 @@ class QuantoQuantizationActivationTest(unittest.TestCase):
|
||||
self.assertIn("We don't support quantizing the activations with transformers library", str(e.exception))
|
||||
|
||||
|
||||
@require_quanto
|
||||
@require_optimum_quanto
|
||||
@require_torch_gpu
|
||||
class QuantoKVCacheQuantizationTest(unittest.TestCase):
|
||||
@slow
|
||||
@require_read_token
|
||||
def test_quantized_cache(self):
|
||||
EXPECTED_TEXT_COMPLETION = [
|
||||
"Simply put, the theory of relativity states that 1) the speed of light is the same for all observers, and 2) the laws of physics are the same for all observers.\nThe first part of the theory of relativity",
|
||||
"Simply put, the theory of relativity states that 1) the speed of light is the same for all observers, and 2) the laws of physics are the same for all observers.\nThe first part of the theory is the most",
|
||||
"My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p",
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user