From 3cb8676a915c6fa8ad863afd8a2b1a6f4507f3ec Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Wed, 20 Nov 2024 20:28:51 +0100 Subject: [PATCH] Fix CI by tweaking torchao tests (#34832) --- src/transformers/utils/quantization_config.py | 9 +++++++-- .../quantization/torchao_integration/test_torchao.py | 11 ++++++----- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 2f04df97e8..ac81864e50 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -1264,8 +1264,13 @@ class TorchAoConfig(QuantizationConfigMixin): r""" Safety checker that arguments are correct - also replaces some NoneType arguments with their default values. """ - if not version.parse(importlib.metadata.version("torchao")) >= version.parse("0.4.0"): - raise ValueError("Requires torchao 0.4.0 version and above") + if is_torchao_available(): + if not version.parse(importlib.metadata.version("torchao")) >= version.parse("0.4.0"): + raise ValueError("Requires torchao 0.4.0 version and above") + else: + raise ValueError( + "TorchAoConfig requires torchao to be installed, please install with `pip install torchao`" + ) _STR_TO_METHOD = self._get_torchao_quant_type_to_method() if self.quant_type not in _STR_TO_METHOD.keys(): diff --git a/tests/quantization/torchao_integration/test_torchao.py b/tests/quantization/torchao_integration/test_torchao.py index 3733d6dcf4..d0263f45f1 100644 --- a/tests/quantization/torchao_integration/test_torchao.py +++ b/tests/quantization/torchao_integration/test_torchao.py @@ -246,12 +246,13 @@ class TorchAoSerializationTest(unittest.TestCase): # TODO: investigate why we don't have the same output as the original model for this test SERIALIZED_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" - quant_config = TorchAoConfig("int4_weight_only", group_size=32) + quant_scheme, quant_scheme_kwargs = "int4_weight_only", {"group_size": 32} device = "cuda:0" # called only once for all test in this class @classmethod def setUpClass(cls): + cls.quant_config = TorchAoConfig(cls.quant_scheme, **cls.quant_scheme_kwargs) cls.quantized_model = AutoModelForCausalLM.from_pretrained( cls.model_name, torch_dtype=torch.bfloat16, @@ -290,21 +291,21 @@ class TorchAoSerializationTest(unittest.TestCase): class TorchAoSerializationW8A8Test(TorchAoSerializationTest): - quant_config = TorchAoConfig("int8_dynamic_activation_int8_weight") + quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {} ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT device = "cuda:0" class TorchAoSerializationW8Test(TorchAoSerializationTest): - quant_config = TorchAoConfig("int8_weight_only") + quant_scheme, quant_scheme_kwargs = "int8_weight_only", {} ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT device = "cuda:0" class TorchAoSerializationW8A8CPUTest(TorchAoSerializationTest): - quant_config = TorchAoConfig("int8_dynamic_activation_int8_weight") + quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {} ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT device = "cpu" @@ -318,7 +319,7 @@ class TorchAoSerializationW8A8CPUTest(TorchAoSerializationTest): class TorchAoSerializationW8CPUTest(TorchAoSerializationTest): - quant_config = TorchAoConfig("int8_weight_only") + quant_scheme, quant_scheme_kwargs = "int8_weight_only", {} ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT device = "cpu"