From 67890de3b86c81fb4775f41b4690b2abaf2a19cf Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Wed, 20 Nov 2024 17:24:45 +0100 Subject: [PATCH] Torchao weights only + prequantized compability (#34355) * weights only compability * better tests from code review * ping torch version * add weights_only check --- src/transformers/modeling_utils.py | 6 +- .../quantizers/quantizer_torchao.py | 19 ++++ .../torchao_integration/test_torchao.py | 95 +++++++++++++++++++ 3 files changed, 119 insertions(+), 1 deletion(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 6f2c6c194f..f679f7a190 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3602,7 +3602,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if hf_quantizer is not None: hf_quantizer.validate_environment( - torch_dtype=torch_dtype, from_tf=from_tf, from_flax=from_flax, device_map=device_map + torch_dtype=torch_dtype, + from_tf=from_tf, + from_flax=from_flax, + device_map=device_map, + weights_only=weights_only, ) torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype) device_map = hf_quantizer.update_device_map(device_map) diff --git a/src/transformers/quantizers/quantizer_torchao.py b/src/transformers/quantizers/quantizer_torchao.py index 9a03eb25f4..e6c2dc1ce3 100644 --- a/src/transformers/quantizers/quantizer_torchao.py +++ b/src/transformers/quantizers/quantizer_torchao.py @@ -91,6 +91,15 @@ class TorchAoHfQuantizer(HfQuantizer): ) else: self.offload = True + if self.pre_quantized: + weights_only = kwargs.get("weights_only", None) + if weights_only: + torch_version = version.parse(importlib.metadata.version("torch")) + if torch_version < version.parse("2.5.0"): + raise RuntimeError( + f"In order to use torchao pre-quantized model, you need to have torch>=2.5.0. However, the current version is {torch_version}." + f" You can also set with `weights_only=False` in `from_pretrained` if you don't want to update torch" + ) def update_torch_dtype(self, torch_dtype): if self.quantization_config.quant_type == "int4_weight_only": @@ -103,6 +112,10 @@ class TorchAoHfQuantizer(HfQuantizer): "Setting torch_dtype to torch.bfloat16 for int4_weight_only quantization since only bfloat16 is supported right now. Please set torch_dtype=torch.bfloat16 to remove this warning." ) torch_dtype = torch.bfloat16 + if self.quantization_config.quant_type == "int8_dynamic_activation_int8_weight": + if torch_dtype is None: + # we need to set the torch_dtype, otherwise we have dtype mismatch when performing the quantized linear op + torch_dtype = torch.float32 return torch_dtype def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": @@ -198,6 +211,12 @@ class TorchAoHfQuantizer(HfQuantizer): ) if not _is_torchao_serializable: logger.warning("torchao quantized model is only serializable after huggingface_hub >= 0.25.0 ") + if self.offload and self.quantization_config.modules_to_not_convert is None: + logger.warning( + "The model contains offloaded modules and these modules are not quantized. We don't recommend saving the model as we won't be able to reload them." + "If you want to specify modules to not quantize, please specify modules_to_not_convert in the quantization_config." + ) + return False return _is_torchao_serializable @property diff --git a/tests/quantization/torchao_integration/test_torchao.py b/tests/quantization/torchao_integration/test_torchao.py index c3ab06ee61..3733d6dcf4 100644 --- a/tests/quantization/torchao_integration/test_torchao.py +++ b/tests/quantization/torchao_integration/test_torchao.py @@ -14,6 +14,7 @@ # limitations under the License. import gc +import tempfile import unittest from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig @@ -236,5 +237,99 @@ class TorchAoTest(unittest.TestCase): self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT) +@require_torch_gpu +@require_torchao +class TorchAoSerializationTest(unittest.TestCase): + input_text = "What are we having for dinner?" + max_new_tokens = 10 + ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside" + # TODO: investigate why we don't have the same output as the original model for this test + SERIALIZED_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" + model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" + quant_config = TorchAoConfig("int4_weight_only", group_size=32) + device = "cuda:0" + + # called only once for all test in this class + @classmethod + def setUpClass(cls): + cls.quantized_model = AutoModelForCausalLM.from_pretrained( + cls.model_name, + torch_dtype=torch.bfloat16, + device_map=cls.device, + quantization_config=cls.quant_config, + ) + cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name) + + def tearDown(self): + gc.collect() + torch.cuda.empty_cache() + gc.collect() + + def test_original_model_expected_output(self): + input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device) + output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + + self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.ORIGINAL_EXPECTED_OUTPUT) + + def check_serialization_expected_output(self, device, expected_output): + """ + Test if we can serialize and load/infer the model again on the same device + """ + with tempfile.TemporaryDirectory() as tmpdirname: + self.quantized_model.save_pretrained(tmpdirname, safe_serialization=False) + loaded_quantized_model = AutoModelForCausalLM.from_pretrained( + self.model_name, torch_dtype=torch.bfloat16, device_map=self.device + ) + input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device) + + output = loaded_quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) + self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), expected_output) + + def test_serialization_expected_output(self): + self.check_serialization_expected_output(self.device, self.SERIALIZED_EXPECTED_OUTPUT) + + +class TorchAoSerializationW8A8Test(TorchAoSerializationTest): + quant_config = TorchAoConfig("int8_dynamic_activation_int8_weight") + ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" + SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT + device = "cuda:0" + + +class TorchAoSerializationW8Test(TorchAoSerializationTest): + quant_config = TorchAoConfig("int8_weight_only") + ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" + SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT + device = "cuda:0" + + +class TorchAoSerializationW8A8CPUTest(TorchAoSerializationTest): + quant_config = TorchAoConfig("int8_dynamic_activation_int8_weight") + ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" + SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT + device = "cpu" + + def test_serialization_expected_output_cuda(self): + """ + Test if we can serialize on device (cpu) and load/infer the model on cuda + """ + new_device = "cuda:0" + self.check_serialization_expected_output(new_device, self.SERIALIZED_EXPECTED_OUTPUT) + + +class TorchAoSerializationW8CPUTest(TorchAoSerializationTest): + quant_config = TorchAoConfig("int8_weight_only") + ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" + SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT + device = "cpu" + + def test_serialization_expected_output_cuda(self): + """ + Test if we can serialize on device (cpu) and load/infer the model on cuda + """ + new_device = "cuda:0" + self.check_serialization_expected_output(new_device, self.SERIALIZED_EXPECTED_OUTPUT) + + if __name__ == "__main__": unittest.main()