Fix TorchAoConfig not JSON serializable (#36206)
**Summary:** TorchAoConfig optionally contains a
`torchao.dtypes.Layout` object which is a dataclass and not
JSON serializable, and so the following fails:
```
import json
from torchao.dtypes import TensorCoreTiledLayout
from transformers import TorchAoConfig
config = TorchAoConfig("int4_weight_only", layout=TensorCoreTiledLayout())
config.to_json_string()
json.dumps(config.to_dict())
```
This also causes `quantized_model.save_pretrained(...)` to
fail because the first step of this call is to JSON serialize
the config. Fixes https://github.com/pytorch/ao/issues/1704.
**Test Plan:**
python tests/quantization/torchao_integration/test_torchao.py -k test_json_serializable
Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import copy
|
||||
import dataclasses
|
||||
import importlib.metadata
|
||||
import json
|
||||
import os
|
||||
@@ -1539,6 +1540,21 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
config_dict = self.to_dict()
|
||||
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serializes this instance to a Python dictionary, converting any `torchao.dtypes.Layout`
|
||||
dataclasses to simple dicts.
|
||||
|
||||
Returns:
|
||||
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
|
||||
"""
|
||||
d = super().to_dict()
|
||||
if "quant_type_kwargs" in d and "layout" in d["quant_type_kwargs"]:
|
||||
layout = d["quant_type_kwargs"]["layout"]
|
||||
layout = dataclasses.asdict(layout)
|
||||
d["quant_type_kwargs"]["layout"] = layout
|
||||
return d
|
||||
|
||||
|
||||
@dataclass
|
||||
class BitNetConfig(QuantizationConfigMixin):
|
||||
|
||||
@@ -31,8 +31,10 @@ if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_torchao_available():
|
||||
from torchao.dtypes import AffineQuantizedTensor
|
||||
from torchao.dtypes.affine_quantized_tensor import TensorCoreTiledLayoutType
|
||||
from torchao.dtypes import (
|
||||
AffineQuantizedTensor,
|
||||
TensorCoreTiledLayout,
|
||||
)
|
||||
|
||||
|
||||
def check_torchao_quantized(test_module, qlayer, batch_size=1, context_size=1024):
|
||||
@@ -40,7 +42,7 @@ def check_torchao_quantized(test_module, qlayer, batch_size=1, context_size=1024
|
||||
test_module.assertTrue(isinstance(weight, AffineQuantizedTensor))
|
||||
test_module.assertEqual(weight.quant_min, 0)
|
||||
test_module.assertEqual(weight.quant_max, 15)
|
||||
test_module.assertTrue(isinstance(weight.layout_type, TensorCoreTiledLayoutType))
|
||||
test_module.assertTrue(isinstance(weight.layout, TensorCoreTiledLayout))
|
||||
|
||||
|
||||
def check_forward(test_module, model, batch_size=1, context_size=1024):
|
||||
@@ -82,6 +84,16 @@ class TorchAoConfigTest(unittest.TestCase):
|
||||
quantization_config = TorchAoConfig("int4_weight_only", modules_to_not_convert=["conv"], group_size=8)
|
||||
repr(quantization_config)
|
||||
|
||||
def test_json_serializable(self):
|
||||
"""
|
||||
Check that the config dict can be JSON serialized.
|
||||
"""
|
||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=32, layout=TensorCoreTiledLayout())
|
||||
d = quantization_config.to_dict()
|
||||
self.assertIsInstance(d["quant_type_kwargs"]["layout"], dict)
|
||||
self.assertTrue("inner_k_tiles" in d["quant_type_kwargs"]["layout"])
|
||||
quantization_config.to_json_string(use_diff=False)
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torchao
|
||||
|
||||
Reference in New Issue
Block a user