Enable non-safetensor ser/deser for TorchAoConfig quantized model 🔴 (#33456)
* Enable non-safetensor serialization and deserialization for TorchAoConfig quantized model Summary: After https://github.com/huggingface/huggingface_hub/pull/2440 we added non-safetensor serialization and deserialization in huggingface, with this we can now add the support in transformers Note that we don't plan to add safetensor serialization due to different goals of wrapper tensor subclass and safetensor see README for more details Test Plan: tested locally Reviewers: Subscribers: Tasks: Tags: * formatting * formatting * minor fix * formatting * address comments * comments * minor fix * update doc * refactor compressed tensor quantizer
This commit is contained in:
@@ -11,7 +11,7 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
# TorchAO
|
||||
|
||||
[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch, it provides high performance dtypes, optimization techniques and kernels for inference and training, featuring composability with native PyTorch features like `torch.compile`, FSDP etc.. Some benchmark numbers can be found [here](https://github.com/pytorch/ao/tree/main?tab=readme-ov-file#without-intrusive-code-changes)
|
||||
[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch, it provides high performance dtypes, optimization techniques and kernels for inference and training, featuring composability with native PyTorch features like `torch.compile`, FSDP etc.. Some benchmark numbers can be found [here](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks).
|
||||
|
||||
Before you begin, make sure the following libraries are installed with their latest version:
|
||||
|
||||
@@ -21,6 +21,7 @@ pip install --upgrade torch torchao
|
||||
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model_name = "meta-llama/Meta-Llama-3-8B"
|
||||
@@ -40,6 +41,51 @@ quantized_model = torch.compile(quantized_model, mode="max-autotune")
|
||||
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=10)
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
|
||||
# benchmark the performance
|
||||
import torch.utils.benchmark as benchmark
|
||||
|
||||
def benchmark_fn(f, *args, **kwargs):
|
||||
# Manual warmup
|
||||
for _ in range(5):
|
||||
f(*args, **kwargs)
|
||||
|
||||
t0 = benchmark.Timer(
|
||||
stmt="f(*args, **kwargs)",
|
||||
globals={"args": args, "kwargs": kwargs, "f": f},
|
||||
num_threads=torch.get_num_threads(),
|
||||
)
|
||||
return f"{(t0.blocked_autorange().mean):.3f}"
|
||||
|
||||
MAX_NEW_TOKENS = 1000
|
||||
print("int4wo-128 model:", benchmark_fn(quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS))
|
||||
|
||||
bf16_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda", torch_dtype=torch.bfloat16)
|
||||
bf16_model = torch.compile(bf16_model, mode="max-autotune")
|
||||
print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS))
|
||||
|
||||
```
|
||||
|
||||
torchao quantization is implemented with tensor subclasses, currently it does not work with huggingface serialization, both the safetensor option and [non-safetensor option](https://github.com/huggingface/transformers/issues/32364), we'll update here with instructions when it's working.
|
||||
## Serialization and Deserialization
|
||||
torchao quantization is implemented with [tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor), it only work with huggingface non-safetensor serialization and deserialization. It relies on `torch.load(..., weights_only=True)` to avoid arbitrary user code execution during load time and use [add_safe_globals](https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals) to allowlist some known user functions.
|
||||
|
||||
The reason why it does not support safe tensor serialization is that wrapper tensor subclass allows maximum flexibility so we want to make sure the effort of supporting new format of quantized Tensor is low, while safe tensor optimizes for maximum safety (no user code execution), it also means we have to make sure to manually support new quantization format.
|
||||
|
||||
```py
|
||||
# save quantized model locally
|
||||
output_dir = "llama3-8b-int4wo-128"
|
||||
quantized_model.save_pretrained(output_dir, safe_serialization=False)
|
||||
|
||||
# push to huggingface hub
|
||||
# save_to = "{user_id}/llama3-8b-int4wo-128"
|
||||
# quantized_model.push_to_hub(save_to, safe_serialization=False)
|
||||
|
||||
# load quantized model
|
||||
ckpt_id = "llama3-8b-int4wo-128" # or huggingface hub model id
|
||||
loaded_quantized_model = AutoModelForCausalLM.from_pretrained(ckpt_id, device_map="cuda")
|
||||
|
||||
|
||||
# confirm the speedup
|
||||
loaded_quantized_model = torch.compile(loaded_quantized_model, mode="max-autotune")
|
||||
print("loaded int4wo-128 model:", benchmark_fn(loaded_quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS))
|
||||
```
|
||||
|
||||
@@ -540,7 +540,11 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
|
||||
return torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
|
||||
|
||||
|
||||
def load_state_dict(checkpoint_file: Union[str, os.PathLike], is_quantized: bool = False):
|
||||
def load_state_dict(
|
||||
checkpoint_file: Union[str, os.PathLike],
|
||||
is_quantized: bool = False,
|
||||
map_location: Optional[Union[str, torch.device]] = None,
|
||||
):
|
||||
"""
|
||||
Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
|
||||
"""
|
||||
@@ -555,13 +559,18 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], is_quantized: bool
|
||||
)
|
||||
return safe_load_file(checkpoint_file)
|
||||
try:
|
||||
if (
|
||||
(is_deepspeed_zero3_enabled() and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0)
|
||||
or (is_fsdp_enabled() and not is_local_dist_rank_0())
|
||||
) and not is_quantized:
|
||||
map_location = "meta"
|
||||
else:
|
||||
map_location = "cpu"
|
||||
if map_location is None:
|
||||
if (
|
||||
(
|
||||
is_deepspeed_zero3_enabled()
|
||||
and torch.distributed.is_initialized()
|
||||
and torch.distributed.get_rank() > 0
|
||||
)
|
||||
or (is_fsdp_enabled() and not is_local_dist_rank_0())
|
||||
) and not is_quantized:
|
||||
map_location = "meta"
|
||||
else:
|
||||
map_location = "cpu"
|
||||
extra_args = {}
|
||||
# mmap can only be used with files serialized with zipfile-based format.
|
||||
if (
|
||||
@@ -2564,7 +2573,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
hf_quantizer = getattr(self, "hf_quantizer", None)
|
||||
quantization_serializable = (
|
||||
hf_quantizer is not None and isinstance(hf_quantizer, HfQuantizer) and hf_quantizer.is_serializable
|
||||
hf_quantizer is not None
|
||||
and isinstance(hf_quantizer, HfQuantizer)
|
||||
and hf_quantizer.is_serializable(safe_serialization=safe_serialization)
|
||||
)
|
||||
|
||||
if hf_quantizer is not None and not _hf_peft_config_loaded and not quantization_serializable:
|
||||
@@ -4479,7 +4490,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
# Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload.
|
||||
if shard_file in disk_only_shard_files:
|
||||
continue
|
||||
state_dict = load_state_dict(shard_file, is_quantized=is_quantized)
|
||||
map_location = None
|
||||
if (
|
||||
device_map is not None
|
||||
and hf_quantizer is not None
|
||||
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
|
||||
and hf_quantizer.quantization_config.quant_type == "int4_weight_only"
|
||||
):
|
||||
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
|
||||
state_dict = load_state_dict(shard_file, is_quantized=is_quantized, map_location=map_location)
|
||||
|
||||
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
|
||||
# matching the weights in the model.
|
||||
|
||||
@@ -217,9 +217,8 @@ class HfQuantizer(ABC):
|
||||
@abstractmethod
|
||||
def _process_model_after_weight_loading(self, model, **kwargs): ...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def is_serializable(self): ...
|
||||
def is_serializable(self, safe_serialization=None): ...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
|
||||
@@ -93,6 +93,5 @@ class AqlmHfQuantizer(HfQuantizer):
|
||||
)
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_serializable(self):
|
||||
def is_serializable(self, safe_serialization=None):
|
||||
return True
|
||||
|
||||
@@ -106,8 +106,7 @@ class AwqQuantizer(HfQuantizer):
|
||||
|
||||
model = post_init_awq_exllama_modules(model, self.quantization_config.exllama_config)
|
||||
|
||||
@property
|
||||
def is_serializable(self):
|
||||
def is_serializable(self, safe_serialization=None):
|
||||
# AWQ through auto-awq has been always serializable, except if the model is fused.
|
||||
if self.quantization_config.do_fuse:
|
||||
logger.warning("You cannot save an AWQ model that uses fused modules!")
|
||||
|
||||
@@ -320,11 +320,10 @@ class Bnb4BitHfQuantizer(HfQuantizer):
|
||||
# Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer._process_model_after_weight_loading with 8bit->4bit
|
||||
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
|
||||
model.is_loaded_in_4bit = True
|
||||
model.is_4bit_serializable = self.is_serializable
|
||||
model.is_4bit_serializable = self.is_serializable()
|
||||
return model
|
||||
|
||||
@property
|
||||
def is_serializable(self):
|
||||
def is_serializable(self, safe_serialization=None):
|
||||
_is_4bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.41.3")
|
||||
|
||||
if not _is_4bit_serializable:
|
||||
|
||||
@@ -210,7 +210,7 @@ class Bnb8BitHfQuantizer(HfQuantizer):
|
||||
raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {target_device}.")
|
||||
|
||||
new_value = param_value.to("cpu")
|
||||
if self.pre_quantized and not self.is_serializable:
|
||||
if self.pre_quantized and not self.is_serializable():
|
||||
raise ValueError(
|
||||
"Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. "
|
||||
"Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
|
||||
@@ -238,7 +238,7 @@ class Bnb8BitHfQuantizer(HfQuantizer):
|
||||
|
||||
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
|
||||
model.is_loaded_in_8bit = True
|
||||
model.is_8bit_serializable = self.is_serializable
|
||||
model.is_8bit_serializable = self.is_serializable()
|
||||
return model
|
||||
|
||||
def _process_model_before_weight_loading(
|
||||
@@ -282,8 +282,7 @@ class Bnb8BitHfQuantizer(HfQuantizer):
|
||||
|
||||
model.config.quantization_config = self.quantization_config
|
||||
|
||||
@property
|
||||
def is_serializable(self):
|
||||
def is_serializable(self, safe_serialization=None):
|
||||
_bnb_supports_8bit_serialization = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse(
|
||||
"0.37.2"
|
||||
)
|
||||
|
||||
@@ -72,6 +72,5 @@ class CompressedTensorsHfQuantizer(HfQuantizer):
|
||||
def is_trainable(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_serializable(self):
|
||||
def is_serializable(self, safe_serialization=None):
|
||||
return False
|
||||
|
||||
@@ -161,8 +161,7 @@ class EetqHfQuantizer(HfQuantizer):
|
||||
|
||||
model.config.quantization_config = self.quantization_config
|
||||
|
||||
@property
|
||||
def is_serializable(self):
|
||||
def is_serializable(self, safe_serialization=None):
|
||||
return True
|
||||
|
||||
@property
|
||||
|
||||
@@ -196,8 +196,7 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
|
||||
not_missing_keys.append(missing)
|
||||
return [k for k in missing_keys if k not in not_missing_keys]
|
||||
|
||||
@property
|
||||
def is_serializable(self):
|
||||
def is_serializable(self, safe_serialization=None):
|
||||
return True
|
||||
|
||||
@property
|
||||
|
||||
@@ -89,6 +89,5 @@ class GptqHfQuantizer(HfQuantizer):
|
||||
def is_trainable(self, model: Optional["PreTrainedModel"] = None):
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_serializable(self):
|
||||
def is_serializable(self, safe_serialization=None):
|
||||
return True
|
||||
|
||||
@@ -188,11 +188,10 @@ class HqqHfQuantizer(HfQuantizer):
|
||||
|
||||
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
|
||||
model.is_hqq_quantized = True
|
||||
model.is_hqq_serializable = self.is_serializable
|
||||
model.is_hqq_serializable = self.is_serializable()
|
||||
return model
|
||||
|
||||
@property
|
||||
def is_serializable(self):
|
||||
def is_serializable(self, safe_serialization=None):
|
||||
return False
|
||||
|
||||
@property
|
||||
|
||||
@@ -195,6 +195,5 @@ class QuantoHfQuantizer(HfQuantizer):
|
||||
def is_trainable(self, model: Optional["PreTrainedModel"] = None):
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_serializable(self):
|
||||
def is_serializable(self, safe_serialization=None):
|
||||
return False
|
||||
|
||||
@@ -160,9 +160,19 @@ class TorchAoHfQuantizer(HfQuantizer):
|
||||
"""No process required for torchao quantized model"""
|
||||
return
|
||||
|
||||
@property
|
||||
def is_serializable(self):
|
||||
return False
|
||||
def is_serializable(self, safe_serialization=None):
|
||||
if safe_serialization:
|
||||
logger.warning(
|
||||
"torchao quantized model does not support safe serialization, "
|
||||
"please set `safe_serialization` to False"
|
||||
)
|
||||
return False
|
||||
_is_torchao_serializable = version.parse(importlib.metadata.version("huggingface_hub")) >= version.parse(
|
||||
"0.25.0"
|
||||
)
|
||||
if not _is_torchao_serializable:
|
||||
logger.warning("torchao quantized model is only serializable after huggingface_hub >= 0.25.0 ")
|
||||
return _is_torchao_serializable
|
||||
|
||||
@property
|
||||
def is_trainable(self):
|
||||
|
||||
@@ -95,7 +95,6 @@ class QuantizationConfigMixin:
|
||||
Returns:
|
||||
[`QuantizationConfigMixin`]: The configuration object instantiated from those parameters.
|
||||
"""
|
||||
|
||||
config = cls(**config_dict)
|
||||
|
||||
to_remove = []
|
||||
@@ -1235,24 +1234,11 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
self.quant_method = QuantizationMethod.TORCHAO
|
||||
self.quant_type = quant_type
|
||||
self.modules_to_not_convert = modules_to_not_convert
|
||||
self.kwargs = kwargs
|
||||
self._STR_TO_METHOD = {}
|
||||
if is_torchao_available():
|
||||
from torchao.quantization import (
|
||||
int4_weight_only,
|
||||
int8_dynamic_activation_int8_weight,
|
||||
int8_weight_only,
|
||||
)
|
||||
|
||||
self._STR_TO_METHOD = {
|
||||
"int4_weight_only": int4_weight_only,
|
||||
"int8_weight_only": int8_weight_only,
|
||||
"int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight,
|
||||
}
|
||||
# when we load from serailized config, "quant_type_kwargs" will be the key
|
||||
if "quant_type_kwargs" in kwargs:
|
||||
self.quant_type_kwargs = kwargs["quant_type_kwargs"]
|
||||
else:
|
||||
raise ValueError(
|
||||
"TorchAoConfig requires torchao to be installed, please install with `pip install torchao`"
|
||||
)
|
||||
self.quant_type_kwargs = kwargs
|
||||
|
||||
self.post_init()
|
||||
|
||||
@@ -1263,26 +1249,46 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
if not version.parse(importlib.metadata.version("torchao")) >= version.parse("0.4.0"):
|
||||
raise ValueError("Requires torchao 0.4.0 version and above")
|
||||
|
||||
if self.quant_type not in self._STR_TO_METHOD.keys():
|
||||
_STR_TO_METHOD = self._get_torchao_quant_type_to_method()
|
||||
if self.quant_type not in _STR_TO_METHOD.keys():
|
||||
raise ValueError(
|
||||
f"Requested quantization type: {self.quant_type} is not supported yet, please add support in TorchAoConfig and TorchAoHfQuantizer."
|
||||
)
|
||||
|
||||
method = self._STR_TO_METHOD[self.quant_type]
|
||||
method = _STR_TO_METHOD[self.quant_type]
|
||||
sig = signature(method)
|
||||
all_kwargs = [
|
||||
param.name
|
||||
for param in sig.parameters.values()
|
||||
if param.kind in [Parameter.KEYWORD_ONLY, Parameter.POSITIONAL_OR_KEYWORD]
|
||||
]
|
||||
for k in self.kwargs:
|
||||
for k in self.quant_type_kwargs:
|
||||
if k not in all_kwargs:
|
||||
raise ValueError(
|
||||
f"Unexpected keyword arg: {k} for API: {method}, accepted keyword args are: {all_kwargs}"
|
||||
)
|
||||
|
||||
def _get_torchao_quant_type_to_method(self):
|
||||
if is_torchao_available():
|
||||
from torchao.quantization import (
|
||||
int4_weight_only,
|
||||
int8_dynamic_activation_int8_weight,
|
||||
int8_weight_only,
|
||||
)
|
||||
|
||||
return {
|
||||
"int4_weight_only": int4_weight_only,
|
||||
"int8_weight_only": int8_weight_only,
|
||||
"int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight,
|
||||
}
|
||||
else:
|
||||
raise ValueError(
|
||||
"TorchAoConfig requires torchao to be installed, please install with `pip install torchao`"
|
||||
)
|
||||
|
||||
def get_apply_tensor_subclass(self):
|
||||
return self._STR_TO_METHOD[self.quant_type](**self.kwargs)
|
||||
_STR_TO_METHOD = self._get_torchao_quant_type_to_method()
|
||||
return _STR_TO_METHOD[self.quant_type](**self.quant_type_kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.quant_type}({', '.join(str(k) + '=' + str(v) for k, v in self.kwargs.items())})"
|
||||
return f"{self.quant_type}({', '.join(str(k) + '=' + str(v) for k, v in self.quant_type_kwargs.items())})"
|
||||
|
||||
Reference in New Issue
Block a user