Fix TorchAoConfig not JSON serializable (#36206)
**Summary:** TorchAoConfig optionally contains a
`torchao.dtypes.Layout` object which is a dataclass and not
JSON serializable, and so the following fails:
```
import json
from torchao.dtypes import TensorCoreTiledLayout
from transformers import TorchAoConfig
config = TorchAoConfig("int4_weight_only", layout=TensorCoreTiledLayout())
config.to_json_string()
json.dumps(config.to_dict())
```
This also causes `quantized_model.save_pretrained(...)` to
fail because the first step of this call is to JSON serialize
the config. Fixes https://github.com/pytorch/ao/issues/1704.
**Test Plan:**
python tests/quantization/torchao_integration/test_torchao.py -k test_json_serializable
Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
@@ -15,6 +15,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import copy
|
import copy
|
||||||
|
import dataclasses
|
||||||
import importlib.metadata
|
import importlib.metadata
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
@@ -1539,6 +1540,21 @@ class TorchAoConfig(QuantizationConfigMixin):
|
|||||||
config_dict = self.to_dict()
|
config_dict = self.to_dict()
|
||||||
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
|
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Serializes this instance to a Python dictionary, converting any `torchao.dtypes.Layout`
|
||||||
|
dataclasses to simple dicts.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
|
||||||
|
"""
|
||||||
|
d = super().to_dict()
|
||||||
|
if "quant_type_kwargs" in d and "layout" in d["quant_type_kwargs"]:
|
||||||
|
layout = d["quant_type_kwargs"]["layout"]
|
||||||
|
layout = dataclasses.asdict(layout)
|
||||||
|
d["quant_type_kwargs"]["layout"] = layout
|
||||||
|
return d
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BitNetConfig(QuantizationConfigMixin):
|
class BitNetConfig(QuantizationConfigMixin):
|
||||||
|
|||||||
@@ -31,8 +31,10 @@ if is_torch_available():
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
if is_torchao_available():
|
if is_torchao_available():
|
||||||
from torchao.dtypes import AffineQuantizedTensor
|
from torchao.dtypes import (
|
||||||
from torchao.dtypes.affine_quantized_tensor import TensorCoreTiledLayoutType
|
AffineQuantizedTensor,
|
||||||
|
TensorCoreTiledLayout,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def check_torchao_quantized(test_module, qlayer, batch_size=1, context_size=1024):
|
def check_torchao_quantized(test_module, qlayer, batch_size=1, context_size=1024):
|
||||||
@@ -40,7 +42,7 @@ def check_torchao_quantized(test_module, qlayer, batch_size=1, context_size=1024
|
|||||||
test_module.assertTrue(isinstance(weight, AffineQuantizedTensor))
|
test_module.assertTrue(isinstance(weight, AffineQuantizedTensor))
|
||||||
test_module.assertEqual(weight.quant_min, 0)
|
test_module.assertEqual(weight.quant_min, 0)
|
||||||
test_module.assertEqual(weight.quant_max, 15)
|
test_module.assertEqual(weight.quant_max, 15)
|
||||||
test_module.assertTrue(isinstance(weight.layout_type, TensorCoreTiledLayoutType))
|
test_module.assertTrue(isinstance(weight.layout, TensorCoreTiledLayout))
|
||||||
|
|
||||||
|
|
||||||
def check_forward(test_module, model, batch_size=1, context_size=1024):
|
def check_forward(test_module, model, batch_size=1, context_size=1024):
|
||||||
@@ -82,6 +84,16 @@ class TorchAoConfigTest(unittest.TestCase):
|
|||||||
quantization_config = TorchAoConfig("int4_weight_only", modules_to_not_convert=["conv"], group_size=8)
|
quantization_config = TorchAoConfig("int4_weight_only", modules_to_not_convert=["conv"], group_size=8)
|
||||||
repr(quantization_config)
|
repr(quantization_config)
|
||||||
|
|
||||||
|
def test_json_serializable(self):
|
||||||
|
"""
|
||||||
|
Check that the config dict can be JSON serialized.
|
||||||
|
"""
|
||||||
|
quantization_config = TorchAoConfig("int4_weight_only", group_size=32, layout=TensorCoreTiledLayout())
|
||||||
|
d = quantization_config.to_dict()
|
||||||
|
self.assertIsInstance(d["quant_type_kwargs"]["layout"], dict)
|
||||||
|
self.assertTrue("inner_k_tiles" in d["quant_type_kwargs"]["layout"])
|
||||||
|
quantization_config.to_json_string(use_diff=False)
|
||||||
|
|
||||||
|
|
||||||
@require_torch_gpu
|
@require_torch_gpu
|
||||||
@require_torchao
|
@require_torchao
|
||||||
|
|||||||
Reference in New Issue
Block a user