From 99f9f1042f59f60de9a8f0538c1117e4eca38ef9 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 7 Apr 2025 20:50:48 +0800 Subject: [PATCH] Fix torchao usage (#37034) * fix load path Signed-off-by: jiqing-feng * fix path Signed-off-by: jiqing-feng * Fix torchao usage Signed-off-by: jiqing-feng * fix tests Signed-off-by: jiqing-feng * fix format Signed-off-by: jiqing-feng * revert useless change Signed-off-by: jiqing-feng * format Signed-off-by: jiqing-feng * revert fp8 test Signed-off-by: jiqing-feng * fix fp8 test Signed-off-by: jiqing-feng * fix fp8 test Signed-off-by: jiqing-feng * fix torch dtype Signed-off-by: jiqing-feng --------- Signed-off-by: jiqing-feng Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --- src/transformers/utils/quantization_config.py | 18 ++++- .../torchao_integration/test_torchao.py | 68 +++++++++---------- 2 files changed, 50 insertions(+), 36 deletions(-) diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index b0f119c58b..3bf205e8e1 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -20,7 +20,7 @@ import dataclasses import importlib.metadata import json import os -from dataclasses import dataclass +from dataclasses import dataclass, is_dataclass from enum import Enum from inspect import Parameter, signature from typing import Any, Dict, List, Optional, Tuple, Union @@ -1627,6 +1627,7 @@ class TorchAoConfig(QuantizationConfigMixin): and is_torchao_available() and self.quant_type == "int4_weight_only" and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0") + and quant_type_kwargs.get("layout", None) is None ): from torchao.dtypes import Int4CPULayout @@ -1643,7 +1644,17 @@ class TorchAoConfig(QuantizationConfigMixin): if isinstance(self.quant_type, str): # Handle layout serialization if present if "quant_type_kwargs" in d and "layout" in d["quant_type_kwargs"]: - d["quant_type_kwargs"]["layout"] = dataclasses.asdict(d["quant_type_kwargs"]["layout"]) + if is_dataclass(d["quant_type_kwargs"]["layout"]): + d["quant_type_kwargs"]["layout"] = [ + d["quant_type_kwargs"]["layout"].__class__.__name__, + dataclasses.asdict(d["quant_type_kwargs"]["layout"]), + ] + if isinstance(d["quant_type_kwargs"]["layout"], list): + assert len(d["quant_type_kwargs"]["layout"]) == 2, "layout saves layout name and layour kwargs" + assert isinstance(d["quant_type_kwargs"]["layout"][0], str), "layout name must be a string" + assert isinstance(d["quant_type_kwargs"]["layout"][1], dict), "layout kwargs must be a dict" + else: + raise ValueError("layout must be a list") else: # Handle AOBaseConfig serialization from torchao.core.config import config_to_dict @@ -1661,6 +1672,9 @@ class TorchAoConfig(QuantizationConfigMixin): assert ao_verison > version.parse("0.9.0"), "TorchAoConfig requires torchao > 0.9.0 for construction from dict" config_dict = config_dict.copy() quant_type = config_dict.pop("quant_type") + + if isinstance(quant_type, str): + return cls(quant_type=quant_type, **config_dict) # Check if we only have one key which is "default" # In the future we may update this assert len(quant_type) == 1 and "default" in quant_type, ( diff --git a/tests/quantization/torchao_integration/test_torchao.py b/tests/quantization/torchao_integration/test_torchao.py index b6c12ab738..5f4a21e073 100644 --- a/tests/quantization/torchao_integration/test_torchao.py +++ b/tests/quantization/torchao_integration/test_torchao.py @@ -104,8 +104,8 @@ class TorchAoConfigTest(unittest.TestCase): """ quantization_config = TorchAoConfig("int4_weight_only", group_size=32, layout=TensorCoreTiledLayout()) d = quantization_config.to_dict() - self.assertIsInstance(d["quant_type_kwargs"]["layout"], dict) - self.assertTrue("inner_k_tiles" in d["quant_type_kwargs"]["layout"]) + self.assertIsInstance(d["quant_type_kwargs"]["layout"], list) + self.assertTrue("inner_k_tiles" in d["quant_type_kwargs"]["layout"][1]) quantization_config.to_json_string(use_diff=False) @@ -159,7 +159,7 @@ class TorchAoTest(unittest.TestCase): # Note: we quantize the bfloat16 model on the fly to int4 quantized_model = AutoModelForCausalLM.from_pretrained( self.model_name, - torch_dtype=None, + torch_dtype=torch.bfloat16, device_map=self.device, quantization_config=quant_config, ) @@ -282,7 +282,7 @@ class TorchAoGPUTest(TorchAoTest): quantized_model = AutoModelForCausalLM.from_pretrained( self.model_name, - torch_dtype=torch.bfloat16, + torch_dtype="auto", device_map=self.device, quantization_config=quant_config, ) @@ -295,7 +295,7 @@ class TorchAoGPUTest(TorchAoTest): check_autoquantized(self, quantized_model.model.layers[0].self_attn.v_proj) - EXPECTED_OUTPUT = 'What are we having for dinner?\n\n10. "Dinner is ready' + EXPECTED_OUTPUT = "What are we having for dinner?\n\nJane: (sighs)" output = quantized_model.generate( **input_ids, max_new_tokens=self.max_new_tokens, cache_implementation="static" ) @@ -307,9 +307,7 @@ class TorchAoGPUTest(TorchAoTest): class TorchAoSerializationTest(unittest.TestCase): input_text = "What are we having for dinner?" max_new_tokens = 10 - ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside" - # TODO: investigate why we don't have the same output as the original model for this test - SERIALIZED_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" + EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside" model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" quant_scheme = "int4_weight_only" quant_scheme_kwargs = ( @@ -326,9 +324,10 @@ class TorchAoSerializationTest(unittest.TestCase): def setUp(self): self.quant_config = TorchAoConfig(self.quant_scheme, **self.quant_scheme_kwargs) + torch_dtype = torch.bfloat16 if self.quant_scheme == "int4_weight_only" else "auto" self.quantized_model = AutoModelForCausalLM.from_pretrained( self.model_name, - torch_dtype=torch.bfloat16, + torch_dtype=torch_dtype, device_map=self.device, quantization_config=self.quant_config, ) @@ -342,16 +341,17 @@ class TorchAoSerializationTest(unittest.TestCase): input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(self.device) output = self.quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens) - self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.ORIGINAL_EXPECTED_OUTPUT) + self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT) def check_serialization_expected_output(self, device, expected_output): """ Test if we can serialize and load/infer the model again on the same device """ + torch_dtype = torch.bfloat16 if self.quant_scheme == "int4_weight_only" else "auto" with tempfile.TemporaryDirectory() as tmpdirname: self.quantized_model.save_pretrained(tmpdirname, safe_serialization=False) loaded_quantized_model = AutoModelForCausalLM.from_pretrained( - self.model_name, torch_dtype=torch.bfloat16, device_map=device + tmpdirname, torch_dtype=torch_dtype, device_map=device ) input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(device) @@ -359,33 +359,31 @@ class TorchAoSerializationTest(unittest.TestCase): self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), expected_output) def test_serialization_expected_output(self): - self.check_serialization_expected_output(self.device, self.SERIALIZED_EXPECTED_OUTPUT) + self.check_serialization_expected_output(self.device, self.EXPECTED_OUTPUT) class TorchAoSerializationW8A8CPUTest(TorchAoSerializationTest): quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {} - ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" - SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT + EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" @require_torch_gpu def test_serialization_expected_output_on_cuda(self): """ Test if we can serialize on device (cpu) and load/infer the model on cuda """ - self.check_serialization_expected_output("cuda", self.SERIALIZED_EXPECTED_OUTPUT) + self.check_serialization_expected_output("cuda", self.EXPECTED_OUTPUT) class TorchAoSerializationW8CPUTest(TorchAoSerializationTest): quant_scheme, quant_scheme_kwargs = "int8_weight_only", {} - ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" - SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT + EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" @require_torch_gpu def test_serialization_expected_output_on_cuda(self): """ Test if we can serialize on device (cpu) and load/infer the model on cuda """ - self.check_serialization_expected_output("cuda", self.SERIALIZED_EXPECTED_OUTPUT) + self.check_serialization_expected_output("cuda", self.EXPECTED_OUTPUT) @require_torch_gpu @@ -397,53 +395,55 @@ class TorchAoSerializationGPTTest(TorchAoSerializationTest): @require_torch_gpu class TorchAoSerializationW8A8GPUTest(TorchAoSerializationTest): quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {} - ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" - SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT + EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" device = "cuda:0" @require_torch_gpu class TorchAoSerializationW8GPUTest(TorchAoSerializationTest): quant_scheme, quant_scheme_kwargs = "int8_weight_only", {} - ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" - SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT + EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" device = "cuda:0" @require_torch_gpu @require_torchao_version_greater_or_equal("0.10.0") class TorchAoSerializationFP8GPUTest(TorchAoSerializationTest): - ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" - SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT + EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" device = "cuda:0" - def setUp(self): + # called only once for all test in this class + @classmethod + def setUpClass(cls): if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9: raise unittest.SkipTest("CUDA compute capability 9.0 or higher required for FP8 tests") from torchao.quantization import Float8WeightOnlyConfig - self.quant_scheme = Float8WeightOnlyConfig() - self.quant_scheme_kwargs = {} - super().setUp() + cls.quant_scheme = Float8WeightOnlyConfig() + cls.quant_scheme_kwargs = {} + + super().setUpClass() @require_torch_gpu @require_torchao_version_greater_or_equal("0.10.0") class TorchAoSerializationA8W4Test(TorchAoSerializationTest): - ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" - SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT + EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" device = "cuda:0" - def setUp(self): + # called only once for all test in this class + @classmethod + def setUpClass(cls): if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9: raise unittest.SkipTest("CUDA compute capability 9.0 or higher required for FP8 tests") from torchao.quantization import Int8DynamicActivationInt4WeightConfig - self.quant_scheme = Int8DynamicActivationInt4WeightConfig() - self.quant_scheme_kwargs = {} - super().setUp() + cls.quant_scheme = Int8DynamicActivationInt4WeightConfig() + cls.quant_scheme_kwargs = {} + + super().setUpClass() if __name__ == "__main__":