Fix TorchAoConfig not JSON serializable (#36206)
**Summary:** TorchAoConfig optionally contains a
`torchao.dtypes.Layout` object which is a dataclass and not
JSON serializable, and so the following fails:
```
import json
from torchao.dtypes import TensorCoreTiledLayout
from transformers import TorchAoConfig
config = TorchAoConfig("int4_weight_only", layout=TensorCoreTiledLayout())
config.to_json_string()
json.dumps(config.to_dict())
```
This also causes `quantized_model.save_pretrained(...)` to
fail because the first step of this call is to JSON serialize
the config. Fixes https://github.com/pytorch/ao/issues/1704.
**Test Plan:**
python tests/quantization/torchao_integration/test_torchao.py -k test_json_serializable
Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
@@ -31,8 +31,10 @@ if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_torchao_available():
|
||||
from torchao.dtypes import AffineQuantizedTensor
|
||||
from torchao.dtypes.affine_quantized_tensor import TensorCoreTiledLayoutType
|
||||
from torchao.dtypes import (
|
||||
AffineQuantizedTensor,
|
||||
TensorCoreTiledLayout,
|
||||
)
|
||||
|
||||
|
||||
def check_torchao_quantized(test_module, qlayer, batch_size=1, context_size=1024):
|
||||
@@ -40,7 +42,7 @@ def check_torchao_quantized(test_module, qlayer, batch_size=1, context_size=1024
|
||||
test_module.assertTrue(isinstance(weight, AffineQuantizedTensor))
|
||||
test_module.assertEqual(weight.quant_min, 0)
|
||||
test_module.assertEqual(weight.quant_max, 15)
|
||||
test_module.assertTrue(isinstance(weight.layout_type, TensorCoreTiledLayoutType))
|
||||
test_module.assertTrue(isinstance(weight.layout, TensorCoreTiledLayout))
|
||||
|
||||
|
||||
def check_forward(test_module, model, batch_size=1, context_size=1024):
|
||||
@@ -82,6 +84,16 @@ class TorchAoConfigTest(unittest.TestCase):
|
||||
quantization_config = TorchAoConfig("int4_weight_only", modules_to_not_convert=["conv"], group_size=8)
|
||||
repr(quantization_config)
|
||||
|
||||
def test_json_serializable(self):
|
||||
"""
|
||||
Check that the config dict can be JSON serialized.
|
||||
"""
|
||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=32, layout=TensorCoreTiledLayout())
|
||||
d = quantization_config.to_dict()
|
||||
self.assertIsInstance(d["quant_type_kwargs"]["layout"], dict)
|
||||
self.assertTrue("inner_k_tiles" in d["quant_type_kwargs"]["layout"])
|
||||
quantization_config.to_json_string(use_diff=False)
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torchao
|
||||
|
||||
Reference in New Issue
Block a user