From fdcfdbfd221a5b35694db6fb8620eaa729a01f57 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Tue, 18 Feb 2025 05:05:42 -0500 Subject: [PATCH] Fix TorchAoConfig not JSON serializable (#36206) **Summary:** TorchAoConfig optionally contains a `torchao.dtypes.Layout` object which is a dataclass and not JSON serializable, and so the following fails: ``` import json from torchao.dtypes import TensorCoreTiledLayout from transformers import TorchAoConfig config = TorchAoConfig("int4_weight_only", layout=TensorCoreTiledLayout()) config.to_json_string() json.dumps(config.to_dict()) ``` This also causes `quantized_model.save_pretrained(...)` to fail because the first step of this call is to JSON serialize the config. Fixes https://github.com/pytorch/ao/issues/1704. **Test Plan:** python tests/quantization/torchao_integration/test_torchao.py -k test_json_serializable Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --- src/transformers/utils/quantization_config.py | 16 ++++++++++++++++ .../torchao_integration/test_torchao.py | 18 +++++++++++++++--- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index ec8a5ef70d..3fafca29b9 100755 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -15,6 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import copy +import dataclasses import importlib.metadata import json import os @@ -1539,6 +1540,21 @@ class TorchAoConfig(QuantizationConfigMixin): config_dict = self.to_dict() return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n" + def to_dict(self) -> Dict[str, Any]: + """ + Serializes this instance to a Python dictionary, converting any `torchao.dtypes.Layout` + dataclasses to simple dicts. + + Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. + """ + d = super().to_dict() + if "quant_type_kwargs" in d and "layout" in d["quant_type_kwargs"]: + layout = d["quant_type_kwargs"]["layout"] + layout = dataclasses.asdict(layout) + d["quant_type_kwargs"]["layout"] = layout + return d + @dataclass class BitNetConfig(QuantizationConfigMixin): diff --git a/tests/quantization/torchao_integration/test_torchao.py b/tests/quantization/torchao_integration/test_torchao.py index d0263f45f1..1708550cf0 100644 --- a/tests/quantization/torchao_integration/test_torchao.py +++ b/tests/quantization/torchao_integration/test_torchao.py @@ -31,8 +31,10 @@ if is_torch_available(): import torch if is_torchao_available(): - from torchao.dtypes import AffineQuantizedTensor - from torchao.dtypes.affine_quantized_tensor import TensorCoreTiledLayoutType + from torchao.dtypes import ( + AffineQuantizedTensor, + TensorCoreTiledLayout, + ) def check_torchao_quantized(test_module, qlayer, batch_size=1, context_size=1024): @@ -40,7 +42,7 @@ def check_torchao_quantized(test_module, qlayer, batch_size=1, context_size=1024 test_module.assertTrue(isinstance(weight, AffineQuantizedTensor)) test_module.assertEqual(weight.quant_min, 0) test_module.assertEqual(weight.quant_max, 15) - test_module.assertTrue(isinstance(weight.layout_type, TensorCoreTiledLayoutType)) + test_module.assertTrue(isinstance(weight.layout, TensorCoreTiledLayout)) def check_forward(test_module, model, batch_size=1, context_size=1024): @@ -82,6 +84,16 @@ class TorchAoConfigTest(unittest.TestCase): quantization_config = TorchAoConfig("int4_weight_only", modules_to_not_convert=["conv"], group_size=8) repr(quantization_config) + def test_json_serializable(self): + """ + Check that the config dict can be JSON serialized. + """ + quantization_config = TorchAoConfig("int4_weight_only", group_size=32, layout=TensorCoreTiledLayout()) + d = quantization_config.to_dict() + self.assertIsInstance(d["quant_type_kwargs"]["layout"], dict) + self.assertTrue("inner_k_tiles" in d["quant_type_kwargs"]["layout"]) + quantization_config.to_json_string(use_diff=False) + @require_torch_gpu @require_torchao