Fix TorchAoConfig not JSON serializable (#36206)

**Summary:** TorchAoConfig optionally contains a
`torchao.dtypes.Layout` object which is a dataclass and not
JSON serializable, and so the following fails:

```
import json
from torchao.dtypes import TensorCoreTiledLayout
from transformers import TorchAoConfig

config = TorchAoConfig("int4_weight_only", layout=TensorCoreTiledLayout())

config.to_json_string()

json.dumps(config.to_dict())
```

This also causes `quantized_model.save_pretrained(...)` to
fail because the first step of this call is to JSON serialize
the config. Fixes https://github.com/pytorch/ao/issues/1704.

**Test Plan:**
python tests/quantization/torchao_integration/test_torchao.py -k test_json_serializable

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
andrewor14
2025-02-18 05:05:42 -05:00
committed by GitHub
parent 626666c444
commit fdcfdbfd22
2 changed files with 31 additions and 3 deletions

View File

@@ -15,6 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import dataclasses
import importlib.metadata
import json
import os
@@ -1539,6 +1540,21 @@ class TorchAoConfig(QuantizationConfigMixin):
config_dict = self.to_dict()
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
def to_dict(self) -> Dict[str, Any]:
"""
Serializes this instance to a Python dictionary, converting any `torchao.dtypes.Layout`
dataclasses to simple dicts.
Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
d = super().to_dict()
if "quant_type_kwargs" in d and "layout" in d["quant_type_kwargs"]:
layout = d["quant_type_kwargs"]["layout"]
layout = dataclasses.asdict(layout)
d["quant_type_kwargs"]["layout"] = layout
return d
@dataclass
class BitNetConfig(QuantizationConfigMixin):

View File

@@ -31,8 +31,10 @@ if is_torch_available():
import torch
if is_torchao_available():
from torchao.dtypes import AffineQuantizedTensor
from torchao.dtypes.affine_quantized_tensor import TensorCoreTiledLayoutType
from torchao.dtypes import (
AffineQuantizedTensor,
TensorCoreTiledLayout,
)
def check_torchao_quantized(test_module, qlayer, batch_size=1, context_size=1024):
@@ -40,7 +42,7 @@ def check_torchao_quantized(test_module, qlayer, batch_size=1, context_size=1024
test_module.assertTrue(isinstance(weight, AffineQuantizedTensor))
test_module.assertEqual(weight.quant_min, 0)
test_module.assertEqual(weight.quant_max, 15)
test_module.assertTrue(isinstance(weight.layout_type, TensorCoreTiledLayoutType))
test_module.assertTrue(isinstance(weight.layout, TensorCoreTiledLayout))
def check_forward(test_module, model, batch_size=1, context_size=1024):
@@ -82,6 +84,16 @@ class TorchAoConfigTest(unittest.TestCase):
quantization_config = TorchAoConfig("int4_weight_only", modules_to_not_convert=["conv"], group_size=8)
repr(quantization_config)
def test_json_serializable(self):
"""
Check that the config dict can be JSON serialized.
"""
quantization_config = TorchAoConfig("int4_weight_only", group_size=32, layout=TensorCoreTiledLayout())
d = quantization_config.to_dict()
self.assertIsInstance(d["quant_type_kwargs"]["layout"], dict)
self.assertTrue("inner_k_tiles" in d["quant_type_kwargs"]["layout"])
quantization_config.to_json_string(use_diff=False)
@require_torch_gpu
@require_torchao