Enable non-safetensor ser/deser for TorchAoConfig quantized model 🔴 (#33456)

* Enable non-safetensor serialization and deserialization for TorchAoConfig quantized model

Summary:
After https://github.com/huggingface/huggingface_hub/pull/2440 we added non-safetensor serialization and deserialization
in huggingface, with this we can now add the support in transformers

Note that we don't plan to add safetensor serialization due to different goals of wrapper tensor subclass and safetensor
see README for more details

Test Plan:
tested locally

Reviewers:

Subscribers:

Tasks:

Tags:

* formatting

* formatting

* minor fix

* formatting

* address comments

* comments

* minor fix

* update doc

* refactor compressed tensor quantizer
This commit is contained in:
Jerry Zhang
2024-09-30 02:30:29 -07:00
committed by GitHub
parent 2e24ee4dfa
commit 4bb49d4e00
15 changed files with 134 additions and 64 deletions

View File

@@ -11,7 +11,7 @@ rendered properly in your Markdown viewer.
# TorchAO
[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch, it provides high performance dtypes, optimization techniques and kernels for inference and training, featuring composability with native PyTorch features like `torch.compile`, FSDP etc.. Some benchmark numbers can be found [here](https://github.com/pytorch/ao/tree/main?tab=readme-ov-file#without-intrusive-code-changes)
[TorchAO](https://github.com/pytorch/ao) is an architecture optimization library for PyTorch, it provides high performance dtypes, optimization techniques and kernels for inference and training, featuring composability with native PyTorch features like `torch.compile`, FSDP etc.. Some benchmark numbers can be found [here](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks).
Before you begin, make sure the following libraries are installed with their latest version:
@@ -21,6 +21,7 @@ pip install --upgrade torch torchao
```py
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
model_name = "meta-llama/Meta-Llama-3-8B"
@@ -40,6 +41,51 @@ quantized_model = torch.compile(quantized_model, mode="max-autotune")
output = quantized_model.generate(**input_ids, max_new_tokens=10)
print(tokenizer.decode(output[0], skip_special_tokens=True))
# benchmark the performance
import torch.utils.benchmark as benchmark
def benchmark_fn(f, *args, **kwargs):
# Manual warmup
for _ in range(5):
f(*args, **kwargs)
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)",
globals={"args": args, "kwargs": kwargs, "f": f},
num_threads=torch.get_num_threads(),
)
return f"{(t0.blocked_autorange().mean):.3f}"
MAX_NEW_TOKENS = 1000
print("int4wo-128 model:", benchmark_fn(quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS))
bf16_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda", torch_dtype=torch.bfloat16)
bf16_model = torch.compile(bf16_model, mode="max-autotune")
print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS))
```
torchao quantization is implemented with tensor subclasses, currently it does not work with huggingface serialization, both the safetensor option and [non-safetensor option](https://github.com/huggingface/transformers/issues/32364), we'll update here with instructions when it's working.
## Serialization and Deserialization
torchao quantization is implemented with [tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor), it only work with huggingface non-safetensor serialization and deserialization. It relies on `torch.load(..., weights_only=True)` to avoid arbitrary user code execution during load time and use [add_safe_globals](https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals) to allowlist some known user functions.
The reason why it does not support safe tensor serialization is that wrapper tensor subclass allows maximum flexibility so we want to make sure the effort of supporting new format of quantized Tensor is low, while safe tensor optimizes for maximum safety (no user code execution), it also means we have to make sure to manually support new quantization format.
```py
# save quantized model locally
output_dir = "llama3-8b-int4wo-128"
quantized_model.save_pretrained(output_dir, safe_serialization=False)
# push to huggingface hub
# save_to = "{user_id}/llama3-8b-int4wo-128"
# quantized_model.push_to_hub(save_to, safe_serialization=False)
# load quantized model
ckpt_id = "llama3-8b-int4wo-128" # or huggingface hub model id
loaded_quantized_model = AutoModelForCausalLM.from_pretrained(ckpt_id, device_map="cuda")
# confirm the speedup
loaded_quantized_model = torch.compile(loaded_quantized_model, mode="max-autotune")
print("loaded int4wo-128 model:", benchmark_fn(loaded_quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS))
```

View File

@@ -540,7 +540,11 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
return torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
def load_state_dict(checkpoint_file: Union[str, os.PathLike], is_quantized: bool = False):
def load_state_dict(
checkpoint_file: Union[str, os.PathLike],
is_quantized: bool = False,
map_location: Optional[Union[str, torch.device]] = None,
):
"""
Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
"""
@@ -555,13 +559,18 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], is_quantized: bool
)
return safe_load_file(checkpoint_file)
try:
if (
(is_deepspeed_zero3_enabled() and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0)
or (is_fsdp_enabled() and not is_local_dist_rank_0())
) and not is_quantized:
map_location = "meta"
else:
map_location = "cpu"
if map_location is None:
if (
(
is_deepspeed_zero3_enabled()
and torch.distributed.is_initialized()
and torch.distributed.get_rank() > 0
)
or (is_fsdp_enabled() and not is_local_dist_rank_0())
) and not is_quantized:
map_location = "meta"
else:
map_location = "cpu"
extra_args = {}
# mmap can only be used with files serialized with zipfile-based format.
if (
@@ -2564,7 +2573,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
hf_quantizer = getattr(self, "hf_quantizer", None)
quantization_serializable = (
hf_quantizer is not None and isinstance(hf_quantizer, HfQuantizer) and hf_quantizer.is_serializable
hf_quantizer is not None
and isinstance(hf_quantizer, HfQuantizer)
and hf_quantizer.is_serializable(safe_serialization=safe_serialization)
)
if hf_quantizer is not None and not _hf_peft_config_loaded and not quantization_serializable:
@@ -4479,7 +4490,15 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
# Skip the load for shards that only contain disk-offloaded weights when using safetensors for the offload.
if shard_file in disk_only_shard_files:
continue
state_dict = load_state_dict(shard_file, is_quantized=is_quantized)
map_location = None
if (
device_map is not None
and hf_quantizer is not None
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
and hf_quantizer.quantization_config.quant_type == "int4_weight_only"
):
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
state_dict = load_state_dict(shard_file, is_quantized=is_quantized, map_location=map_location)
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# matching the weights in the model.

View File

@@ -217,9 +217,8 @@ class HfQuantizer(ABC):
@abstractmethod
def _process_model_after_weight_loading(self, model, **kwargs): ...
@property
@abstractmethod
def is_serializable(self): ...
def is_serializable(self, safe_serialization=None): ...
@property
@abstractmethod

View File

@@ -93,6 +93,5 @@ class AqlmHfQuantizer(HfQuantizer):
)
return False
@property
def is_serializable(self):
def is_serializable(self, safe_serialization=None):
return True

View File

@@ -106,8 +106,7 @@ class AwqQuantizer(HfQuantizer):
model = post_init_awq_exllama_modules(model, self.quantization_config.exllama_config)
@property
def is_serializable(self):
def is_serializable(self, safe_serialization=None):
# AWQ through auto-awq has been always serializable, except if the model is fused.
if self.quantization_config.do_fuse:
logger.warning("You cannot save an AWQ model that uses fused modules!")

View File

@@ -320,11 +320,10 @@ class Bnb4BitHfQuantizer(HfQuantizer):
# Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer._process_model_after_weight_loading with 8bit->4bit
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
model.is_loaded_in_4bit = True
model.is_4bit_serializable = self.is_serializable
model.is_4bit_serializable = self.is_serializable()
return model
@property
def is_serializable(self):
def is_serializable(self, safe_serialization=None):
_is_4bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.41.3")
if not _is_4bit_serializable:

View File

@@ -210,7 +210,7 @@ class Bnb8BitHfQuantizer(HfQuantizer):
raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {target_device}.")
new_value = param_value.to("cpu")
if self.pre_quantized and not self.is_serializable:
if self.pre_quantized and not self.is_serializable():
raise ValueError(
"Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. "
"Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
@@ -238,7 +238,7 @@ class Bnb8BitHfQuantizer(HfQuantizer):
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
model.is_loaded_in_8bit = True
model.is_8bit_serializable = self.is_serializable
model.is_8bit_serializable = self.is_serializable()
return model
def _process_model_before_weight_loading(
@@ -282,8 +282,7 @@ class Bnb8BitHfQuantizer(HfQuantizer):
model.config.quantization_config = self.quantization_config
@property
def is_serializable(self):
def is_serializable(self, safe_serialization=None):
_bnb_supports_8bit_serialization = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse(
"0.37.2"
)

View File

@@ -72,6 +72,5 @@ class CompressedTensorsHfQuantizer(HfQuantizer):
def is_trainable(self):
return False
@property
def is_serializable(self):
def is_serializable(self, safe_serialization=None):
return False

View File

@@ -161,8 +161,7 @@ class EetqHfQuantizer(HfQuantizer):
model.config.quantization_config = self.quantization_config
@property
def is_serializable(self):
def is_serializable(self, safe_serialization=None):
return True
@property

View File

@@ -196,8 +196,7 @@ class FbgemmFp8HfQuantizer(HfQuantizer):
not_missing_keys.append(missing)
return [k for k in missing_keys if k not in not_missing_keys]
@property
def is_serializable(self):
def is_serializable(self, safe_serialization=None):
return True
@property

View File

@@ -89,6 +89,5 @@ class GptqHfQuantizer(HfQuantizer):
def is_trainable(self, model: Optional["PreTrainedModel"] = None):
return True
@property
def is_serializable(self):
def is_serializable(self, safe_serialization=None):
return True

View File

@@ -188,11 +188,10 @@ class HqqHfQuantizer(HfQuantizer):
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
model.is_hqq_quantized = True
model.is_hqq_serializable = self.is_serializable
model.is_hqq_serializable = self.is_serializable()
return model
@property
def is_serializable(self):
def is_serializable(self, safe_serialization=None):
return False
@property

View File

@@ -195,6 +195,5 @@ class QuantoHfQuantizer(HfQuantizer):
def is_trainable(self, model: Optional["PreTrainedModel"] = None):
return False
@property
def is_serializable(self):
def is_serializable(self, safe_serialization=None):
return False

View File

@@ -160,9 +160,19 @@ class TorchAoHfQuantizer(HfQuantizer):
"""No process required for torchao quantized model"""
return
@property
def is_serializable(self):
return False
def is_serializable(self, safe_serialization=None):
if safe_serialization:
logger.warning(
"torchao quantized model does not support safe serialization, "
"please set `safe_serialization` to False"
)
return False
_is_torchao_serializable = version.parse(importlib.metadata.version("huggingface_hub")) >= version.parse(
"0.25.0"
)
if not _is_torchao_serializable:
logger.warning("torchao quantized model is only serializable after huggingface_hub >= 0.25.0 ")
return _is_torchao_serializable
@property
def is_trainable(self):

View File

@@ -95,7 +95,6 @@ class QuantizationConfigMixin:
Returns:
[`QuantizationConfigMixin`]: The configuration object instantiated from those parameters.
"""
config = cls(**config_dict)
to_remove = []
@@ -1235,24 +1234,11 @@ class TorchAoConfig(QuantizationConfigMixin):
self.quant_method = QuantizationMethod.TORCHAO
self.quant_type = quant_type
self.modules_to_not_convert = modules_to_not_convert
self.kwargs = kwargs
self._STR_TO_METHOD = {}
if is_torchao_available():
from torchao.quantization import (
int4_weight_only,
int8_dynamic_activation_int8_weight,
int8_weight_only,
)
self._STR_TO_METHOD = {
"int4_weight_only": int4_weight_only,
"int8_weight_only": int8_weight_only,
"int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight,
}
# when we load from serailized config, "quant_type_kwargs" will be the key
if "quant_type_kwargs" in kwargs:
self.quant_type_kwargs = kwargs["quant_type_kwargs"]
else:
raise ValueError(
"TorchAoConfig requires torchao to be installed, please install with `pip install torchao`"
)
self.quant_type_kwargs = kwargs
self.post_init()
@@ -1263,26 +1249,46 @@ class TorchAoConfig(QuantizationConfigMixin):
if not version.parse(importlib.metadata.version("torchao")) >= version.parse("0.4.0"):
raise ValueError("Requires torchao 0.4.0 version and above")
if self.quant_type not in self._STR_TO_METHOD.keys():
_STR_TO_METHOD = self._get_torchao_quant_type_to_method()
if self.quant_type not in _STR_TO_METHOD.keys():
raise ValueError(
f"Requested quantization type: {self.quant_type} is not supported yet, please add support in TorchAoConfig and TorchAoHfQuantizer."
)
method = self._STR_TO_METHOD[self.quant_type]
method = _STR_TO_METHOD[self.quant_type]
sig = signature(method)
all_kwargs = [
param.name
for param in sig.parameters.values()
if param.kind in [Parameter.KEYWORD_ONLY, Parameter.POSITIONAL_OR_KEYWORD]
]
for k in self.kwargs:
for k in self.quant_type_kwargs:
if k not in all_kwargs:
raise ValueError(
f"Unexpected keyword arg: {k} for API: {method}, accepted keyword args are: {all_kwargs}"
)
def _get_torchao_quant_type_to_method(self):
if is_torchao_available():
from torchao.quantization import (
int4_weight_only,
int8_dynamic_activation_int8_weight,
int8_weight_only,
)
return {
"int4_weight_only": int4_weight_only,
"int8_weight_only": int8_weight_only,
"int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight,
}
else:
raise ValueError(
"TorchAoConfig requires torchao to be installed, please install with `pip install torchao`"
)
def get_apply_tensor_subclass(self):
return self._STR_TO_METHOD[self.quant_type](**self.kwargs)
_STR_TO_METHOD = self._get_torchao_quant_type_to_method()
return _STR_TO_METHOD[self.quant_type](**self.quant_type_kwargs)
def __repr__(self):
return f"{self.quant_type}({', '.join(str(k) + '=' + str(v) for k, v in self.kwargs.items())})"
return f"{self.quant_type}({', '.join(str(k) + '=' + str(v) for k, v in self.quant_type_kwargs.items())})"